In [1]:
!pip install pydot
!pip install graphviz



In [8]:
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf

import pathlib
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, LeakyReLU,Add
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt

In [3]:
AUTOTUNE = tf.data.AUTOTUNE
img_height = 256
img_width = 256
batch_size =18

In [4]:

def decode_img(img):
  # Convert the compressed string to a 3D uint8 tensor
  img = tf.io.decode_jpeg(img, channels=3)
  # Resize the image to the desired size
  return tf.image.resize(img, [img_height, img_width])


def process_path(file_path):

  # Load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  return img,

def configure_for_performance(ds):
  ds = ds.cache()
  ds = ds.shuffle(buffer_size=1000)
  ds = ds.batch(batch_size)
  ds = ds.prefetch(buffer_size=AUTOTUNE)
  return ds


In [5]:
list_ds = tf.data.Dataset.list_files('data_dir''*/*', shuffle=False)
val_size = int(tf.data.experimental.cardinality(list_ds).numpy() * 0.3)

train_ds = list_ds.skip(val_size)
val_ds = list_ds.take(val_size)

train_ds = train_ds.map(process_path, num_parallel_calls=AUTOTUNE)
val_ds = val_ds.map(process_path, num_parallel_calls=AUTOTUNE)

train_ds = configure_for_performance(train_ds)
val_ds = configure_for_performance(val_ds)

In [6]:
class VectorQuantizer(layers.Layer):
    def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
        super().__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings

        # The `beta` parameter is best kept between [0.25, 2] as per the paper.
        self.beta = beta

        # Initialize the embeddings which we will quantize. This means give me an embedding codebook with 
        # num_embedding codes each of embedding_dim dimensions. for instance for the default paramenters 
        # num_embeddings = 64 and embedding_dims = 16 this will give me a (16,64) matrix 64 vectors of dim 16
        w_init = tf.random_uniform_initializer()
        self.embeddings = tf.Variable(
            initial_value=w_init(
                shape=(self.embedding_dim, self.num_embeddings), dtype="float32"
            ),
            trainable=True,
            name="embeddings_vqvae",
        )

    def call(self, x):
        # Calculate the input shape of the inputs and
        # then flatten the inputs keeping `embedding_dim` intact.
        input_shape = tf.shape(x)
        flattened = tf.reshape(x, [-1, self.embedding_dim])
        
        

        # Quantization.
        
        # Get the index of the closest codebook vector for each of the HxW codebook vectors
        encoding_indices = self.get_code_indices(flattened)  
        
        # Encode in a onehot matrix and retrieve the corresponding codebook with a matrix multiplicatino
        encodings = tf.one_hot(encoding_indices, self.num_embeddings)   
        quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)

        # Reshape the quantized values back to the original input shape
        # I think this could just be a transpose operation
        quantized = tf.reshape(quantized, input_shape)

        # Calculate vector quantization loss and add that to the layer. You can learn more
        # about adding losses to different layers here:
        # https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check
        # the original paper to get a handle on the formulation of the loss function.
        commitment_loss = tf.reduce_mean((tf.stop_gradient(quantized) - x) ** 2)
        codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
        self.add_loss(self.beta * commitment_loss + codebook_loss)

        # Straight-through estimator.
        quantized = x + tf.stop_gradient(quantized - x)
        return quantized

    def get_code_indices(self, flattened_inputs):
        # Calculate L2-normalized distance between the inputs and the codes.
        # Similarity is a matrix that compares ALL the encoder codes with all the embedding codes. 
        # Thus the shape becomes 
        # (#HxW,EmbeddingDimensionaltiy) x (#EmbeddingDimensionaltiy,#NumCodebookVecs) : (#HxW,NumCodebookVecs)
        # (#HxW,16) x (16,64) : (#HxW,64)
        
        similarity = tf.matmul(flattened_inputs, self.embeddings)
        
        # Distances outputs  (#HXW,64)
        distances = (
            tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True)
            + tf.reduce_sum(self.embeddings ** 2, axis=0)
            - 2 * similarity
        )

        # Derive the indices for minimum distances.
        # For each of the #HXW code this next line selects the 1-out-of-64 minimum distance
        encoding_indices = tf.argmin(distances, axis=1)
        return encoding_indices


### Check Dimentionality

In [7]:
embedding_dim = 16             # The number 
num_embeddings = 128

In [8]:
# Define the components

encoder = get_encoder(num_embeddings)
vq_layer = VectorQuantizer(num_embeddings, embedding_dim, name="vector_quantizer")

In [9]:
encoder_outputs = keras.Input(shape=(7, 7, embedding_dim))
quantized_latents = vq_layer(encoder_outputs)
qunatitized = keras.Model(encoder_outputs, quantized_latents, name="vq_vae")

In [10]:

input_shape = tf.shape(encoder_outputs)
flattened = tf.reshape(encoder_outputs, [-1,embedding_dim])

print(flattened.shape)

(None, 16)


In [11]:
w_init = tf.random_uniform_initializer()
embeddings = tf.Variable(
            initial_value=w_init(shape=(embedding_dim, num_embeddings), dtype="float32"),  
            trainable=True,
            name="embeddings_vqvae",
        )
