<a href="https://colab.research.google.com/github/facial09/GDL_code/blob/master/VAE(VariationalAutoEncoder).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install tensorflow==1.15

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
from tensorflow.keras.layers import Lambda,Input,Conv2D, Conv2DTranspose, Dense, Flatten, Dropout, BatchNormalization, Reshape, LeakyReLU
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras import backend as K
import tensorflow as tf


## TEST

In [34]:
## 8.23 VAE encoder network
import numpy as np
img_shape = (28,28,1)
batch_size = 16
latent_dim = 2

input_img = tf.keras.Input(shape = img_shape)
x = Conv2D(32,3,padding='same',activation='relu')(input_img)
# layers.Conv2D https://stackoverflow.com/questions/43624625/why-do-we-have-to-specify-output-shape-during-deconvolution-in-tensorflow/43624992#43624992

x = Conv2D(64,3,padding='same',activation='relu',strides=(2,2))(x)
x = Conv2D(64,3,padding='same',activation='relu')(x)
x = Conv2D(64,3,padding='same',activation='relu')(x)


shape_before_flattening = K.int_shape(x) # return tuple of integers of shape of x

x = Flatten()(x)
x = Dense(32,activation='relu')(x)

z_mean = Dense(latent_dim)(x)
z_log_var = Dense(latent_dim)(x)

# Q1. how to match these values to real mean and var values of z? let me see the loss function
## 8.24 latent_space_sampling function

def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0],latent_dim),mean=0., stddev=1.)
    return z_mean + K.exp(z_log_var) * epsilon

z = Lambda(sampling)([z_mean, z_log_var])
## 8.25 VAE decoder network, mapping latent space points to imgaes

decoder_input = Input(K.int_shape(z)[1:])
x = Dense(np.prod(shape_before_flattening[1:]),activation='relu')(decoder_input)
# np.prod = https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.prod.html
x = Reshape(shape_before_flattening[1:])(x)
x = Conv2DTranspose(32,3,padding='same',activation='relu',strides=(2,2))(x)
# output shape = https://stackoverflow.com/questions/43624625/why-do-we-have-to-specify-output-shape-during-deconvolution-in-tensorflow/43624992#43624992
x = Conv2D(1,3,padding='same',activation='sigmoid')(x)

decoder = Model(decoder_input, x)
z_decoded = decoder(z)

class CustomVariationalLayer(tf.keras.layers.Layer):
    
    def vae_loss(self, x, z_decoded):
        x = K.flatten(x)
        z_decoded = K.flatten(z_decoded)
        xent_loss = tf.keras.metrics.binary_crossentropy(x,z_decoded)
        kl_loss   = -5e-4*K.mean(1+z_log_var-K.square(z_mean)-K.exp(z_log_var),axis=-1)
        return K.mean(xent_loss + kl_loss)
    
    
    def call(self, inputs):
        x = inputs[0]
        z_decoded = inputs[1]
        loss = self.vae_loss(x,z_decoded)
        self.add_loss(loss, inputs=inputs)
        return x

y = CustomVariationalLayer()([input_img,z_decoded]) #check
 
## 8.27 Training the VAE

vae = Model(input_img, y)
vae.compile(optimizer='rmsprop',loss=None) 
vae.summary()




Model: "model_16"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_30 (InputLayer)           [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
conv2d_26 (Conv2D)              (None, 28, 28, 32)   320         input_30[0][0]                   
__________________________________________________________________________________________________
conv2d_27 (Conv2D)              (None, 14, 14, 64)   18496       conv2d_26[0][0]                  
__________________________________________________________________________________________________
conv2d_28 (Conv2D)              (None, 14, 14, 64)   36928       conv2d_27[0][0]                  
___________________________________________________________________________________________

In [45]:
mnist = tf.keras.datasets.mnist
(x_train,y_train),(x_val,y_val) = mnist.load_data()
x_train=x_train.reshape(-1,28,28,1)
x_val = x_val.reshape(-1,28,28,1)

In [47]:
vae.fit(x=x_train, y=None, shuffle=True, epochs=10, batch_size=batch_size,validation_data=(x_val,None))

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Train on 60000 samples, validate on 10000 samples
Epoch 1/10

KeyboardInterrupt: ignored

In [None]:
import matplotlib.pyplot as plt
from scipy.stats import norm

n=15
digit_size = 28
figure = np.zeros((digit_size*n,digit_size*n))
grid_x = norm.ppf(np.linspace(0.05,0.95,n))
grid_y = norm.ppf(np.linspace(0.05,0.95,n))

for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
        z_sample = np.array([[xi,yi]])
        z_sample = np.tile(z_sample,batch_size).reshape(batch_size,2)
        x_decoded = decoder.predict(z_sample, batch_size = batch_size)
        digit = x_decoded[0].reshape(digit_size, digit_size)
        figure[i * digit_size: (i+1)*digit_size, j*digit_size:(j+1)*digit_size] = digit
        
plt.figure(figsize=(10,10))
plt.imshow(figure, cmap ='Greys_r')
plt.show()