In [1]:
import numpy as np
import tensorflow as tf

In [None]:
# roadmap:
# 1-implement dcgan in tf and observe generated examples
# 2-investigate tf dcgan implementation
# 3-implement dcgan in torch and observed generated examples
# 4-investigate torch dcgan implementation

In [155]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = np.expand_dims(x_train,axis=3)
x_train = tf.cast(x_train,dtype=tf.float32)

# [0,255] -> [0,1] -> [-1,1]
x_train = (x_train/255.) * 2. - 1.

train_ds = tf.data.Dataset.from_tensor_slices(x_train)
train_ds = train_ds.shuffle(1000).batch(64)
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)


In [137]:
# let's say batch is 2
batch_size = 2
latent_code = tf.random.normal(shape=(batch_size,100))

# GENERATOR
# use batchnorm and dropout on G
inputs = tf.keras.Input(shape=(100,))
x = tf.keras.layers.Dense(units=(4*4*256),activation='relu')(inputs)
x = tf.reshape(x,[-1,4,4,256])
x = tf.keras.layers.Conv2DTranspose(filters=128,kernel_size=5,strides=2,activation='relu',padding='same')(x)
x = tf.keras.layers.Conv2DTranspose(filters=64,kernel_size=7,strides=1,activation='relu',padding='valid')(x)
outputs = tf.keras.layers.Conv2DTranspose(filters=1,kernel_size=5,strides=2,activation='tanh',padding='same')(x)

Generator = tf.keras.Model(inputs=inputs, outputs=[outputs])

# DISCRIMINATOR
# use batchnorm on D
inputs = tf.keras.Input(shape=(28,28,1))
x = tf.keras.layers.Conv2D(filters=64,kernel_size=5,strides=2)(inputs)
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
x = tf.keras.layers.Conv2D(filters=128,kernel_size=5,strides=2)(x)
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
x = tf.keras.layers.Flatten()(x)
outputs = tf.keras.layers.Dense(units=1)(x)

Discriminator = tf.keras.Model(inputs=inputs, outputs=[outputs])


In [None]:
class DCGAN(tf.keras.Model):
    def __init__(self,latent_dim):
        super().__init__()
        self.latent_dim = latent_dim
        self.generator = get_generator()
        self.discriminator = get_discriminator()

    def call(self,img,labels):
        # encoder q(z|x,y)
        enc1_output = self.encoder_block1(img)
        # concat feature maps and one hot label vector
        img_lbl_concat = np.concatenate((enc1_output,labels),axis=1)
        z_mu,z_rho = self.encoder_block2(img_lbl_concat)

        # sampling
        epsilon = tf.random.normal(shape=z_mu.shape,mean=0.0,stddev=1.0)
        z = z_mu + tf.math.softplus(z_rho) * epsilon

        # decoder p(x|z,y)
        z_lbl_concat = np.concatenate((z,labels),axis=1)
        decoded_img = self.decoder_block(z_lbl_concat)

        return z_mu,z_rho,decoded_img

In [139]:
model = VAE(latent_dim)

optimizer = tf.keras.optimizers.Adam(learning_rate = 0.001)

kl_loss_tracker = tf.keras.metrics.Mean(name='kl_loss')
mse_loss_tracker = tf.keras.metrics.Mean(name='mse_loss')


for epoch in range(epochs):

    label_list = None
    z_mu_list = None    

    for _,(imgs,labels) in train_ds.enumerate():
        
        # training loop
        with tf.GradientTape() as tape:
            # forward pass
            z_mu,z_rho,decoded_imgs = model(imgs)

            # compute loss
            mse,kl = elbo(z_mu,z_rho,decoded_imgs,imgs)
            loss = mse + beta * kl
        
        # compute gradients
        gradients = tape.gradient(loss,model.variables)

        # update weights
        optimizer.apply_gradients(zip(gradients, model.variables))

        # update metrics
        kl_loss_tracker.update_state(beta * kl)
        mse_loss_tracker.update_state(mse)

        # save encoded means and labels for latent space visualization
        if label_list is None:
            label_list = labels
        else:
            label_list = np.concatenate((label_list,labels))
            
        if z_mu_list is None:
            z_mu_list = z_mu
        else:
            z_mu_list = np.concatenate((z_mu_list,z_mu),axis=0)


    # generate new samples
    generate_images(model,dataset_mean,dataset_std,temp_x_test=None)
    # encode and decode samples from test data
    generate_images(model,dataset_mean,dataset_std,temp_x_test=x_test[:16])
    # visualize the latent space by non-linear dim reduction
    visualize_latent_space(z_mu_list,label_list)
    # plot 2D digit manifold if latent dim=2
    if latent_dim==2:
        plot_latent_images(model,dataset_mean,dataset_std)

    # display metrics at the end of each epoch.
    epoch_kl,epoch_mse = kl_loss_tracker.result(),mse_loss_tracker.result()
    print(f'epoch: {epoch}, mse: {epoch_mse:.4f}, kl_div: {epoch_kl:.4f}')

    # reset metric states
    kl_loss_tracker.reset_state()
    mse_loss_tracker.reset_state()

<tf.Tensor: shape=(2, 1), dtype=float32, numpy=
array([[-0.00242834],
       [-0.00438537]], dtype=float32)>