# Variational Auto Encoder

In [None]:
# Dataset
DATASET = 'cifar10'
CHANNELS = 3
INPUT_SHAPE = (32, 32, CHANNELS)
VALID_SIZE = 0

# Optimization
LATENT_DIM = 2048
REC_LOSS_W = 1
KL_LOSS_W = 1

# Training
EPOCHS = 200
BATCH_SIZE = 64
LR = 0.001

CVS = [32, 64, 128]
CV_PARAMS = dict(strides=2, padding='same', activation='relu')

### Setup

In [None]:
import os
import time
from collections import namedtuple

import numpy as np
import tensorflow as tf

import seaborn as sns
import matplotlib.pyplot as plt

from tensorflow.keras import Model
from tensorflow.keras.layers import (Layer, Conv2D, Conv2DTranspose, Dense,
                                     Dropout, BatchNormalization,
                                     Activation, GlobalAveragePooling2D,
                                     Reshape, Flatten)

In [None]:
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

In [None]:
sns.set_style("whitegrid", {'axes.grid' : False})

In [None]:
def plot(y, titles=None, rows=1, i0=0):
    for i, image in enumerate(y):
        if image is None:
            plt.subplot(rows, len(y), i0+i+1)
            plt.axis('off')
            continue

        t = titles[i] if titles else None
        plt.subplot(rows, len(y), i0+i+1, title=t)
        plt.imshow(image)
        plt.axis('off')

def plot_to_image(figure):
    import io
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    plt.close(figure)
    buf.seek(0)
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    image = tf.expand_dims(image, 0)
    return image

class TensorBoardImage(tf.keras.callbacks.Callback):
    def __init__(self, tag, logs_dir):
        super().__init__() 
        self.tag = tag
        self.logs_dir = logs_dir

    def on_train_begin(self, logs=None):    
        self.writer = tf.summary.create_file_writer(self.logs_dir)

    def on_epoch_end(self, epoch, logs=None):
        d_zu, d_zlv, d_z = encoder.predict(Data.x[:16])
        rec = decoder.predict(d_z)
        rec = np.expand_dims(np.hstack(rec), 0)
        with self.writer.as_default():
            tf.summary.image(self.tag, rec, step=epoch)

    def on_train_end(self, logs=None):
        self.writer.close()

### Load Dataset

In [None]:
CLASSES = np.asarray('airplane automobile bird cat deer dog frog horse ship truck'.split())

In [None]:
Data = namedtuple('Data', 'x y xv yv xt yt')

def load_data():
    (x, y), (xt, yt) = tf.keras.datasets.cifar10.load_data()
    x = (x.astype("float32") / 255) # .mean(axis=-1, keepdims=True)
    xt = (xt.astype("float32") / 255) # .mean(axis=-1, keepdims=True)
    y, yt = y.ravel(), yt.ravel()

    if VALID_SIZE:
        _valid_samples = int(VALID_SIZE * len(x))

        x, xv = x[_valid_samples:], x[:_valid_samples]
        y, yv = x[_valid_samples:], x[:_valid_samples]
    else:
        xv, yv = None, None
        
    print('Training')
    print('  samples:', len(x))
    print('  labels examples:', y[:10])
    if VALID_SIZE:
        print('Validating')
        print('  samples:', len(xv))
        print('  labels examples:', yv[:10])
    print('Testing')
    print('  samples:', len(xt))
    print('  labels examples:', yt[:10])
    
    plot(x[:4])
    
    return Data(x=x, y=y, xv=xv, yv=yv, xt=xt, yt=yt)

Data = load_data()

In [None]:
x_ds = (tf.data
        .Dataset.from_tensor_slices(Data.x)
        .shuffle(len(Data.x))
        .batch(BATCH_SIZE, drop_remainder=True))

### Defining Model

In [None]:
RUN_ID = int(time.time())

LOGS = (f'/tf/logs/d:{DATASET} e:{EPOCHS} b:{BATCH_SIZE} lr:{LR} '
        f'arch:({",".join(map(str, CVS))}) latent:{LATENT_DIM} '
        f'rec_w:{round(REC_LOSS_W, 6)} kl_w:{round(KL_LOSS_W, 6)}'
        f'/{RUN_ID}')

