# VQVAE and Residual Stack

In [9]:
from keras import layers as Layer, Input, Model, Sequential
from keras.datasets import mnist, cifar10
from keras.optimizers import Adam
from keras.metrics import Mean, MAE
from keras.models import load_model
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import pandas as pd
import glob

In [10]:
tf.config.get_visible_devices()

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [98]:
class ResidualStack(tf.Module):
    def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
               name=None):
        super(ResidualStack, self).__init__(name=name)
        self._num_hiddens = num_hiddens
        self._num_residual_layers = num_residual_layers
        self._num_residual_hiddens = num_residual_hiddens

        self._layers = []
        for idx in range(num_residual_layers):
            conv3 = Layer.Conv2D(num_residual_hiddens, kernel_size=3, strides=1, padding='same', name=f'res3x3_{idx}')
            conv1 = Layer.Conv2D(num_hiddens, kernel_size=1, strides=1, padding='same', name=f'res1x1_{idx}')
            self._layers.append((conv3, conv1))


    def __call__(self, inputs):
        h = inputs
        for conv3, conv1 in self._layers:
            conv3_out_l = conv3(tf.nn.relu(h))
            conv1_out_l = conv1(tf.nn.relu(conv3_out_l))
            h += conv1_out_l
        return tf.nn.relu(h)

In [147]:
class Encoder(Model):
    def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
               name=None):
        super(Encoder, self).__init__(name=name)

        self._num_hiddens = num_hiddens
        self._num_residual_layers = num_residual_layers
        self._num_residual_hiddens = num_residual_hiddens

        self._enc_l1 = Layer.Conv2D(self._num_hiddens // 2, kernel_size=(4,4), strides=(2,2), name='enc_l1')
        self._enc_l2 = Layer.Conv2D(self._num_hiddens, kernel_size=(4,4), strides=(2,2), name='enc_l2')
        self._enc_l3 = Layer.Conv2D(self._num_hiddens, kernel_size=(3,3), strides=(1,1), name='enc_l3')
        self._residual_stack = ResidualStack(self._num_hiddens, self._num_residual_layers, self._num_residual_hiddens, name='resblock1')


    def call(self, input, training=None, mask=None):
        h = tf.nn.relu(self._enc_l1(input))
        h = tf.nn.relu(self._enc_l2(h))
        h = tf.nn.relu(self._enc_l3(h))
        return self._residual_stack(h)

In [None]:
class Decoder(Model):
    def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
               name=None):
        super(Decoder, self).__init__(name=name)


    def call(self, inputs, training=None, mask=None):
        pass

In [148]:
sample = tf.ones([1, 128, 128, 3])
sample.shape

TensorShape([1, 128, 128, 3])

In [152]:
encoder_test = Encoder(num_hiddens=128, num_residual_layers=32, num_residual_hiddens=2)
encoder_test(sample).shape

TensorShape([1, 28, 28, 128])