In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [28]:
import pickle
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt

# Reusing Keras VAE code

In [29]:
class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs) -> tf.Tensor:
        # inputs is the tensor of shape (batch_size, d0, d1,...)
        # In this case, the shape is -> (batch_size, 2)
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]  # batch_size
        dim = tf.shape(z_mean)[1]    # 2
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + epsilon * tf.exp(0.5 * z_log_var)           # mu + sqrt(var) * epsilon

################ ENCODER ################
latent_dim = 2

encoder_inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, kernel_size=3, activation="relu", padding="same")(encoder_inputs)
x = layers.Conv2D(64, kernel_size=3, activation="relu", padding="same", strides=(2,2))(x)
x = layers.Conv2D(64, kernel_size=3, activation="relu", padding="same")(x)
x = layers.Conv2D(64, kernel_size=3, activation="relu", padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)

# Pass this last layer's output to both mean layer and variance layer
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)

# Sampling layer takes the mean and variance as tensor inputs
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")

################ DECODER ################
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)  # arbitrary units
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)

# Last layer reduces color channels to 1
decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")

## Define the VAE as a Model with a custom train_step

In [55]:
kl_term = 1e0

class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            # Get the outputs of the first part of the model, the encoder
            # Outputs are:
            #              - mu
            #              - log variance
            #              - z = latent sampled vector
            z_mean, z_log_var, z = self.encoder(data)
            
            # Decode z
            reconstruction = self.decoder(z)
            
            # Recon loss
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                )
            )
            # KL loss between a normal distribution N(mu, sigma) and N(0,1) --> D( N(mu,sigma) || N(0,1))
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_term * kl_loss
        
        # Get computed grqdients
        grads = tape.gradient(total_loss, self.trainable_weights)
        
        # Pass the gradients and the trainable_weights to the optimizer
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))  # self.trainable_weight is also returned by self.trainable_variables(https://github.com/tensorflow/tensorflow/blob/285c6a0fa1d5a3d2a96507ac9f707ce5c0e3ac1f/tensorflow/python/layers/base.py#L183)
        
        # Update state of loss functions
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result()
        }

In [57]:
# Load data
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        a_dict = pickle.load(fo, encoding='bytes')
    return a_dict
qmnist = unpickle("/kaggle/input/qmnist-the-extended-mnist-dataset-120k-images/MNIST-120k")  # dict

print(qmnist.keys())
print(len(qmnist['data']))
print(qmnist['data'].shape)
print(qmnist['data'][0].shape)

In [58]:
plt.imshow(qmnist['data'][1])

In [None]:
mnist_digits = qmnist['data']

# Normalize
mnist_digits = mnist_digits.astype('float32')
mnist_digits = mnist_digits / 255

shape = mnist_digits.shape

mnist_digits = mnist_digits.reshape(shape[0], shape[1], shape[2], 1)  # (120000, 28,28,1)--> has 1 color channel
vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())
history = vae.fit(mnist_digits, epochs=30, batch_size=128)

In [None]:
def plot_latent_space(vae, n=30, figsize=15):
    # display a n*n 2D manifold of digits
    digit_size = 28
    scale = 1.0
    figure = np.zeros((digit_size * n, digit_size * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = vae.decoder.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = digit

    plt.figure(figsize=(figsize, figsize))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()


plot_latent_space(vae)

In [None]:
def plot_label_clusters(vae, data, labels):
    # display a 2D plot of the digit classes in the latent space
    z_mean, _, _ = vae.encoder.predict(data)
    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.show()



plot_label_clusters(vae, mnist_digits, qmnist['labels'])