In [None]:
class Sampling(Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

def encode(x, latent_dim=LATENT_DIM):
    y = x
    for ix, c in enumerate(CVS):
        y = Conv2D(c, 3, name=f'cv{ix}', **CV_PARAMS)(y)
    y = Flatten(name='ft')(y)
    y = Dense(latent_dim, activation="relu", name='fc1')(y)
    zu = Dense(latent_dim, name='zu')(y)
    zlv = Dense(latent_dim, name='zlv')(y)

    return zu, zlv

In [None]:
x = tf.keras.Input(shape=INPUT_SHAPE, name='images')
zu, zlv = encode(x)
z = Sampling(name='zs')([zu, zlv])

encoder = Model(x, [zu, zlv, z], name='encoder')

In [None]:
encoder.summary()

In [None]:
def decode(z, act='sigmoid'):
    size = 4
    filters = CVS[-1]

    z = Dense(size*size*filters, activation="relu", name='fc1')(z)
    z = Reshape((size, size, filters), name='rs')(z)
    for ix, c in enumerate(reversed(CVS)):
        z = Conv2DTranspose(c, 3, name=f'cvt{ix}', **CV_PARAMS)(z)
    z = Conv2DTranspose(CHANNELS, 3, activation=act, padding='same', name='decoded')(z)

    return z

lvs = tf.keras.Input(shape=(LATENT_DIM,), name='latent_vars')
ty = decode(lvs)

decoder = Model(lvs, ty, name='decoder')

In [None]:
decoder.summary()

In [None]:
from tensorflow.keras import metrics, losses

class VAE(Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = REC_LOSS_W * tf.reduce_mean(
                tf.reduce_sum(losses.binary_crossentropy(data, reconstruction),
                              axis=(1, 2))
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = KL_LOSS_W * tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

In [None]:
model = VAE(encoder, decoder)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=LR))

### Summary

In [None]:
tf.keras.utils.plot_model(encoder, show_shapes=True, rankdir="LR")

In [None]:
tf.keras.utils.plot_model(decoder, show_shapes=True, rankdir="LR")

## Training

In [None]:
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

In [None]:
if os.path.exists(LOGS):
    raise ValueError(f'Conflicting logs {LOGS}. Change or delete the target folder.')

model.fit(
    x_ds,
    epochs=EPOCHS,
    verbose=2,
    callbacks=[
        tf.keras.callbacks.EarlyStopping(
            monitor='loss',
            patience=max(1, EPOCHS // 10),
            verbose=1),
        tf.keras.callbacks.TerminateOnNaN(),
        tf.keras.callbacks.TensorBoard(
            LOGS,
            histogram_freq=1,
            embeddings_freq=3),
        TensorBoardImage('reconstruction', LOGS + '/rec')
    ]);

## Generating

In [None]:
file_writer = tf.summary.create_file_writer(LOGS + '/train')

In [None]:
import matplotlib.pyplot as plt

SAMPLES = 10

def plot_latent_space(vae, n=SAMPLES, figsize=15):
    # display a n*n 2D manifold of digits
    preds = []
    dh, dw = INPUT_SHAPE[:2]
    scale = 1.0
    figure = np.zeros((INPUT_SHAPE[0] * n, INPUT_SHAPE[1] * n, CHANNELS))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi] + [0] * (LATENT_DIM - 2)])
            x_decoded = vae.decoder.predict(z_sample)
            digit = x_decoded[0].reshape(INPUT_SHAPE)
            figure[
                i * dh : (i + 1) * dw,
                j * dh : (j + 1) * dh,
            ] = digit
            preds.append(x_decoded)

    plt_fig = plt.figure(figsize=(figsize, figsize))
    plt.imshow(figure, cmap="Greys_r")
    plt.show()

    return figure.reshape(1, *figure.shape)

predictions = plot_latent_space(model)

with file_writer.as_default():
    tf.summary.image(f'{SAMPLES**2} samples generated', predictions, step=0);

In [None]:
from sklearn.decomposition import PCA

SAMPLES_PLOTTED = 10000

def plot_label_clusters(data, labels):
    zu, _, _ = encoder.predict(data[:SAMPLES_PLOTTED])
    zu = PCA(n_components=2).fit_transform(zu)

    fig = plt.figure(figsize=(12, 10))
    sns.scatterplot(x=zu[:, 0], y=zu[:, 1], hue=labels[:SAMPLES_PLOTTED])
    plt.xlabel("Pz[0]")
    plt.ylabel("Pz[1]")
    plt.show()

    return fig

xp, yp = ((Data.xv, Data.yv) if VALID_SIZE else (Data.x, Data.y))
clu_fig = plot_label_clusters(xp, CLASSES[yp])
plt.show()

with file_writer.as_default():
    tf.summary.image("PCA(z)", plot_to_image(clu_fig), step=0);

In [None]:
import warnings
from math import ceil
from itertools import combinations

import pandas as pd

SAMPLES = 10000
DIMS = 8

def plot_label_pairs_clusters(data, labels):
    zu, _, _ = encoder.predict(data[:SAMPLES])
    dim = zu.shape[1]

    d = pd.DataFrame(zu[:, :DIMS])
    d['labels'] = labels[:SAMPLES]
    return sns.pairplot(d, hue='labels')

xp, yp = ((Data.xv, Data.yv) if VALID_SIZE else (Data.x, Data.y))
g = plot_label_pairs_clusters(xp, CLASSES[yp])
plt.show()

with file_writer.as_default():
    tf.summary.image("latent_variables", plot_to_image(g.fig), step=0);