In [None]:
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
from pathlib import Path
from sklearn.model_selection import train_test_split
import tensorflow_io as tfio
from tensorflow.keras.layers import Input, Conv2D, UpSampling2D, MaxPooling2D

In [None]:
ds_path = "../datasets/horse2zebra" # enter path to the dataset
image_size = 256 # enter image size
batch_size = 8 # enter batch size
all_img_path = [str(p) for p in Path(ds_path).rglob('*.jpg')]
train_path, test_path = train_test_split(all_img_path, test_size=0.2, shuffle=True, random_state=42)

In [None]:
# function to load images and seperate the gray and color channel. this function will be used in tensorflow data pipeline.
def load_and_preprocess_images(impath):
    im = tf.io.read_file(impath)
    colorimg = tf.image.decode_image(im, channels=3)
    colorimg = tf.image.resize_with_pad(colorimg, target_height=image_size, target_width=image_size)
    colorimg = tf.cast(colorimg, dtype=tf.float32) / 255.0
    
    labimg = tfio.experimental.color.rgb_to_lab(colorimg)
    graypart = tf.expand_dims(labimg[...,0], axis=-1) / 100
    colorpart = labimg[...,1:] / 128
    
    return graypart, colorpart

In [None]:
# standard data pipeline for train and test dataset
ds_train = tf.data.Dataset.from_tensor_slices(train_path)
ds_train = ds_train.map(load_and_preprocess_images, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(len(ds_train), seed=42, reshuffle_each_iteration=False)
ds_train = ds_train.batch(batch_size)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

ds_test = tf.data.Dataset.from_tensor_slices(test_path)
ds_test = ds_test.map(load_and_preprocess_images, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.cache()
ds_test = ds_test.batch(batch_size)
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

In [None]:
# show a sample gray image and its colorful correspondence.
sample_input, sample_output = list(ds_train.as_numpy_iterator())[0]
grayimg = sample_input
colorimg = tfio.experimental.color.lab_to_rgb(np.concatenate((100*grayimg,128*sample_output), axis=-1)).numpy()
fig, axs = plt.subplots(batch_size, 2, figsize=(batch_size,4*batch_size))
for i in range(batch_size):
    axs[i,0].imshow(grayimg[i,...], cmap='gray')
    axs[i,1].imshow(colorimg[i,...])
    axs[i,0].axis('off')
    axs[i,1].axis('off')
    axs[i,0].set_title('gray image (input)', size=8)
    axs[i,1].set_title('colorful image (GT output)', size=8)
plt.show()

In [None]:
# define auto encoder-decoder network
def define_model():
    model = tf.keras.models.Sequential()

    # encoder
    model.add(Input(shape=(image_size, image_size, 1)))
    model.add(Conv2D(64, (3,3), activation='relu', padding='same'))
    model.add(Conv2D(64, (3,3), activation='relu', padding='same'))
    model.add(MaxPooling2D())
    model.add(Conv2D(128, (3,3), activation='relu', padding='same'))
    model.add(Conv2D(128, (3,3), activation='relu', padding='same'))
    model.add(MaxPooling2D())
    model.add(Conv2D(256, (3,3), activation='relu', padding='same'))
    model.add(Conv2D(256, (3,3), activation='relu', padding='same'))
    model.add(MaxPooling2D())

    #decoder
    model.add(Conv2D(256, (3,3), activation='relu', padding='same'))
    model.add(Conv2D(256, (3,3), activation='relu', padding='same'))
    model.add(UpSampling2D((2,2)))
    model.add(Conv2D(128, (3,3), activation='relu', padding='same'))
    model.add(Conv2D(128, (3,3), activation='relu', padding='same'))
    model.add(UpSampling2D((2,2)))
    model.add(Conv2D(64, (3,3), activation='relu', padding='same'))
    model.add(Conv2D(64, (3,3), activation='relu', padding='same'))
    model.add(UpSampling2D((2,2)))
    model.add(Conv2D(32, (3,3), activation='relu', padding='same'))
    model.add(Conv2D(32, (3,3), activation='relu', padding='same'))
    model.add(Conv2D(2, (3,3), activation='tanh', padding='same'))
    
    model.compile(optimizer='adam', loss='mse')

    return model
    
model = define_model()
model.summary()
# tf.keras.utils.plot_model(model)
try:
    model.load_weights('weights.h5')
except FileNotFoundError:
    print('weights NOT found.')

In [None]:
# train the model
from keras.callbacks import ModelCheckpoint, TensorBoard
checkpoint = ModelCheckpoint("weights.h5", monitor='val_loss', verbose=1, save_weights_only=True, save_best_only=True, mode='min')
tensorboard = TensorBoard(log_dir="logs\\my_saved_model")
callbacks_list = [checkpoint, tensorboard]
model.fit(ds_train, validation_data=ds_test, epochs=200, callbacks=callbacks_list)

In [None]:
# show sample result of model from the test dataset
sample_input, sample_output = list(ds_test.as_numpy_iterator())[4]
grayimg = sample_input
colorimg = tfio.experimental.color.lab_to_rgb(np.concatenate((100*grayimg,128*sample_output), axis=-1)).numpy()
pred_colorimg = tfio.experimental.color.lab_to_rgb(np.concatenate((100*grayimg,128*model.predict(grayimg, verbose=0)), axis=-1)).numpy()
fig, axs = plt.subplots(batch_size, 3, figsize=(2*batch_size,6*batch_size))
for i in range(batch_size):
    axs[i,0].imshow(grayimg[i,...], cmap='gray')
    axs[i,1].imshow(pred_colorimg[i,...])
    axs[i,2].imshow(colorimg[i,...])
    axs[i,0].set_title('gray image (input)', size=8)
    axs[i,1].set_title('predicted color', size=8)
    axs[i,2].set_title('GT color', size=8)
    axs[i,0].axis('off')
    axs[i,1].axis('off')
    axs[i,2].axis('off')
plt.show()