In [1]:
from keras.layers import Flatten, Reshape, Lambda, Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, LeakyReLU, BatchNormalization, Conv2DTranspose
from keras.models import Model
from keras import backend as K
from keras.callbacks import ModelCheckpoint
from keras.models import load_model
from keras.utils import plot_model
from keras.losses import mse

import numpy as np
from PIL import Image
import os 
# from numba import cuda
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.python.client import device_lib


Using TensorFlow backend.


In [2]:
## Data loading 

# trainData = "C:/Users/mostafaosama2/Desktop/autoenctrain/train"
trainData = "../../../autoenctrain/train"
# testData = "C:/Users/mostafaosama2/Desktop/autoenctrain/test"
testData = "../../../autoenctrain/test"

new_train = []
new_test = []

for filename in os.listdir(trainData):
    if filename.endswith('.tif'):
        image = Image.open(os.path.join(trainData, filename)) 
        new_train.append(np.asarray( image, dtype="uint8" ))

for filename in os.listdir(testData):
    if filename.endswith('.tif'):
        image = Image.open(os.path.join(testData, filename)) 
        new_test.append(np.asarray( image, dtype="uint8" ))

## Data preprocessing

x_train = np.asarray(new_train)
x_test = np.asarray(new_test)
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

In [74]:
n_epochs = 5
batch_size = 32
optimizer = 'SGD'
image_size = x_train.shape[1]
latent_dimension = 2

In [75]:
input_img = Input(shape=(96, 96, 3))  # adapt this if using `channels_first` image data format

encoder = Conv2D(filters=16, kernel_size=(3,3),strides=1, padding='same')(input_img)#x^2*16
encoder = LeakyReLU()(encoder)
encoder = Conv2D(filters=32, kernel_size=(3,3),strides=1, padding='same')(encoder)#x^2*32
encoder = LeakyReLU()(encoder)
encoder = BatchNormalization()(encoder)
encoder = Conv2D(filters=64, kernel_size=(3,3),strides=2, padding='same')(encoder)#(x/2)^2*64
encoder = LeakyReLU()(encoder)
encoder = BatchNormalization()(encoder)
encoder = MaxPooling2D()(encoder)#(x/4)^2*64
encoder = Conv2D(filters=128, kernel_size=(3,3),strides=2, padding='same')(encoder)#(x/8)^2*128
encoder = LeakyReLU()(encoder)
encoder = BatchNormalization()(encoder)
encoder = MaxPooling2D()(encoder)#(x/16)^2*64
encoder = Conv2D(filters=64, kernel_size=(3,3),strides=1, padding='same')(encoder)#(x/16)^2*64
encoder = LeakyReLU()(encoder)
encoder = BatchNormalization()(encoder)
encoder = Conv2D(filters=32, kernel_size=(3,3),strides=1, padding='same')(encoder)#(x/16)^2*32
encoder = LeakyReLU()(encoder)

bottleneck = Conv2D(filters=32, kernel_size=(1,1),strides=1, padding='same')(encoder)#(x/16)^2*16
bottleneck = LeakyReLU()(bottleneck)



In [76]:
# reparameterization trick
# instead of sampling from Q(z|X), sample epsilon = N(0,I)
# z = z_mean + sqrt(var) * epsilon
def sampling(args):
    """Reparameterization trick by sampling from an isotropic unit Gaussian.

    # Arguments
        args (tensor): mean and log of variance of Q(z|X)

    # Returns
        z (tensor): sampled latent vector
    """

    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    # by default, random_normal has mean = 0 and std = 1.0
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

In [77]:
## VAE Part

# shape info needed to build decoder model
shape = K.int_shape(bottleneck)  

# generate latent vector Q(z|X)
x = Flatten()(bottleneck)
x = Dense(128, activation='relu')(x) # (None, 6, 6, 32)

z_mean = Dense(latent_dimension, name='z_mean')(x) # (None, 6, 6, 6)
z_log_var = Dense(latent_dimension, name='z_log_var')(x) # (None, 6, 6, 6)

z = Lambda(sampling, output_shape=(latent_dimension,), name='z')([z_mean, z_log_var]) # (None, 6)

encoder = Model(input_img, [z_mean, z_log_var, z], name='encoder')

# build decoder model
latent_inputs = Input(shape=(latent_dimension,), name='z_sampling')
x = Dense(shape[1] * shape[2] * shape[3], activation='relu')(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)


In [78]:
decoder = Conv2D(filters=32, kernel_size=(1,1),strides=1, padding='same')(x)#(x/16)^2*32
decoder = LeakyReLU()(decoder)
decoder = Conv2D(filters=64, kernel_size=(3,3),strides=1, padding='same')(decoder)#(x/16)^2*64
decoder = LeakyReLU()(decoder)
decoder = BatchNormalization()(decoder)
decoder = UpSampling2D()(decoder)#(x/8)^2*64
decoder = Conv2DTranspose(filters=128, kernel_size=(3,3),strides=2, padding='same')(decoder)#(x/4)^2*128
decoder = LeakyReLU()(decoder)
decoder = BatchNormalization()(decoder)
decoder = UpSampling2D()(decoder)#(x/2)^2*128
decoder = Conv2DTranspose(filters=64, kernel_size=(3,3),strides=2, padding='same')(decoder)#x^2*64
decoder = LeakyReLU()(decoder)
decoder = BatchNormalization()(decoder)
decoder = Conv2D(filters=32, kernel_size=(3,3),strides=1, padding='same')(decoder)#x^2*32
decoder = LeakyReLU()(decoder)
decoder = BatchNormalization()(decoder)
decoder = Conv2D(filters=16, kernel_size=(3,3),strides=1, padding='same')(decoder)#x^2*16
decoder = LeakyReLU()(decoder)
decoder = Conv2D(filters=3, kernel_size=(3,3),strides=1, padding='same')(decoder)#x^2*3
decoded = LeakyReLU()(decoder)

decoder = Model(latent_inputs, decoded, name='decoder')

In [79]:
# instantiate VAE model
outputs = decoder(encoder(input_img)[2])
vae = Model(input_img, outputs, name='vae')

In [80]:
loss = mse(K.flatten(input_img), K.flatten(outputs)) # output is z (the latent dimension)
loss *= image_size * image_size # multiply the loss by the dimensions of the image

kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5

vae_loss = K.mean(loss + kl_loss)
vae.add_loss(vae_loss)

In [81]:
vae.compile(optimizer=optimizer)

autoencoder_train = vae.fit(x_train,
                epochs=n_epochs,
                batch_size=batch_size,
                shuffle=True,
                validation_data=(x_test, None)
               )


  'be expecting any data to be passed to {0}.'.format(name))


Train on 9999 samples, validate on 1500 samples
Epoch 1/5
 416/9999 [>.............................] - ETA: 20:26 - loss: nan

KeyboardInterrupt: 