
Colorization autoencoder<br>
The autoencoder is trained with grayscale images as input<br>
and colored images as output.<br>
Colorization autoencoder can be treated like the opposite<br>
of denoising autoencoder. Instead of removing noise, colorization<br>
adds noise (color) to the grayscale image.<br>
Grayscale Images --> Colorization --> Color Images<br>


In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

In [None]:
from keras.layers import Dense, Input
from keras.layers import Conv2D, Flatten
from keras.layers import Reshape, Conv2DTranspose
from keras.models import Model
from keras.callbacks import ReduceLROnPlateau
from keras.callbacks import ModelCheckpoint
from keras.datasets import cifar10
from keras.utils import plot_model
from keras import backend as K

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os

In [None]:
def rgb2gray(rgb):
    """Convert from color image (RGB) to grayscale.
       Source: opencv.org
       grayscale = 0.299*red + 0.587*green + 0.114*blue
    Argument:
        rgb (tensor): rgb image
    Return:
        (tensor): grayscale image
    """
    return np.dot(rgb[...,:3], [0.299, 0.587, 0.114])

load the CIFAR10 data

In [None]:
(x_train, _), (x_test, _) = cifar10.load_data()

input image dimensions<br>
we assume data format "channels_last"

In [None]:
img_rows = x_train.shape[1]
img_cols = x_train.shape[2]
channels = x_train.shape[3]

create saved_images folder

In [None]:
imgs_dir = 'saved_images'
save_dir = os.path.join(os.getcwd(), imgs_dir)
if not os.path.isdir(save_dir):
        os.makedirs(save_dir)

display the 1st 100 input images (color and gray)

In [None]:
imgs = x_test[:100]
imgs = imgs.reshape((10, 10, img_rows, img_cols, channels))
imgs = np.vstack([np.hstack(i) for i in imgs])
plt.figure()
plt.axis('off')
plt.title('Test color images (Ground  Truth)')
plt.imshow(imgs, interpolation='none')
plt.savefig('%s/test_color.png' % imgs_dir)
plt.show()

convert color train and test images to gray

In [None]:
x_train_gray = rgb2gray(x_train)
x_test_gray = rgb2gray(x_test)

display grayscale version of test images

In [None]:
imgs = x_test_gray[:100]
imgs = imgs.reshape((10, 10, img_rows, img_cols))
imgs = np.vstack([np.hstack(i) for i in imgs])
plt.figure()
plt.axis('off')
plt.title('Test gray images (Input)')
plt.imshow(imgs, interpolation='none', cmap='gray')
plt.savefig('%s/test_gray.png' % imgs_dir)
plt.show()

normalize output train and test color images

In [None]:
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

normalize input train and test grayscale images

In [None]:
x_train_gray = x_train_gray.astype('float32') / 255
x_test_gray = x_test_gray.astype('float32') / 255

reshape images to row x col x channel for CNN output/validation

In [None]:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, channels)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, channels)

reshape images to row x col x channel for CNN input

In [None]:
x_train_gray = x_train_gray.reshape(x_train_gray.shape[0], img_rows, img_cols, 1)
x_test_gray = x_test_gray.reshape(x_test_gray.shape[0], img_rows, img_cols, 1)

network parameters

In [None]:
input_shape = (img_rows, img_cols, 1)
batch_size = 32
kernel_size = 3
latent_dim = 256
# encoder/decoder number of CNN layers and filters per layer
layer_filters = [64, 128, 256]

build the autoencoder model<br>
first build the encoder model

In [None]:
inputs = Input(shape=input_shape, name='encoder_input')
x = inputs
# stack of Conv2D(64)-Conv2D(128)-Conv2D(256)
for filters in layer_filters:
    x = Conv2D(filters=filters,
               kernel_size=kernel_size,
               strides=2,
               activation='relu',
               padding='same')(x)

shape info needed to build decoder model so we don't do hand computation<br>
the input to the decoder's first Conv2DTranspose will have this shape<br>
shape is (4, 4, 256) which is processed by the decoder back to (32, 32, 3)

In [None]:
shape = K.int_shape(x)

generate a latent vector

In [None]:
x = Flatten()(x)
latent = Dense(latent_dim, name='latent_vector')(x)

instantiate encoder model

In [None]:
encoder = Model(inputs, latent, name='encoder')
encoder.summary()

build the decoder model

In [None]:
latent_inputs = Input(shape=(latent_dim,), name='decoder_input')
x = Dense(shape[1]*shape[2]*shape[3])(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)

stack of Conv2DTranspose(256)-Conv2DTranspose(128)-Conv2DTranspose(64)

In [None]:
for filters in layer_filters[::-1]:
    x = Conv2DTranspose(filters=filters,
                        kernel_size=kernel_size,
                        strides=2,
                        activation='relu',
                        padding='same')(x)

In [None]:
outputs = Conv2DTranspose(filters=channels,
                          kernel_size=kernel_size,
                          activation='sigmoid',
                          padding='same',
                          name='decoder_output')(x)

instantiate decoder model

In [None]:
decoder = Model(latent_inputs, outputs, name='decoder')
decoder.summary()

autoencoder = encoder + decoder<br>
instantiate autoencoder model

In [None]:
autoencoder = Model(inputs, decoder(encoder(inputs)), name='autoencoder')
autoencoder.summary()

prepare model saving directory.

In [None]:
save_dir = os.path.join(os.getcwd(), 'saved_models')
model_name = 'colorized_ae_model.{epoch:03d}.h5'
if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
filepath = os.path.join(save_dir, model_name)

reduce learning rate by sqrt(0.1) if the loss does not improve in 5 epochs

In [None]:
lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1),
                               cooldown=0,
                               patience=5,
                               verbose=1,
                               min_lr=0.5e-6)

save weights for future use (e.g. reload parameters w/o training)

In [None]:
checkpoint = ModelCheckpoint(filepath=filepath,
                             monitor='val_loss',
                             verbose=1,
                             save_best_only=True)

Mean Square Error (MSE) loss function, Adam optimizer

In [None]:
autoencoder.compile(loss='mse', optimizer='adam')

called every epoch

In [None]:
callbacks = [lr_reducer, checkpoint]

train the autoencoder

In [None]:
autoencoder.fit(x_train_gray,
                x_train,
                validation_data=(x_test_gray, x_test),
                epochs=30,
                batch_size=batch_size,
                callbacks=callbacks)

predict the autoencoder output from test data

In [None]:
x_decoded = autoencoder.predict(x_test_gray)

display the 1st 100 colorized images

In [None]:
imgs = x_decoded[:100]
imgs = imgs.reshape((10, 10, img_rows, img_cols, channels))
imgs = np.vstack([np.hstack(i) for i in imgs])
plt.figure()
plt.axis('off')
plt.title('Colorized test images (Predicted)')
plt.imshow(imgs, interpolation='none')
plt.savefig('%s/colorized.png' % imgs_dir)
plt.show()