<a href="https://colab.research.google.com/github/conwayjw97/Image-Colourisation/blob/master/src/CVAEBojoneCNNCifar10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from __future__ import absolute_import, division, print_function, unicode_literals
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

import tensorflow as tf
from tensorflow.keras.layers import Dense, Input, Conv2D, Flatten, Lambda, Reshape, Conv2DTranspose, concatenate
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from tensorflow.keras.datasets import cifar10, mnist

In [0]:
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()
train_images = train_images / 255.0
test_images = test_images / 255.0

class_train_indices = np.argwhere(train_labels == 0)
class_train_images = train_images[class_train_indices[:,0]]
class_test_indices = np.argwhere(test_labels == 0)
class_test_images = test_images[class_test_indices[:,0]]

train_yuvImages = tf.image.rgb_to_yuv(class_train_images)
train_y = tf.expand_dims(train_yuvImages[:,:,:,0], 3)
train_uv = train_yuvImages[:,:,:,1:]

test_yuvImages = tf.image.rgb_to_yuv(class_test_images)
test_y = tf.expand_dims(test_yuvImages[:,:,:,0], 3)
test_uv = test_yuvImages[:,:,:,1:]

In [21]:
# https://github.com/keras-team/keras/blob/master/examples/variational_autoencoder_deconv.py

# (x_train, y_train_), (x_test, y_test_) = mnist.load_data()

# image_size = x_train.shape[1]
# x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
# x_test = np.reshape(x_test, [-1, image_size, image_size, 1])
# x_train = x_train.astype('float32') / 255
# x_test = x_test.astype('float32') / 255

batch_size = 100
kernel_size = 3
filters = 64
# latent_dim = 2
latent_dim = 512
epochs = 30

# Sampling with the reparametrisation trick
def sample(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=K.shape(z_mean))
    return z_mean + K.exp(z_log_var / 2) * epsilon

# Define Training Encoder q(z|uv,y)
yuv_in = Input(shape=(train_yuvImages.shape[1], train_yuvImages.shape[2], train_yuvImages.shape[3]))
layer = yuv_in
layer = Conv2D(filters=8, kernel_size=kernel_size, activation='relu', strides=2, padding='same')(layer)
layer = Conv2D(filters=16, kernel_size=kernel_size, activation='relu', strides=2, padding='same')(layer)
layer = Conv2D(filters=32, kernel_size=kernel_size, activation='relu', strides=2, padding='same')(layer)
layer = Conv2D(filters=64, kernel_size=kernel_size, activation='relu', strides=2, padding='same')(layer)
shape = K.int_shape(layer) # Shape before flattening

# Latent space for training encoder
layer = Flatten()(layer)
layer = Dense(16, activation='relu')(layer)
z_mean_training = Dense(latent_dim)(layer)
z_log_var_training = Dense(latent_dim)(layer)
z_training = Lambda(sample, output_shape=(latent_dim,))([z_mean_training, z_log_var_training]) # Data passable to the decoder

# Instantiate training encoder 
encoder_training = Model(yuv_in, [z_mean_training, z_log_var_training, z_training], name='encoder')
encoder_training.summary()

# Define Conditional Encoder p(z|y)
y_in = Input(shape=(train_y.shape[1], train_y.shape[2], train_y.shape[3]))
layer = y_in
layer = Conv2D(filters=8, kernel_size=kernel_size, activation='relu', strides=2, padding='same')(layer)
layer = Conv2D(filters=16, kernel_size=kernel_size, activation='relu', strides=2, padding='same')(layer)
layer = Conv2D(filters=32, kernel_size=kernel_size, activation='relu', strides=2, padding='same')(layer)
layer = Conv2D(filters=64, kernel_size=kernel_size, activation='relu', strides=2, padding='same')(layer)
shapeConditional = K.int_shape(layer) # Shape before flattening

# Latent space for conditional encoder
layer = Flatten()(layer)
layer = Dense(16, activation='relu')(layer)
z_mean_conditional = Dense(latent_dim)(layer)
z_log_var_conditional = Dense(latent_dim)(layer)
z_conditional = Lambda(sample, output_shape=(latent_dim,))([z_mean_conditional, z_log_var_conditional]) # Data passable to the decoder

# Instantiate conditional encoder
encoder_training = Model(y_in, [z_mean_conditional, z_log_var_conditional, z_conditional], name='encoder')
encoder_training.summary()


# Define Decoder p(uv|z,y)
latent_inputs = Input(shape=(latent_dim,))
layer = Dense(shape[1] * shape[2] * shape[3], activation='relu')(latent_inputs)
layer = Reshape((shape[1], shape[2], shape[3]))(layer)
layer = Conv2DTranspose(filters=64, kernel_size=kernel_size, activation='relu', strides=2, padding='same')(layer)
layer = Conv2DTranspose(filters=32, kernel_size=kernel_size, activation='relu', strides=2, padding='same')(layer)
layer = Conv2DTranspose(filters=16, kernel_size=kernel_size, activation='relu', strides=2, padding='same')(layer)
layer = Conv2DTranspose(filters=8, kernel_size=kernel_size, activation='relu', strides=2, padding='same')(layer)
uv_out = Conv2DTranspose(filters=2, kernel_size=kernel_size, activation='sigmoid', padding='same')(layer)