print(embeddings.shape)


(16, 64)


In [12]:
similarity = tf.matmul(flattened, embeddings)
print(similarity.shape)

(None, 64)


In [13]:
distances = (
    tf.reduce_sum(flattened ** 2, axis=1, keepdims=True)
    + tf.reduce_sum(embeddings ** 2, axis=0)
    - 2 * similarity
)

print(distances.shape)

(None, 64)


In [14]:
encoding_indices = tf.argmin(distances, axis=1)

print(encoding_indices.shape)

(None,)


In [15]:
flattened.shape         # In this case a Vector of BHWC of Batchx7x7x16 will be flatten to 49x16

TensorShape([None, 16])

## Build the VQ-VAE

In [27]:
embedding_dim = 32              # The number 
num_embeddings = 512            # The number of vectors in the codebook

### Resnet Encoder

In [41]:
def resnet_encoder(input_shape, num_filters, latent_dim):
    inputs = Input(shape=input_shape)
    x = Conv2D(num_filters, kernel_size=7, strides=2, padding='same')(inputs)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)

    # downsample via strided convolutions
    filters = [num_filters, num_filters*2, num_filters*4, num_filters*8]
    size = len(filters)
    for i in range(size):
        for j in range(2):
            # first block of each layer uses stride 2
            strides = 2 if j == 0 else 1
            x = resnet_block(x, filters[i], strides=strides)

    # final conv layer
    x = Conv2D(latent_dim, kernel_size=1, strides=1)(x)

    model = Model(inputs, x,name='Encoder')
    return model


def resnet_block(inputs, filters, strides=1):
    x = BatchNormalization()(inputs)
    x = LeakyReLU()(x)
    x = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    x = Conv2D(filters, kernel_size=3, strides=1, padding='same')(x)

    shortcut = inputs
    if strides != 1 or inputs.shape[3] != filters:
        shortcut = Conv2D(filters, kernel_size=1, strides=strides, padding='valid')(inputs)

    x = Add()([x, shortcut])
    return x


In [42]:
encoder_model = resnet_encoder(input_shape=(256, 256, 3), num_filters=64, latent_dim=embedding_dim)
encoder_model.summary()


Model: "Encoder"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_11 (InputLayer)          [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_110 (Conv2D)            (None, 128, 128, 64  9472        ['input_11[0][0]']               
                                )                                                                 
                                                                                                  
 batch_normalization_97 (BatchN  (None, 128, 128, 64  256        ['conv2d_110[0][0]']             
 ormalization)                  )                                                           

In [None]:
# tf.keras.utils.plot_model(encoder_model, to_file='encoder.png', show_shapes=True,show_layer_names=False,dpi=180)

### DCGAN Generator

In [33]:
encoder_output_shape = encoder_model.layers[-1].output_shape[1:]

In [43]:
encoder_output_shape

(8, 8, 32)

In [44]:



def dcgenerator(encoder_output_shape):
    # Input shape: (8, 8, 256)
    inputs = tf.keras.layers.Input(shape=encoder_output_shape)

    # Upsample to (16, 16, 128)
    x = tf.keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same', use_bias=False)(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)

    # Upsample to (32, 32, 64)
    x = tf.keras.layers.Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same', use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)

    # Upsample to (64, 64, 32)
    x = tf.keras.layers.Conv2DTranspose(32, (4, 4), strides=(2, 2), padding='same', use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)

    # Upsample to (128, 128, 16)
    x = tf.keras.layers.Conv2DTranspose(16, (4, 4), strides=(2, 2), padding='same', use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)

    # Upsample to (256, 256, 3)
    outputs = tf.keras.layers.Conv2DTranspose(3, (4, 4), strides=(2, 2), padding='same', activation='tanh', use_bias=False)(x)

    model = tf.keras.Model(inputs=inputs, outputs=outputs,name='Decoder')

    return model


In [46]:
vq_layer = VectorQuantizer(num_embeddings, embedding_dim, name="vector_quantizer")
encoder = resnet_encoder(input_shape=(256, 256, 3), num_filters=64, latent_dim=embedding_dim)
decoder = dcgenerator(encoder_output_shape)

inputs = keras.Input(shape=(256, 256, 3))
encoder_outputs = encoder(inputs)
quantized_latents = vq_layer(encoder_outputs)
reconstructions = decoder(quantized_latents)
model = keras.Model(inputs, reconstructions, name="vq_vae")

In [47]:
model.summary()

Model: "vq_vae"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_15 (InputLayer)       [(None, 256, 256, 3)]     0         
                                                                 
 Encoder (Functional)        (None, 8, 8, 32)          11206112  
                                                                 
 vector_quantizer (VectorQua  (None, 8, 8, 32)         16384     
 ntizer)                                                         
                                                                 
 Decoder (Functional)        (None, 256, 256, 3)       239296    
                                                                 
Total params: 11,461,792
Trainable params: 11,454,400
Non-trainable params: 7,392
_________________________________________________________________
