<a href="https://colab.research.google.com/github/conwayjw97/Image-Colourisation/blob/master/src/CVAECifar10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
BATCH_SIZE = 100
LATENT_DIM = 2
EPOCHS = 5

In [0]:
from __future__ import absolute_import, division, print_function, unicode_literals
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

from scipy.stats import norm
from tensorflow.keras.layers import Dense, Input, Conv2D, Flatten, Lambda, Reshape, Conv2DTranspose, concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.losses import mse
from tensorflow.keras.datasets import cifar10
from tensorflow.keras import backend as K

In [0]:
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()
train_images = train_images / 255.0
test_images = test_images / 255.0

class_index = 6

class_train_indices = np.argwhere(train_labels == class_index)
class_train_images = train_images[class_train_indices[:,0]]
class_test_indices = np.argwhere(test_labels == class_index)
class_test_images = test_images[class_test_indices[:,0]]

train_yuv  = tf.image.rgb_to_yuv(class_train_images)
train_y = tf.expand_dims(train_yuv[:,:,:,0], 3)
train_uv = train_yuv[:,:,:,1:]

test_yuv = tf.image.rgb_to_yuv(class_test_images)
test_y = tf.expand_dims(test_yuv[:,:,:,0], 3)
test_uv = test_yuv[:,:,:,1:]

In [90]:
# https://github.com/keras-team/keras/blob/master/examples/variational_autoencoder_deconv.py
# https://xiangyutang2.github.io/auto-colorization-autoencoders/

# Sampling with the reparametrisation trick
def sample(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=K.shape(z_mean))
    return z_mean + K.exp(z_log_var / 2) * epsilon

def slice_uv(args):
    yuv_image = args
    return yuv_image[:,:,:,1:]

yuv_in = Input(shape=(train_yuv.shape[1], train_yuv.shape[2], train_yuv.shape[3]))
uv = Lambda(slice_uv, output_shape=(train_yuv.shape[1], train_yuv.shape[2], 2))(yuv_in)

# Define Training Encoder q(z|yuv)
layer = Conv2D(filters=8, kernel_size=3, activation='relu', strides=2, padding='same')(uv)
layer = Conv2D(filters=16, kernel_size=3, activation='relu', strides=2, padding='same')(layer)
layer = Conv2D(filters=64, kernel_size=3, activation='relu', strides=2, padding='same')(layer)
layer = Conv2D(filters=128, kernel_size=3, activation='relu', strides=2, padding='same')(layer)
shape = K.int_shape(layer) # Shape before flattening

# Latent space for encoder
layer = Flatten()(layer)
layer = Dense(512, activation='relu')(layer)
z_mean_training = Dense(LATENT_DIM)(layer)
z_log_var_training = Dense(LATENT_DIM)(layer)
z_training = Lambda(sample, output_shape=(LATENT_DIM,))([z_mean_training, z_log_var_training]) # Data passable to the decoder

# Instantiate encoder 
encoder_training = Model(yuv_in, z_training, name='training_encoder')
encoder_training.summary()

Model: "training_encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_89 (InputLayer)           [(None, 32, 32, 3)]  0                                            
__________________________________________________________________________________________________
lambda_18 (Lambda)              (None, 32, 32, 2)    0           input_89[0][0]                   
__________________________________________________________________________________________________
conv2d_48 (Conv2D)              (None, 16, 16, 8)    152         lambda_18[0][0]                  
__________________________________________________________________________________________________
conv2d_49 (Conv2D)              (None, 8, 8, 16)     1168        conv2d_48[0][0]                  
___________________________________________________________________________________

In [91]:
y_conditional_in = Input(shape=(train_y.shape[1], train_y.shape[2], train_y.shape[3]))

# Define Conditional Encoder p(z|y)
layer = Conv2D(filters=8, kernel_size=3, activation='relu', strides=2, padding='same')(y_conditional_in)
layer = Conv2D(filters=16, kernel_size=3, activation='relu', strides=2, padding='same')(layer)
layer = Conv2D(filters=64, kernel_size=3, activation='relu', strides=2, padding='same')(layer)
layer = Conv2D(filters=128, kernel_size=3, activation='relu', strides=2, padding='same')(layer)

# Latent space for conditional encoder
layer = Flatten()(layer)
layer = Dense(512, activation='relu')(layer)
z_mean_conditional = Dense(LATENT_DIM)(layer)
z_log_var_conditional = Dense(LATENT_DIM)(layer)
z_conditional = Lambda(sample, output_shape=(LATENT_DIM,))([z_mean_conditional, z_log_var_conditional]) # Data passable to the decoder

# Instantiate conditional encoder
encoder_conditional = Model(y_conditional_in, z_conditional, name='conditional_encoder')
encoder_conditional.summary()

