In [None]:
import os
os.environ['KERAS_BACKEND'] = 'tensorflow'

In [None]:
from keras.layers import Conv2D, Conv2DTranspose, UpSampling2D
from keras.layers import Activation, Dense, Dropout, Flatten, InputLayer
from keras.layers.normalization import BatchNormalization
from keras.callbacks import TensorBoard
from keras.models import Sequential
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb, rgb2gray
from skimage.io import imsave

In [None]:
import numpy as np
import random
import tensorflow as tf

In [None]:
DATASET_PATH = './dataset'

In [None]:
# Get images
X = []
for filename in os.listdir(f"{DATASET_PATH}/train"):
    X.append(img_to_array(load_img(f"{DATASET_PATH}/train/{filename}")))
X = np.array(X, dtype=float)

# Set up train and test data
split = int(0.95 * len(X))
X_train = X[:split]
X_train = 1.0/255 * X_train

In [None]:
# Build a model
model = Sequential()
model.add(InputLayer(input_shape=(256, 256, 1)))
model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(64, (3, 3), activation='relu', padding='same', strides=2))
model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(128, (3, 3), activation='relu', padding='same', strides=2))
model.add(Conv2D(256, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(256, (3, 3), activation='relu', padding='same', strides=2))
model.add(Conv2D(512, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(256, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(32, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(2, (3, 3), activation='tanh', padding='same'))
model.add(UpSampling2D((2, 2)))

In [None]:
model.compile(optimizer='rmsprop',
              loss='mse')

In [None]:
# Image transformer
datagen = ImageDataGenerator(shear_range=0.2,
                             zoom_range=0.2,
                             rotation_range=20,
                             horizontal_flip=True)

# Generate training data
def image_a_b_gen(batch_size):
    for batch in datagen.flow(X_train, batch_size=batch_size):
        lab_batch = rgb2lab(batch)
        X_batch = lab_batch[:, :, :, 0]
        y_batch = lab_batch[:, :, :, 1:] / 128
        
        yield (X_batch.reshape(X_batch.shape + (1, )), y_batch)

In [None]:
batch_size = 10
tensorboard = TensorBoard(log_dir="./log/beta")
model.fit_generator(image_a_b_gen(batch_size), 
                    callbacks=[tensorboard], 
                    epochs=1000, 
                    steps_per_epoch=10,
                    verbose=False)

In [None]:
# Save model
model_json = model.to_json()
with open("model.json", "w") as json_file:
    json_file.write(model_json)
model.save_weights("model/beta_1.h5")

In [None]:
# Test model
X_test = rgb2lab(1.0 / 255 * X[split:])[:, :, :, 0]
X_test = X_test.reshape(X_test.shape + (1, )) # from (1, 256, 256) => (1, 256, 256, 1)
y_test = rgb2lab(1.0 / 255 * X[split:])[:, :, :, 1:]
y_test = y_test / 128
print(model.evaluate(X_test, y_test, batch_size=batch_size))

In [None]:
color_me = []
for filename in os.listdir(f"{DATASET_PATH}/test/"):
    color_me.append(img_to_array(load_img(f"{DATASET_PATH}/test/{filename}")))
color_me = np.array(color_me, dtype=float)
color_me = rgb2lab(1.0 / 255 * color_me)[:, :, :, 0]
color_me = color_me.reshape(color_me.shape + (1, ))

In [None]:
# Test model
output = model.predict(color_me)
output = output * 128

for i in range(len(output)):
    cur = np.zeros((256, 256, 3))
    cur[:, :, 0] = color_me[i][:, :, 0]
    cur[:, :, 1:] = output[i]
    imsave(f"result/img_{str(i)}.png", lab2rgb(cur))