concat_outputs = concatenate([uv_out, y_in], 3)
layer = Conv2DTranspose(filters=8, kernel_size=kernel_size, activation='relu', padding='same')(concat_outputs)
mean_yuv = Conv2DTranspose(filters=2, kernel_size=kernel_size, activation='sigmoid', padding='same')(layer)
log_sig_sqr_yuv = Conv2DTranspose(filters=2, kernel_size=kernel_size, activation='relu', padding='same')(layer)

# Instantiate Decoder
decoder = Model([latent_inputs, y_in], [mean_yuv, log_sig_sqr_yuv])
decoder.summary()

# COST FROM RECONSTRUCTION
SMALL_CONSTANT = 1e-6
normalising_factor_uv_vae = - 0.5 * K.log(SMALL_CONSTANT+K.exp(log_sig_sqr_yuv)) - 0.5 * K.log(2 * np.pi)
square_diff_between_mu_and_x_vae = K.square(mean_yuv - yuv_in) # yuv_in???
inside_exp_x_vae = -0.5 * K.div(square_diff_between_mu_and_x_vae,SMALL_CONSTANT+tf.exp(log_sig_sqr_yuv))
reconstr_loss_x_vae = -K.reduce_sum(normalising_factor_uv_vae + inside_exp_x_vae, 1)
cost_R_vae = K.reduce_mean(reconstr_loss_x_vae)

# KL(q(z|uv,y)||p(z|y))
v_mean = z_mean_conditional #2
aux_mean = z_mean_training #1
v_log_sig_sq = K.log(K.exp(z_log_var_conditional)+SMALL_CONSTANT) #2
aux_log_sig_sq = K.log(K.exp(z_log_var_training)+SMALL_CONSTANT) #1
v_log_sig = K.log(K.sqrt(K.exp(v_log_sig_sq))) #2
aux_log_sig = K.log(K.sqrt(K.exp(aux_log_sig_sq))) #1
cost_VAE_a = v_log_sig-aux_log_sig+K.divide(K.exp(aux_log_sig_sq)+K.square(aux_mean-v_mean),2*K.exp(v_log_sig_sq))-0.5
cost_VAE_b = K.reduce_sum(cost_VAE_a,1)
KL_vae = K.reduce_mean(cost_VAE_b)

# VAE
vae = Model(yuv_in, y_in, [mean_yuv, log_sig_sqr_yuv])
vae.add_loss(K.sum(cost_R_vae, KL_vae))
vae.compile(optimizer='rmsprop')

# xent_loss = K.sum(K.mean_squared_error(x_in, x_out), axis=[1, 2, 3]) 
# kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
# vae_loss = K.mean(xent_loss + kl_loss)
# vae.add_loss(vae_loss)
# vae.compile(optimizer='rmsprop')



Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_28 (InputLayer)           [(None, 32, 32, 3)]  0                                            
__________________________________________________________________________________________________
conv2d_76 (Conv2D)              (None, 16, 16, 8)    224         input_28[0][0]                   
__________________________________________________________________________________________________
conv2d_77 (Conv2D)              (None, 8, 8, 16)     1168        conv2d_76[0][0]                  
__________________________________________________________________________________________________
conv2d_78 (Conv2D)              (None, 4, 4, 32)     4640        conv2d_77[0][0]                  
____________________________________________________________________________________________

ValueError: ignored

In [0]:
# Train
# vae.fit(x_train, shuffle=True, epochs=epochs, batch_size=batch_size, validation_data=(x_test, None)) # MNIST version
vae.fit(train_uv, shuffle=True, epochs=epochs, batch_size=batch_size, validation_data=(test_uv, None))

In [0]:
# encoder = Model(x_in, z_mean)

# x_test_encoded = encoder.predict(x_test, batch_size=batch_size) 
# plt.figure(figsize=(6, 6))
# plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=test_y)
# plt.colorbar()
# plt.show()

# n = 15  # figure with 15x15 digits
# digit_size = 28
# figure = np.zeros((digit_size * n, digit_size * n))

# grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
# grid_y = norm.ppf(np.linspace(0.05, 0.95, n))

# for i, yi in enumerate(grid_x):
#     for j, xi in enumerate(grid_y):
#         z_sample = np.array([[xi, yi]])
#         x_decoded = decoder.predict(z_sample)
#         digit = x_decoded[0].reshape(digit_size, digit_size)
#         figure[i * digit_size: (i + 1) * digit_size,
#                j * digit_size: (j + 1) * digit_size] = digit

# plt.figure(figsize=(10, 10))
# plt.imshow(figure, cmap='Greys_r')
# plt.show()