# Generating images with variational autoencoders

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import keras
import numpy as np

from keras.datasets import mnist
from keras import layers
from keras import backend as K
from keras.models import Model

from keras import backend as K
#K.tensorflow_backend._get_available_gpus()


import pickle
import os
from sklearn.model_selection import train_test_split
from keras.preprocessing.image import array_to_img, img_to_array

def deserialize(path):
    obj = None
    with open(path, 'rb') as handle:
        obj = pickle.load(handle)
    return(obj)
    
    with open(path, 'wb') as handle:
        pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)
        

In [None]:
K.clear_session()
os.environ["CUDA_VISIBLE_DEVICES"]="1"

## Loading the MNIST Dataset

In [None]:
#(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [None]:
dataset = deserialize("../../BreakHis_Encoder/derived/adenosis_gray.pkl")
x_train, x_test = train_test_split(dataset, test_size=0.2, random_state=42)
plt.imshow(x_train[0,:,:], cmap='gray')

## VAE Encoder Network

In [None]:
img = x_train[0,:,:]
img_shape = (img.shape[0], img.shape[1], 1)
batch_size = 16
latent_dim = 2
input_img = keras.Input(shape=img_shape)

In [None]:
input_img

In [None]:
#x = layers.Conv2D(32, 3,
#                  padding='same', activation='relu')(input_img)
#x = layers.Conv2D(64, 3,
#                  padding='same', activation='relu',
#                  strides=(2, 2))(x)
#K.int_shape(x)

In [None]:
x = layers.Conv2D(32, 3,
                  padding='same', activation='relu')(input_img)
x = layers.Conv2D(64, 3,
                  padding='same', activation='relu',
                  strides=(2, 2))(x)
x = layers.Conv2D(64, 3,
                  padding='same', activation='relu')(x)
x = layers.Conv2D(64, 3,
                  padding='same', activation='relu')(x)
shape_before_flattening = K.int_shape(x)

In [None]:
shape_before_flattening

In [None]:
x = layers.Flatten()(x)
x = layers.Dense(32, activation='relu')(x)

In [None]:
z_mean = layers.Dense(latent_dim)(x)
z_log_var = layers.Dense(latent_dim)(x)

## Latent-space-sampling function

In [None]:
def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim),
                              mean=0., stddev=1.)
    return z_mean + K.exp(z_log_var) * epsilon
z = layers.Lambda(sampling)([z_mean, z_log_var])

## VAE decoder network, mapping latent space points to images

In [None]:
decoder_input = layers.Input(K.int_shape(z)[1:])
x = layers.Dense(np.prod(shape_before_flattening[1:]),
                 activation='relu')(decoder_input)
x = layers.Reshape(shape_before_flattening[1:])(x)
x = layers.Conv2DTranspose(32, 3,
                           padding='same',
                           activation='relu',
                           strides=(2, 2))(x)
x = layers.Conv2D(1, 3,
                  padding='same',
                  activation='sigmoid')(x)

In [None]:
decoder = Model(decoder_input, x)
z_decoded = decoder(z)

In [None]:
z_decoded

In [None]:
class CustomVariationalLayer(keras.layers.Layer):
    
    def vae_loss(self, x, z_decoded):
        x = K.flatten(x)
        z_decoded = K.flatten(z_decoded)
        xent_loss = keras.metrics.binary_crossentropy(x, z_decoded) 
        kl_loss = -5e-4 * K.mean(
            1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
        return K.mean(xent_loss + kl_loss)
    
    def call(self, inputs):
        x = inputs[0]
        z_decoded = inputs[1]
        loss = self.vae_loss(x, z_decoded)
        self.add_loss(loss, inputs=inputs)
        return x

In [None]:
y = CustomVariationalLayer()([input_img, z_decoded])

## Training the VAE

In [None]:
vae = Model(input_img, y)
vae.compile(optimizer='rmsprop', loss=None)
vae.summary()

In [None]:
#(x_train, _), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_train = x_train.reshape(x_train.shape + (1,))
x_test = x_test.astype('float32') / 255.
x_test = x_test.reshape(x_test.shape + (1,))
vae.fit(x=x_train, y=None,
        shuffle=True,
        epochs=10,
        batch_size=batch_size,
        validation_data=(x_test, None))

## Sampling a grid of points from the 2D latent space and decoding them to images

In [None]:
import matplotlib.pyplot as plt
from scipy.stats import norm

n = 10
height = img_shape[0]
width = img_shape[1]

figure = np.zeros((height * n, width * n))
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))

for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
        z_sample = np.array([[xi, yi]])
        z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)
        x_decoded = decoder.predict(z_sample, batch_size=batch_size)
        decoded_img = x_decoded[0].reshape(height, width)
        figure[i * height: (i + 1) * height,
               j * width: (j + 1) * width] = decoded_img
        
plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
plt.show()

#n = 15
#digit_size = 28
#figure = np.zeros((digit_size * n, digit_size * n))
#grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
#grid_y = norm.ppf(np.linspace(0.05, 0.95, n))
#
#for i, yi in enumerate(grid_x):
#    for j, xi in enumerate(grid_y):
#        z_sample = np.array([[xi, yi]])
#        z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)
#        x_decoded = decoder.predict(z_sample, batch_size=batch_size)
#        digit = x_decoded[0].reshape(digit_size, digit_size)
#        figure[i * digit_size: (i + 1) * digit_size,
#               j * digit_size: (j + 1) * digit_size] = digit
#        
#plt.figure(figsize=(10, 10))
#plt.imshow(figure, cmap='Greys_r')
#plt.show()

In [None]:
z_sample = np.array([[grid_x[0], grid_y[0]]])
z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)
x_decoded = decoder.predict(z_sample, batch_size=batch_size)
decoded_img = x_decoded[0]

array_to_img(decoded_img, scale=True)

In [None]:
z_sample = np.array([[grid_x[1], grid_y[1]]])
z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)
x_decoded = decoder.predict(z_sample, batch_size=batch_size)
decoded_img = x_decoded[0]

array_to_img(decoded_img, scale=True)