# vari'art: 
### Example of latent analysis of a rap clip

In [2]:
import numpy as np
import pandas as pd
import random
from sklearn.utils import shuffle
import tensorflow as tf
from tensorflow.keras.layers import (
    InputLayer, 
    Dense, 
    Reshape, 
    Flatten, 
    Dropout, 
    Conv2D, 
    Conv2DTranspose, 
    MaxPool2D,
    BatchNormalization
)
from tensorflow.keras.optimizers import Adam
from tensorflow_addons.optimizers import Lookahead

from variart.preprocessing import ArtVideo
from variart.model import VAE
from variart.latent import Latent

## 1. Load data and preprocessing

In [4]:
# Load video
name = 'DrillFR4' 
filename = 'inputs/DrillFR4.mp4'
DrillFR4 = ArtVideo(name, filename)
DrillFR4.load_video()

# Crop images as squares
DrillFR4.square()

# Resize images
size = 64
new_shape=(size,size)
DrillFR4.resize(new_shape=new_shape)

# Rescale pixels in (0,1)
DrillFR4.rescale_image()

# Input data shape
print(f"Shape {DrillFR4.name}: {DrillFR4.shape}")

Shape DrillFR4: (6692, 64, 64, 3)


In [7]:
# Show randomm image
DrillFR4.show_random_image()

## 2. Train VAE

In [None]:
# Prepare data for training
data = DrllFR4.X.astype('float32')
data = shuffle(data, random_state=0)

TRAIN_BUF = int(data.shape[0]*0.9)
data_train = data[:TRAIN_BUF]
data_validation = data[TRAIN_BUF:]

In [None]:
# Parameters
batch_size = 128
epochs = 10000
early_stop_patience = 15
latent_dim = 8
optimizer = Lookahead(Adam(1e-3))

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices(data_train).batch(batch_size)
validation_dataset = tf.data.Dataset.from_tensor_slices(data_validation).batch(batch_size)
nb_features = data.shape[1]*data.shape[2]*data.shape[3]
input_shape = (batch_size, data.shape[1], data.shape[2], data.shape[3])

In [None]:
# Encoder and decoder networks (inference and generative)
inference_net = tf.keras.Sequential(
      [
          InputLayer(input_shape=(data.shape[1], data.shape[2], data.shape[3])),
          Conv2D(filters=4, kernel_size=3, strides=(1, 1), activation='tanh'),
          MaxPool2D((2,2)),
          BatchNormalization(),
          Conv2D(filters=8, kernel_size=3, strides=(1, 1), activation='tanh'),
          MaxPool2D((2,2)),
          BatchNormalization(),
          Flatten(),
          Dense(latent_dim + latent_dim),
      ]
    )

generative_net = tf.keras.Sequential(
        [
            InputLayer(input_shape=(latent_dim,)),
            Dense(units=data.shape[1]*data.shape[2]*4, activation='tanh'),
            BatchNormalization(),
            Reshape(target_shape=(data.shape[1], data.shape[2], 4)),
            Conv2DTranspose(
              filters=8,
              kernel_size=3,
              strides=(1, 1),
              padding="SAME",
              activation='tanh'),
            BatchNormalization(),
            Conv2DTranspose(
              filters=4,
              kernel_size=3,
              strides=(1, 1),
              padding="SAME",
              activation='tanh'),
            BatchNormalization(),
            Conv2DTranspose(
              filters=3, kernel_size=3, strides=(1, 1), padding="SAME"),
        ]
    )

In [None]:
# Model definition
model = VAE(DrillFR4.name, latent_dim, input_shape, inference_net, generative_net)

In [None]:
# Train
model = model.train(optimizer, 
                    train_dataset, 
                    validation_dataset, 
                    epochs,
                    batch_size,
                    early_stop_patience = early_stop_patience, 
                    freq_plot = 25, 
                    plot_test = True,
                    n_to_plot = 4)

## 3. Latent analysis

In [None]:
# Create latent object
LatentDrillFR4 = Latent(data, model)

In [None]:
# Encode and decode data
LatentDrillFR4.encode_data()
LatentDrillFR4.decode_data()

In [None]:
# Create tsne representation of data in latent space
LatentDrillFR4.latent_tsne()
LatentDrillFR4.plot_latent_tsne()

In [None]:
# Compute distributions of latent space dimensions
LatentDrillFR4.compute_dist_coord()
LatentDrillFR4.plot_latent_dist_coord()

In [None]:
# Perform clustering in latent space, test number of cluesters on a grid
LatentDrillFR4.latent_space_clustering(grid=range(5,50,5))
LatentDrillFR4.plot_silhouette_score()

In [None]:
# Select number of clusters
n_clusters = 2
clusterer = LatentDrillFR4.dico_clust[n_clusters]['clusterer']
LatentDrillFR4.plot_latent_tsne(clusterer=clusterer)

In [None]:
# Show images for a given cluster
label = 0
list_id = [i for i,l in enumerate(clusterer.labels_) if l==label][0:5]
LatentDrillFR4.plot_encoded_decoded(list_id=list_id)

In [None]:
# Generate images by sampling from distributions in the latent space
list_z, fig = LatentDrillFR4.generate_image(n=5, method='dist')
fig.show()

In [None]:
# Create a GIF from generated images
filename = f"outputs/gif_{LatentDrillFR4.name}.gif"
LatentDrillFR4.create_gif(list_z)