In [None]:
from keras.layers import Conv2D, Conv2DTranspose, UpSampling2D
from keras.layers import Activation, Dense, Dropout, Flatten, InputLayer
from keras.layers.normalization import BatchNormalization
from keras.callbacks import TensorBoard
from keras.models import Sequential
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb, rgb2gray, gray2rgb
from skimage.io import imsave
import numpy as np
import os
import random
import tensorflow as tf

In [None]:
# Get images
X = []
for filename in os.listdir('/color_300/Train/'):
    X.append(img_to_array(load_img('/color_300/Train/'+filename)))
X = np.array(X, dtype=float)
Xtrain = 1.0/255*X


In [None]:
def conv_stack(data, filters, s):
        output = Conv2D(filters, (3, 3), strides=s, activation='relu', padding='same')(data)
        output = BatchNormalization()(output)
        return output

embed_input = Input(shape=(1000,))

#Encoder
encoder_input = Input(shape=(256, 256, 1,))
encoder_output = conv_stack(encoder_input, 64, 2)
encoder_output = conv_stack(encoder_output, 128, 2)
encoder_output = conv_stack(encoder_output, 256, 2)
encoder_output = conv_stack(encoder_output, 512, 1)
encoder_output = conv_stack(encoder_output, 256, 1)

#Fusion
# y_mid: (None, 256, 28, 28)
fusion_output = RepeatVector(32 * 32)(embed_input) 
fusion_output = Permute((2, 1))(fusion_output) 
fusion_output = Reshape(([32, 32, 1000]))(fusion_output)
fusion_output = concatenate([fusion_output, encoder_output], axis=3) 
fusion_output = Conv2D(256, (1, 1), activation='relu')(fusion_output) 

#Decoder
decoder_output = UpSampling2D((2, 2))(fusion_output)
decoder_output = conv_stack(decoder_output, 128, 1)
decoder_output = conv_stack(decoder_output, 64, 1)
decoder_output = UpSampling2D((2, 2))(decoder_output)
decoder_output = conv_stack(decoder_output, 32, 1)
decoder_output = conv_stack(decoder_output, 16, 1)
decoder_output = Conv2D(2, (2, 2), activation='tanh', padding='same')(decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)

model = Model(inputs=[encoder_input, embed_input], outputs=decoder_output)

In [None]:
#Create inception embedding
def create_inception_embedding(greyscale_rgb):
    gray_scaled = np.expand_dims(gray_scaled, axis=0)
    gray_scaled = preprocess_input(gray_scaled)
    embed = inception.predict(grey_scaled)
    return embed

# Image transformer
datagen = ImageDataGenerator(
        shear_range=0.2,
        zoom_range=0.2,
        rotation_range=20,
        horizontal_flip=True)

# Generate training data
batch_size = 50
def image_a_b_gen(batch_size):
    for batch in datagen.flow(Xtrain, batch_size=batch_size):
        gray_scaled = gray2rgb(rgb2grey(batch)
        create_inception_embedding(batch)
        lab_batch = rgb2lab(batch)
        X_batch = lab_batch[:,:,:,0]
        Y_batch = lab_batch[:,:,:,1:] / 128
        yield (X_batch.reshape(X_batch.shape+(1,)), Y_batch)

# Train model      
tensorboard = TensorBoard(log_dir="/output/{}")
model.fit_generator(image_a_b_gen(batch_size), callbacks=[tensorboard], epochs=2, samples_per_epoch=200)

In [None]:
# Test images
Xtest = rgb2lab(1.0/255*X[split:])[:,:,:,0]
Xtest = Xtest.reshape(Xtest.shape+(1,))
Ytest = rgb2lab(1.0/255*X[split:])[:,:,:,1:]
Ytest = Ytest / 128
print(model.evaluate(Xtest, Ytest, batch_size=batch_size))

In [None]:
color_me = []
for filename in os.listdir('/color_300/Test/'):
	color_me.append(img_to_array(load_img('/color_300/Test/'+filename)))
color_me = np.array(color_me, dtype=float)
color_me = rgb2lab(1.0/255*color_me)[:,:,:,0]
color_me = color_me.reshape(color_me.shape+(1,))

# Test model
output = model.predict(color_me)
output = output * 128

# Output colorizations
for i in range(len(output)):
	cur = np.zeros((256, 256, 3))
	cur[:,:,0] = Xtest[i][:,:,0]
	cur[:,:,1:] = output[i]
	imsave("result/img_"+str(i)+".png", lab2rgb(cur))