Model: "conditional_encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_90 (InputLayer)           [(None, 32, 32, 1)]  0                                            
__________________________________________________________________________________________________
conv2d_52 (Conv2D)              (None, 16, 16, 8)    80          input_90[0][0]                   
__________________________________________________________________________________________________
conv2d_53 (Conv2D)              (None, 8, 8, 16)     1168        conv2d_52[0][0]                  
__________________________________________________________________________________________________
conv2d_54 (Conv2D)              (None, 4, 4, 64)     9280        conv2d_53[0][0]                  
________________________________________________________________________________

In [92]:
y_in = Input(shape=(train_y.shape[1], train_y.shape[2], train_y.shape[3]))

# Define Decoder p(yuv|z,y)
latent_inputs = Input(shape=(LATENT_DIM,))
layer = Dense(shape[1] * shape[2] * shape[3], activation='relu')(latent_inputs)
layer = Reshape((shape[1], shape[2], shape[3]))(layer)
layer = Conv2DTranspose(filters=128, kernel_size=3, activation='relu', strides=2, padding='same')(layer)
layer = Conv2DTranspose(filters=64, kernel_size=3, activation='relu', strides=2, padding='same')(layer)
layer = Conv2DTranspose(filters=16, kernel_size=3, activation='relu', strides=2, padding='same')(layer)
layer = Conv2DTranspose(filters=8, kernel_size=3, activation='relu', strides=2, padding='same')(layer)
uv_out = Conv2DTranspose(filters=2, kernel_size=3, activation='tanh', padding='same')(layer)
concat_outputs = concatenate([uv_out, y_in], 3)
layer = Conv2DTranspose(filters=8, kernel_size=3, activation='relu', padding='same')(concat_outputs)
mean_yuv = Conv2DTranspose(filters=3, kernel_size=3, activation='tanh', padding='same')(layer)
log_sig_sqr_yuv = Conv2DTranspose(filters=3, kernel_size=3, activation='relu', padding='same')(layer)

# Instantiate Decoder
decoder = Model([latent_inputs, y_in], mean_yuv, name='decoder')
decoder.summary()
training_yuv_out = decoder([z_training, y_in])
conditional_yuv_out = decoder([z_conditional, y_conditional_in])

