# VQVAE and Residual Stack

In [1]:
from keras import layers as Layer, Input, Model, Sequential
from keras.datasets import mnist, cifar10
from keras.optimizers import Adam
from keras.metrics import Mean, MAE
from keras.models import load_model
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import pandas as pd
import glob

In [2]:
tf.config.get_visible_devices()

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [79]:
class ResidualStack(tf.Module):
    def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
                 name=None):
        super(ResidualStack, self).__init__(name=name)
        self._num_hiddens = num_hiddens
        self._num_residual_layers = num_residual_layers
        self._num_residual_hiddens = num_residual_hiddens

        self._layers = []
        for idx in range(num_residual_layers):
            conv3 = Layer.Conv2D(num_residual_hiddens, kernel_size=3, strides=1, padding='same', name=f'res3x3_{idx}')
            conv1 = Layer.Conv2D(num_hiddens, kernel_size=1, strides=1, padding='same', name=f'res1x1_{idx}')
            self._layers.append((conv3, conv1))

    def __call__(self, inputs):
        h = inputs
        for conv3, conv1 in self._layers:
            conv3_out_l = conv3(tf.nn.relu(h))
            conv1_out_l = conv1(tf.nn.relu(conv3_out_l))
            h += conv1_out_l
        return tf.nn.relu(h)

In [80]:
class Encoder(Model):
    def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
                 name=None):
        super(Encoder, self).__init__(name=name)

        self._num_hiddens = num_hiddens
        self._num_residual_layers = num_residual_layers
        self._num_residual_hiddens = num_residual_hiddens

        self._enc_l1 = Layer.Conv2D(self._num_hiddens // 2, kernel_size=(4, 4), strides=(2, 2), padding='same', name='enc_l1')
        self._enc_l2 = Layer.Conv2D(self._num_hiddens, kernel_size=(4, 4), strides=(2, 2), padding='same', name='enc_l2')
        self._enc_l3 = Layer.Conv2D(self._num_hiddens, kernel_size=(3, 3), strides=(1, 1), padding='same', name='enc_l3')
        self._residual_stack = ResidualStack(self._num_hiddens, self._num_residual_layers, self._num_residual_hiddens,
                                             name='resblock1')

    def call(self, input, training=None, mask=None):
        h = tf.nn.relu(self._enc_l1(input))
        h = tf.nn.relu(self._enc_l2(h))
        h = tf.nn.relu(self._enc_l3(h))
        return self._residual_stack(h)

In [81]:
class Decoder(Model):
    def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
                 name=None):
        super(Decoder, self).__init__(name=name)

        self._num_hiddens = num_hiddens
        self._num_residual_layers = num_residual_layers
        self._num_residual_hiddens = num_residual_hiddens

        self._dec1 = Layer.Conv2D(self._num_hiddens, kernel_size=(3, 3), strides=(1, 1), padding='same', name='dec_l1')

        self._residual_stack = ResidualStack(
            self._num_hiddens,
            self._num_residual_layers,
            self._num_residual_hiddens,
            name='resblock2'
        )

        self._dec2 = Layer.Conv2DTranspose(self._num_hiddens // 2, kernel_size=(4, 4), strides=(2, 2), padding='same', name='dec_l2')
        self._dec3 = Layer.Conv2DTranspose(3, kernel_size=(4, 4), strides=(2, 2), padding='same', name='dec_l3')
        # self._up1 = Layer.UpSampling2D()
        # self._up2 = Layer.UpSampling2D()

    def call(self, inputs, training=None, mask=None):
        h = self._dec1(inputs)
        h = self._residual_stack(h)
        h = tf.nn.relu(self._dec2(h))
        # h = self._up1(h)
        # h = self._dec3(h)
        reconstruction = self._dec3(h)
        return reconstruction

In [82]:
class VectorQuantizer(Layer.Layer):
    # beta -> commitment cost
    def __init__(self, embedding_dim, num_embeddings, beta=0.25, name=None):
        super(VectorQuantizer, self).__init__(name=name)
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.beta = beta

        # Initialize the embeddings to quantize by pre-specifying random uniform distribution
        w_init = tf.random_uniform_initializer()
        self.embeddings = tf.Variable(
            initial_value=w_init(
                shape=(self.embedding_dim, self.num_embeddings), dtype="float32"
            ),
            trainable=True,
            name="embeddings_vqvae",
        )

    def call(self, inputs, training=None, mask=None):
        # Calculate the input shape of the inputs and
        # then flatten the inputs keeping `embedding_dim` intact
        input_shape = tf.shape(inputs)
        flattened = tf.reshape(inputs, [-1, self.embedding_dim])

        # Quantization
        encoding_indices = self.get_code_indices(flattened)
        # Apply one hot
        encodings = tf.one_hot(encoding_indices, self.num_embeddings)
        # Compute Matrix Multiplication
        quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)

        # Reshape the quantized values back to the original input shape
        quantized = tf.reshape(quantized, input_shape)

        # Calculate the vector quantization loss
        commitment_loss = tf.reduce_mean((tf.stop_gradient(quantized) - inputs) ** 2)
        codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(inputs)) ** 2)
        #  Add the calculation to the layer
        self.add_loss(self.beta * commitment_loss + codebook_loss)

        # Straight-through estimator
        quantized = inputs + tf.stop_gradient(quantized - inputs)
        return quantized

    def get_code_indices(self, flattened_inputs):
        # Calculate the L2-normalized distance between the inputs and the codes
        similarity = tf.matmul(flattened_inputs, self.embeddings)
        # Obtain distance distribution
        distances = (
                tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True)
                + tf.reduce_sum(self.embeddings ** 2, axis=0)
                - 2 * similarity
        )

        # Derive the indices for minimum distances
        encoding_indices = tf.argmin(distances, axis=1)
        return encoding_indices

    def get_config(self):
        return {
            'num_embeddings': self.num_embeddings,
            'embedding_dim': self.embedding_dim,
            'beta': self.beta
        }

    @classmethod
    def from_config(cls, config):
        return cls(**config)

