In [None]:
from keras.layers import Conv2D, Conv2DTranspose
from keras.layers import Activation, Dense, Dropout, Flatten, InputLayer
from keras.layers.normalization import BatchNormalization
from keras.callbacks import TensorBoard
from keras.models import Sequential
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb, rgb2gray
from skimage.io import imsave
import numpy as np
import os
import random
import tensorflow as tf

#tf.python.control_flow_ops = t

In [None]:
# Image transformer
datagen = ImageDataGenerator(
        shear_range=0.2,
        zoom_range=0.2,
        rotation_range=20,
        horizontal_flip=True)

# Get images
X = []
for filename in os.listdir('/color/Train/'):
    X.append(img_to_array(load_img('/color/Train/'+filename)))
X = np.array(X, dtype=float)

# Set up train and test data
split = int(0.9*len(X))
Xtrain = X[:split]
Xtest = rgb2lab(1.0/255*X[split:])[:,:,:,0]
Xtest = Xtest.reshape(Xtest.shape+(1,))
Ytest = rgb2lab(1.0/255*X[split:])[:,:,:,1:]
Ytest = Ytest / 128

In [None]:
def conv_stack(filters, d, strides):
    for i in strides:
        model.add(Conv2D(filters, (3, 3), strides=i, activation='relu', dilation_rate=d, padding='same'))
        model.add(BatchNormalization())

def upsampling_stack(filters):
    for i in filters:
        model.add(Conv2DTranspose(i, (3, 3), strides=(2, 2), activation='relu', padding='same'))
        conv_stack(i, 1, [1, 1])
        model.add(BatchNormalization())

model = Sequential()
model.add(InputLayer(input_shape=(299, 299, 1)))

conv_stack(8, 1, [1, 2])
conv_stack(16, 1, [1, 2, 2])
conv_stack(32, 2, [1, 1, 1, 1, 1, 1])
upsampling_stack([32, 16, 8])
conv_stack(4, 1, [1, 1, 1])
model.add(Conv2DTranspose(2, (3, 3), activation='tanh', padding='same'))

In [None]:
# Finish model
model.compile(optimizer='rmsprop', loss='mse')

# Generate training data
batch_size = 100
def image_a_b_gen(batch_size):
    for batch in datagen.flow(Xtrain, batch_size=batch_size):
        lab_batch = rgb2lab(batch)
        X_batch = lab_batch[:,:,:,0]
        Y_batch = lab_batch[:,:,:,1:]
        yield (X_batch.reshape(X_batch.shape+(1,)), Y_batch)

# Train model
TensorBoard(log_dir='/output')
model.fit_generator(image_a_b_gen(batch_size), samples_per_epoch=1000, epochs=1)

In [None]:
model_json = model.to_json()
with open("model.json", "w") as json_file:
    json_file.write(model_json)
model.save_weights("model.h5")

In [None]:
color_me = []
for filename in os.listdir('/color/Test'):
	color_me.append(img_to_array(load_img('/color/Test/'+filename)))
color_me = np.array(color_me, dtype=float)
color_me = rgb2lab(1.0/255*color_me)[:,:,:,0]
color_me = color_me.reshape(color_me.shape+(1,))

# Test model
print model.evaluate(Xtest, Ytest, batch_size=batch_size)
output = model.predict(color_me)
output = output * 128

# Output colorizations
for i in range(len(output)):
	cur = np.zeros((299, 299, 3))
	cur[:,:,0] = Xtest[i][:,:,0]
	cur[:,:,1:] = output[i]
	imsave("result/img_"+str(i)+".png", lab2rgb(cur))