# Colorization with autoencoders - celeba 96x96


### Import

In [30]:
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
import PIL
from tqdm.notebook import trange, tqdm
from PIL import Image
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.layers import Conv2D, Flatten
from tensorflow.keras.layers import Reshape, Conv2DTranspose
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import plot_model
from tensorflow.keras import backend as K
from matplotlib import font_manager as fm, rcParams
from sklearn.model_selection import train_test_split

### Fonts

In [2]:
csfont = {'fontname':'Georgia'}
hfont = {'fontname':'Helvetica'}

### Parameters

In [3]:
save_dir = '../../data/celeba/colorization/'
imgs_dir = '../../data/celeba/colorization/'

### Find all Data Files

In [28]:
data = {}
filenames = []
for root, dirs, files in os.walk("../../data/celeba/pickle/", topdown=False):
    for name in files:
        if '.p' in name:
            filenames.append(os.path.join(root, name))
filenames = np.sort(filenames)

### Get labels

### Read and store

In [31]:
datalist = []
ycombined = []
for i in range(2):
    data = pickle.load( open(filenames[i],"rb"))
    X = data['X']
    y = data['y']
    y = list(y['person_id'].values)
    datalist.append(X)
datatuple = tuple(datalist)
X = np.concatenate(datatuple)


y
# x_train = data['x_train']
# y_train = data['y_train']
# x_test = data['x_test'] 
# y_test = data['y_test']

array([5024,  165, 5050, ..., 9570, 3717, 1217])

### RGB to grayscale

In [None]:
def rgb2gray(rgb):

    return np.dot(rgb[...,:3], [0.299, 0.587, 0.114])

### Load the CIFAR10 data

In [None]:
data = pickle.load(open("../../data/cifar10/pickle/data.p", "rb"))
x_train = data['x_train']
y_train = data['y_train']
x_test = data['x_test'] 
y_test = data['y_test']

### Plot some images 

In [None]:
img_rows = x_train.shape[1]
img_cols = x_train.shape[2]
channels = x_train.shape[3]
imgs = x_test[:100]
imgs = imgs.reshape((10, 10, img_rows, img_cols, channels))
imgs_orig = np.vstack([np.hstack(i) for i in imgs])
plt.figure(figsize=(8,8))
plt.axis('off')
plt.title('Test color images (Ground  Truth)',**csfont,fontsize=16)
plt.imshow(imgs_orig, interpolation='none')
plt.show()

### Convert to gray and plot again

In [None]:
x_train_gray = rgb2gray(x_train)
x_test_gray = rgb2gray(x_test)
imgs = x_test_gray[:100]
imgs = imgs.reshape((10, 10, img_rows, img_cols))
imgs_gray = np.vstack([np.hstack(i) for i in imgs])
plt.figure(figsize=(8,8))
plt.axis('off')
plt.title('Test gray images (Input)',**csfont,fontsize=16)
plt.imshow(imgs_gray, interpolation='none', cmap='gray')
plt.show()

### Normalize and Reshape

In [None]:
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
x_train_gray = x_train_gray.astype('float32') / 255
x_test_gray = x_test_gray.astype('float32') / 255
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, channels)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, channels)
x_train_gray = x_train_gray.reshape(x_train_gray.shape[0], img_rows, img_cols, 1)
x_test_gray = x_test_gray.reshape(x_test_gray.shape[0], img_rows, img_cols, 1)

### Network Parameters

In [None]:
input_shape = (img_rows, img_cols, 1)
batch_size = 32
kernel_size = 3
latent_dim = 256
layer_filters = [64, 128, 256]

### Encoder

In [None]:
inputs = Input(shape=input_shape, name='encoder_input')
x = inputs
for filters in layer_filters:
    x = Conv2D(filters=filters,
               kernel_size=kernel_size,
               strides=2,
               activation='relu',
               padding='same')(x)
shape = K.int_shape(x)
x = Flatten()(x)
latent = Dense(latent_dim, name='latent_vector')(x)
encoder = Model(inputs, latent, name='encoder')
encoder.summary()

### Decoder

In [None]:
latent_inputs = Input(shape=(latent_dim,), name='decoder_input')
x = Dense(shape[1]*shape[2]*shape[3])(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)
for filters in layer_filters[::-1]:
    x = Conv2DTranspose(filters=filters,
                        kernel_size=kernel_size,
                        strides=2,
                        activation='relu',
                        padding='same')(x)
outputs = Conv2DTranspose(filters=channels,
                          kernel_size=kernel_size,
                          activation='sigmoid',
                          padding='same',
                          name='decoder_output')(x)
decoder = Model(latent_inputs, outputs, name='decoder')
decoder.summary()

### Combine into single model

In [None]:
autoencoder = Model(inputs, decoder(encoder(inputs)), name='autoencoder')
autoencoder.summary()

### Setup Checkpoints, Learning Rate Adjustments, Optimizer, Loss Function and callbacks

In [None]:
save_dir = '../../data/cifar10/colorization/saved_models'
model_name = 'colorized_ae_model.{epoch:03d}.h5'
filepath = os.path.join(save_dir, model_name)
lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1),
                               cooldown=0,
                               patience=5,
                               verbose=1,
                               min_lr=0.5e-6)
checkpoint = ModelCheckpoint(filepath=filepath,
                             monitor='val_loss',
                             verbose=1,
                             save_best_only=True)
autoencoder.compile(loss='mse', optimizer='adam')
callbacks = [lr_reducer, checkpoint]

### Train Model

In [None]:
autoencoder.fit(x_train_gray,
                x_train,
                validation_data=(x_test_gray, x_test),
                epochs=30,
                batch_size=batch_size,
                callbacks=callbacks)
x_decoded = autoencoder.predict(x_test_gray)

### Display the result

In [None]:
imgs = x_decoded[:100]
imgs = imgs.reshape((10, 10, img_rows, img_cols, channels))
imgs_colorized = np.vstack([np.hstack(i) for i in imgs])
plt.figure(figsize=(8,8))
plt.axis('off')
plt.title('Colorized test images (Predicted)')
plt.imshow(imgs_colorized, interpolation='none')
plt.show()