In [83]:
def get_vqvae(shape, latent_dim, num_embeddings, model_encoder, model_decoder):
    vq_layer = VectorQuantizer(latent_dim, num_embeddings, name="vector_quantizer")
    pre_vq_inputs = Layer.Conv2D(latent_dim, kernel_size=(1,1), strides=(1,1), name='to_vq')
    inputs = Input(shape=shape)
    encoder_outputs = model_encoder(inputs)
    z = pre_vq_inputs(encoder_outputs)
    quantized_latents = vq_layer(z)
    reconstructions = model_decoder(quantized_latents)
    return Model(inputs, reconstructions, name="vq_vae")

In [84]:
class VQVAETrainer(Model):
    def __init__(self, input_shape, num_hiddens, num_residual_layers, num_residual_hiddens, train_variance, latent_dim,
                 num_embeddings, name=None):
        super(VQVAETrainer, self).__init__(name=name)
        self.train_variance = train_variance
        self.latent_dim = latent_dim #embedding_dim
        self.num_embeddings = num_embeddings
        self.num_hiddens = num_hiddens
        self.model_encoder = Encoder(num_hiddens=self.num_hiddens, num_residual_layers=num_residual_layers,
                                     num_residual_hiddens=num_residual_hiddens, name='Encoder')
        self.model_decoder = Decoder(num_hiddens=self.num_hiddens, num_residual_layers=num_residual_layers,
                                     num_residual_hiddens=num_residual_hiddens, name='Decoder')

        self.vqvae = get_vqvae(
            shape=input_shape,
            latent_dim=self.latent_dim,
            num_embeddings=self.num_embeddings,
            model_encoder=self.model_encoder,
            model_decoder=self.model_decoder)

        self.total_loss_tracker = Mean(name="total_loss")
        self.reconstruction_loss_tracker = Mean(
            name="reconstruction_loss"
        )
        self.vq_loss_tracker = Mean(name="vq_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.vq_loss_tracker,
        ]

    def train_step(self, x):
        with tf.GradientTape() as tape:
            # Outputs from the VQ-VAE.
            reconstructions = self.vqvae(x)

            # Calculate the losses.
            reconstruction_loss = (tf.reduce_mean((x - reconstructions) ** 2) / self.train_variance)
            total_loss = reconstruction_loss + sum(self.vqvae.losses)

        # Backpropagation
        grads = tape.gradient(total_loss, self.vqvae.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables))

        # Loss tracking
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.vq_loss_tracker.update_state(sum(self.vqvae.losses))

        # Log results.
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "vqvae_loss": self.vq_loss_tracker.result(),
        }

    @classmethod
    def from_config(cls, config):
        return cls(**config)

In [92]:
(x_train, _), (x_test, _) = cifar10.load_data()

# x_train = np.expand_dims(x_train, -1)
# x_test = np.expand_dims(x_test, -1)
x_train_scaled = (x_train / 255.0) - 0.5
x_test_scaled = (x_test / 255.0) - 0.5

In [93]:
x_train_scaled.shape

(50000, 32, 32, 3)

Define the parameters

In [94]:
input_shape = x_train.shape[1:-1]
num_hiddens = 128
num_residual_layers = 2
num_residual_hiddens = 32
data_variance = np.var(x_train / 255.0)
embedding_dim = 64
num_embeddings = 512

In [97]:
x_train.shape[1:]

(32, 32, 3)

Initialize a VQVAETrainer

In [98]:
vqvae_trainer = VQVAETrainer(
    input_shape=x_train.shape[1:],
    num_hiddens=num_hiddens,
    num_residual_layers=num_residual_layers,
    num_residual_hiddens=num_residual_hiddens,
    train_variance=data_variance,
    latent_dim=embedding_dim,
    num_embeddings=num_embeddings,
    name='VQVAETrainer'
)

In [99]:
vqvae_trainer.vqvae.summary()

Model: "vq_vae"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_16 (InputLayer)       [(None, 32, 32, 3)]       0         
                                                                 
 Encoder (Encoder)           (None, 8, 8, 128)         364160    
                                                                 
 to_vq (Conv2D)              (None, 8, 8, 64)          8256      
                                                                 
 vector_quantizer (VectorQua  (None, 8, 8, 64)         32768     
 ntizer)                                                         
                                                                 
 Decoder (Decoder)           (None, 32, 32, 3)         290307    
                                                                 
Total params: 695,491
Trainable params: 695,491
Non-trainable params: 0
______________________________________________________

Compile the network

In [None]:
vqvae_trainer.compile(optimizer=Adam())

Train the network

In [100]:
history = vqvae_trainer.fit(x_train_scaled,
                            epochs=2,
                            batch_size=128)

Epoch 1/2


2023-07-15 20:46:37.303328: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


Epoch 2/2


Plot performance of the network