In [None]:
from glob import glob
from math import sqrt
import os

import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.layers import Input, Conv2D, MaxPool2D, Conv2DTranspose, Lambda, Reshape, Flatten, Dense, BatchNormalization
from tensorflow.keras.losses import KLDivergence
from tensorflow.keras.models import Model, load_model
import tensorflow.python.keras.backend as K
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

In [None]:
# Dataset preprocessing
# Only needs to be run once, Tyler ran it already

# DATA_DIR = "/hdd/datasets/UTKFace"
# image_paths = glob(DATA_DIR + "/*")

# images = np.zeros((len(image_paths), 96, 96, 3), dtype=np.float32)
# for i, image_path in enumerate(image_paths):
#     image = Image.open(image_path)
#     image = image.resize((96, 96), Image.LANCZOS)
#     images[i] = np.array(image) / 255.
# np.save("/hdd/datasets/UTKFace.npy", images)

In [None]:
# Loads numpy file containing dataset
images = np.load("/hdd/datasets/UTKFace.npy")

In [None]:
# Plots some sample images
fig, ax = plt.subplots(ncols=5, figsize=(15, 15))
for i, col in enumerate(ax):
    col.imshow(images[i])

In [None]:
# Create the encoder using the tf.Keras functional API
# Four of Conv2D->Conv2D->BatchNorm->MaxPool starting at 16 filters and doubling each time
# Then a final, pixel-wise Conv2D with 128 filters
def create_encoder(input_shape, encoding_dim):
    if not sqrt(encoding_dim).is_integer():
        raise ValueError("Encoding dim must be a perfect square.")

In [None]:
# Create the decoder using the tf.Keras functional API
# Four of Conv2DTranspose starting at 128 filters and halving each time
# Then a final Conv2D with 3 filters and sigmoid activation
def create_decoder(encoding_dim):
    if not sqrt(encoding_dim).is_integer():
        raise ValueError("Encoding dim must be a perfect square.")

In [None]:
input_shape = images[0].shape
ENCODING_DIM = 36

In [None]:
# Assemble your autoencoder and compile with MSE loss


In [None]:
!mkdir -p models/ae
checkpointer = ModelCheckpoint("models/ae/epoch{epoch}_loss{val_loss:.4f}.h5", save_best_only=True, verbose=1)

# Fit the model for 50 epochs with the checkpointer callback

In [None]:
# Loads best model
saved_models = sorted(glob("models/ae/*"), key=os.path.getmtime)
autoencoder = load_model(saved_models[-1])
encoder = autoencoder.get_layer(index=1)
decoder = autoencoder.get_layer(index=2)

In [None]:
# Plots some example reconstructions

visual = images[:5]
preds = autoencoder.predict(visual)

plt.figure(figsize=(15,15))
for i in range(5):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(visual[i])
    plt.xlabel("Original")
for i in range(5):
    plt.subplot(5,5,i+6)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(preds[i])
    plt.xlabel("Reconstructed")
plt.show()

In [None]:
# Compares the mean of the reconstructions with the reconstruction of the mean latent value
# Just thought this was cool :)

plt.figure(figsize=(15, 15))

plt.subplot(5, 5, 1)
preds = autoencoder.predict(images)
avg = np.mean(images, axis=0)
plt.imshow(avg)

plt.subplot(5, 5, 2)
latent = encoder.predict(images)
avg = np.expand_dims(np.mean(latent, axis=0), 0)
pred = decoder.predict(avg).squeeze()
plt.imshow(pred)

In [None]:
# Create the variational encoder
# Same as the earlier encoder, but with tfp.layers.Convolution2DFlipout at the end
# This layer learns a distribution over each weight
def create_variational_encoder(input_shape, encoding_dim):
    if not sqrt(encoding_dim).is_integer():
        raise ValueError("Encoding dim must be a perfect square.")

In [None]:
# Create the variational encoder
# Same as the earlier decoder!
def create_variational_decoder(encoding_dim):
    if not sqrt(encoding_dim).is_integer():
        raise ValueError("Encoding dim must be a perfect square.")

In [None]:
# Assemble the variational autoencoder

In [None]:
# For the ELBO loss, we need to create a "closure"
# This is when we define a function within a function
# We need this to include other variables in the KL term, but
# Keras expects a certain function signature for the loss fn.
def create_loss_fn(model, batch_size, dataset_size):
    # Obtain the model KL divergence per epoch by sum(model.losses)
    # Then re-weight it by batch / dataset

    def loss_fn(y_true, y_pred):
        # Compute BCE and add KL
    
    return loss_fn

# Compile the model with your custom loss

In [None]:
!mkdir -p models/vae
checkpointer = ModelCheckpoint("models/vae/epoch{epoch}_loss{val_loss:.0f}.h5", save_best_only=True, verbose=1)

# Fit the model for 150 epochs with the checkpointer callback

In [None]:
# Loads best model
saved_models = sorted(glob("models/vae/*"), key=os.path.getmtime)
variational_autoencoder = load_model(saved_models[-1])
variational_encoder = autoencoder.get_layer(index=1)
variational_decoder = autoencoder.get_layer(index=2)

In [None]:
# Plots some example reconstructions
# Why is this worse than the original autoencoder?
# (Recall what adding KL divergence does!)
visual = images[:5]

# Runs 25 Monte Carlo samples
# Necessary because model forward pass is stochastic
preds = np.zeros((25, 5, 96, 96, 3))
for i in range(0, 25):
    preds[i] = variational_autoencoder.predict(visual)
preds = preds.mean(axis=0)

plt.figure(figsize=(15,15))
for i in range(5):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(visual[i])
    plt.xlabel("Original")
for i in range(5):
    plt.subplot(5,5,i+6)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(preds[i])
    plt.xlabel("Reconstructed")
plt.show()

In [None]:
# Compares the mean of the reconstructions with the reconstruction of the mean latent value
# Why do you think this is different than the original autoencoder?

plt.figure(figsize=(15, 15))

plt.subplot(5, 5, 1)
preds = variational_autoencoder.predict(images)
avg = np.mean(images, axis=0)
plt.imshow(avg)

plt.subplot(5, 5, 2)
latent = variational_encoder.predict(images)
avg = np.expand_dims(np.mean(latent, axis=0), 0)
pred = variational_decoder.predict(avg).squeeze()
plt.imshow(pred)

In [None]:
# Feel free to play around with architectures and encoding_dim values!
# Post any cool results on the Slack.


# CHALLENGE PROBLEM:
# I wasn't able to figure out a way to plot the latent space of the VAE 
# to obtain an image like the MNIST plot in the slides
# If you can figure out how to do this and display latent space plots
# for both models, you'll win a prize!