Model: "decoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_92 (InputLayer)           [(None, 2)]          0                                            
__________________________________________________________________________________________________
dense_50 (Dense)                (None, 512)          1536        input_92[0][0]                   
__________________________________________________________________________________________________
reshape_8 (Reshape)             (None, 2, 2, 128)    0           dense_50[0][0]                   
__________________________________________________________________________________________________
conv2d_transpose_64 (Conv2DTran (None, 4, 4, 128)    147584      reshape_8[0][0]                  
____________________________________________________________________________________________

In [94]:
# Instantiate Training VAE
# training_yuv_in = Input(shape=(train_yuv.shape[1], train_yuv.shape[2], train_yuv.shape[3]))
# training_y_in = Input(shape=(train_y.shape[1], train_y.shape[2], train_y.shape[3]))
# training_z = encoder_training(training_yuv_in)
# training_yuv_out = decoder([z_training, y_decoder_in])

training_vae = Model([yuv_in, y_in], training_yuv_out, name='training_vae')
training_vae.summary()

# Instantiate Predictor VAE
# conditional_y_in = Input(shape=(train_y.shape[1], train_y.shape[2], train_y.shape[3]))
# conditional_z = encoder_conditional(y_conditional_in)
# conditional_yuv_out = decoder([conditional_z, y_conditional_in])

# predictor_vae = Model(y_conditional_in, conditional_yuv_out, name='predictor_vae')
# predictor_vae.summary()



# vae = Model([yuv_in, y_in, y_conditional_in], yuv_out, name='vae')
# vae.summary()

Note that input tensors are instantiated via `tensor = tf.keras.Input(shape)`.
The tensor that caused the issue was: input_91:0
Model: "training_vae"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_89 (InputLayer)           [(None, 32, 32, 3)]  0                                            
__________________________________________________________________________________________________
lambda_18 (Lambda)              (None, 32, 32, 2)    0           input_89[0][0]                   
__________________________________________________________________________________________________
conv2d_48 (Conv2D)              (None, 16, 16, 8)    152         lambda_18[0][0]                  
__________________________________________________________________________________________________
conv2d_49 (Conv2D)              (None, 8, 8, 16)     1168 

In [95]:
reconstruction_loss = K.sum(mse(yuv_in, training_yuv_out))
reconstruction_loss *= (train_yuv.shape[1] * train_yuv.shape[2])
kl_loss = 1 + z_log_var_training - K.square(z_mean_training) - K.exp(z_log_var_training)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
training_loss = K.mean(reconstruction_loss + kl_loss)

training_vae.add_loss(training_loss)
training_vae.compile(optimizer='adam')

# # COST FROM RECONSTRUCTION
# SMALL_CONSTANT = 1e-6
# normalising_factor_uv_vae = - 0.5 * K.log(SMALL_CONSTANT+K.exp(log_sig_sqr_yuv)) - 0.5 * K.log(2 * np.pi)
# square_diff_between_mu_and_yuv_vae = K.square(mean_yuv - yuv_in) # yuv_in???
# inside_exp_x_vae = -0.5 * (square_diff_between_mu_and_yuv_vae / (SMALL_CONSTANT+K.exp(log_sig_sqr_yuv)))
# reconstr_loss_x_vae = -K.sum(normalising_factor_uv_vae + inside_exp_x_vae, 1)
# cost_R_vae = K.mean(reconstr_loss_x_vae)

# # KL(q(z|uv,y)||p(z|y))
# v_mean = z_mean_conditional #2
# aux_mean = z_mean_training #1
# v_log_sig_sq = K.log(K.exp(z_log_var_conditional)+SMALL_CONSTANT) #2
# aux_log_sig_sq = K.log(K.exp(z_log_var_training)+SMALL_CONSTANT) #1
# v_log_sig = K.log(K.sqrt(K.exp(v_log_sig_sq))) #2
# aux_log_sig = K.log(K.sqrt(K.exp(aux_log_sig_sq))) #1
# cost_VAE_a = v_log_sig-aux_log_sig+((K.exp(aux_log_sig_sq)+K.square(aux_mean-v_mean))/(2*K.exp(v_log_sig_sq)))-0.5
# cost_VAE_b = K.sum(cost_VAE_a,1)
# KL_vae = K.mean(cost_VAE_b)



In [96]:
# Training VAE fitting
history = training_vae.fit([train_yuv, train_y], shuffle=False, epochs=EPOCHS, batch_size=BATCH_SIZE, validation_data=([test_yuv, test_y], None))

Train on 5000 samples, validate on 1000 samples
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [98]:
conditional_loss = K.sum(mse(z_training, z_conditional))
encoder_conditional.add_loss(conditional_loss)
encoder_conditional.compile(optimizer='adam')

ValueError: ignored

In [0]:
fig = plt.figure(figsize=(38, 9))

fig.add_subplot(1,4,1)
plt.title("Fidelity")
plt.plot(history.history["loss"][5:], label="Training Image Loss")
plt.plot(history.history["val_loss"][5:], label="Testing Image Loss")
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

In [0]:
result = vae.predict([train_yuv, train_y])

output_count = 12
index_offset = 20

fig = plt.figure(figsize=(39, 39))
for i in range(output_count):
  fig.add_subplot(1,output_count,i+1)
  plt.title("Original")
  plt.imshow(class_train_images[i+index_offset])

fig = plt.figure(figsize=(39, 39))
for i in range(output_count):
  fig.add_subplot(1,output_count,i+1)
  plt.title("Output")
  plt.imshow(tf.image.yuv_to_rgb(result[i+index_offset]))

fig = plt.figure(figsize=(39, 39))
for i in range(output_count):
  fig.add_subplot(1,output_count,i+1)
  plt.title("U original")
  plt.imshow(train_yuv[i+index_offset,:,:,1])
  
fig = plt.figure(figsize=(39, 39))
for i in range(output_count):
  fig.add_subplot(1,output_count,i+1)
  plt.title("U output")
  plt.imshow(result[i+index_offset,:,:,1])
  
fig = plt.figure(figsize=(39, 39))
for i in range(output_count):
  fig.add_subplot(1,output_count,i+1)
  plt.title("V original")
  plt.imshow(train_yuv[i+index_offset,:,:,2])
  
fig = plt.figure(figsize=(39, 39))
for i in range(output_count):
  fig.add_subplot(1,output_count,i+1)
  plt.title("V output")
  plt.imshow(result[i+index_offset,:,:,2])

In [0]:
result = vae.predict([train_yuv[test_y.shape[0]:test_y.shape[0]*2], test_y])

output_count = 12
index_offset = 20

fig = plt.figure(figsize=(39, 39))
for i in range(output_count):
  fig.add_subplot(1,output_count,i+1)
  plt.title("Original")
  plt.imshow(class_test_images[i+index_offset])

fig = plt.figure(figsize=(39, 39))
for i in range(output_count):
  fig.add_subplot(1,output_count,i+1)
  plt.title("Output")
  plt.imshow(tf.image.yuv_to_rgb(result[i+index_offset]))

fig = plt.figure(figsize=(39, 39))
for i in range(output_count):
  fig.add_subplot(1,output_count,i+1)
  plt.title("U original")
  plt.imshow(test_yuv[i+index_offset,:,:,1])
  
fig = plt.figure(figsize=(39, 39))
for i in range(output_count):
  fig.add_subplot(1,output_count,i+1)
  plt.title("U output")
  plt.imshow(result[i+index_offset,:,:,1])
  
fig = plt.figure(figsize=(39, 39))
for i in range(output_count):
  fig.add_subplot(1,output_count,i+1)
  plt.title("V original")
  plt.imshow(test_yuv[i+index_offset,:,:,2])
  
fig = plt.figure(figsize=(39, 39))
for i in range(output_count):
  fig.add_subplot(1,output_count,i+1)
  plt.title("V output")
  plt.imshow(result[i+index_offset,:,:,2])