In [None]:
from keras.layers import Conv2D, Conv2DTranspose, UpSampling2D
from keras.layers import Activation, Dense, Dropout, Flatten, InputLayer
from keras.layers.normalization import BatchNormalization
from keras.engine import Layer
from keras.callbacks import TensorBoard
from keras.models import Sequential
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb, rgb2gray
from skimage.io import imsave
import numpy as np
import os
import random
import tensorflow as tf

In [None]:
# Get images
X = []
for filename in os.listdir('/color_300/Train/'):
    X.append(img_to_array(load_img('/color_300/Train/'+filename)))
X = np.array(X, dtype=float)

# Set up train and test data
split = int(0.95*len(X))
Xtrain = X[:split]
Xtrain = 1.0/255*Xtrain

In [None]:
def prepare_image_for_inception(input_tensor):
    """
    Pre-processes an image tensor ``(int8, range [0, 255])``
    to be fed into inception ``(float32, range [-1, +1])``
    :param input_tensor:
    :return:
    """
    res = tf.cast(input_tensor, dtype=tf.float32)
    res = 2 * res / 255 - 1
    res = tf.reshape(res, [-1, 300, 300, 3])
    return res

In [None]:
def _create_operations(self, examples_per_record):
        """
        Create the operations to read images from the queue and
        extract inception features
        :return: a tuple containing all these operations
        """
        # Create the queue operations
        image_key, image_tensor, _ = \
            queue_single_images_from_folder(self.inputs_dir)

        # Build Inception Resnet v2 operations using the image as input
        # - from rgb to grayscale to loose the color information
        # - from grayscale to rgb just to have 3 identical channels
        # - from a [0, 255] int8 range to [-1,+1] float32
        # - feed the image into inception and get the embedding
        img_for_inception = tf.image.rgb_to_grayscale(image_tensor)
        img_for_inception = tf.image.grayscale_to_rgb(img_for_inception)
        img_for_inception = prepare_image_for_inception(img_for_inception)
        with slim.arg_scope(inception_resnet_v2_arg_scope()):
            input_embedding, _ = inception_resnet_v2(img_for_inception,
                                                     is_training=False)

        operations = image_key, image_tensor, input_embedding

        return batch_operations(operations, examples_per_record)

    def _run_session(self, sess, operations, examples_per_record):
        """
        Run the whole reading -> extracting features -> writing to records
        pipeline in a TensorFlow session
        :param sess:
        :param operations:
        :param examples_per_record:
        :return:
        """

        # Coordinate the loading of image files.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        start_time = time.time()
        self._examples_count = 0

        # These are the only lines where something happens:
        # we execute the operations to get the image, compute the
        # embedding and write everything in the TFRecord
        try:
            while not coord.should_stop():
                self._write_record(examples_per_record, operations, sess)
        except tf.errors.OutOfRangeError:
            # The string_input_producer queue ran out of strings
            pass
        finally:
            # Ask the threads (filename queue) to stop.
            coord.request_stop()
            print('Finished writing {} images in {:.2f}s'
                  .format(self._examples_count, time.time() - start_time))

        # Wait for threads to finish.
        coord.join(threads)

def batch_all(self, examples_per_record):
        operations = self._create_operations(examples_per_record)

        with tf.Session() as sess:
            self._initialize_session(sess)
            self._run_session(sess, operations, examples_per_record)

    def _initialize_session(self, sess):
        """
        Initialize a new session to run the operations
        :param sess:
        :return:
        """

        # Initialize the the variables that we introduced (like queues etc.)
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        # Restore the weights from Inception
        # (do not call a global/local variable initializer after this call)
        saver = tf.train.Saver()
        saver.restore(sess, self.checkpoint_file)

In [None]:
class FusionLayer(Layer):
    def call(self, inputs, mask=None):
        imgs, embs = inputs
        reshaped_shape = imgs.shape[:3].concatenate(embs.shape[1])
        embs = K.repeat(embs, imgs.shape[1] * imgs.shape[2])
        embs = K.reshape(embs, reshaped_shape)
        return K.concatenate([imgs, embs], axis=3)

    def compute_output_shape(self, input_shapes):
        # Must have 2 tensors as input
        assert input_shapes and len(input_shapes) == 2
        imgs_shape, embs_shape = input_shapes

        # The batch size of the two tensors must match
        assert imgs_shape[0] == embs_shape[0]

        # (batch_size, width, height, embedding_len + depth)
        return imgs_shape[:3] + (imgs_shape[3] + embs_shape[1],)

In [None]:
class Colorization:
    def __init__(self, depth_after_fusion):
        self.encoder = _build_encoder()
        self.fusion = FusionLayer()
        self.after_fusion = Conv2D(
            depth_after_fusion, (1, 1), activation='relu')
        self.decoder = _build_decoder(depth_after_fusion)

    def build(self, img_l, img_emb):
        img_enc = self.encoder(img_l)

        fusion = self.fusion([img_enc, img_emb])
        fusion = self.after_fusion(fusion)

        return self.decoder(fusion)

def conv_stack(filters, d, strides):
    for i in strides:
        model.add(Conv2D(filters, (3, 3), strides=i, activation='relu', dilation_rate=d, padding='same'))
        model.add(BatchNormalization())

def _build_encoder():
    model = Sequential(name='encoder')
    model.add(InputLayer(input_shape=(None, None, 1)))
    conv_stack(64, 1, [2])
    conv_stack(128, 1, [1, 2])
    conv_stack(256, 1, [1, 2])
    conv_stack(512, 1, [1, 1])
    conv_stack(256, 1, [1])
    conv_stack(128, 1, [1])
    return model

def _build_decoder(encoding_depth):
    model = Sequential(name='decoder')
    model.add(InputLayer(input_shape=(None, None, encoding_depth)))
    model.add(UpSampling2D((2, 2)))
    conv_stack(64, 1, [1, 1])
    model.add(UpSampling2D((2, 2)))
    conv_stack(32, 1, [1])
    model.add(Conv2D(2, (3, 3), activation='tanh'))
    model.add(UpSampling2D((2, 2)))
    return model
    
model.compile(optimizer='rmsprop', loss='mse')

In [None]:
# Image transformer
datagen = ImageDataGenerator(
        shear_range=0.2,
        zoom_range=0.2,
        rotation_range=20,
        horizontal_flip=True)

# Generate training data
batch_size = 50
def image_a_b_gen(batch_size):
    for batch in datagen.flow(Xtrain, batch_size=batch_size):
        lab_batch = rgb2lab(batch)
        X_batch = lab_batch[:,:,:,0]
        Y_batch = lab_batch[:,:,:,1:] / 128
        yield (X_batch.reshape(X_batch.shape+(1,)), Y_batch)

# Train model      
tensorboard = TensorBoard(log_dir="/output/{}")
model.fit_generator(image_a_b_gen(batch_size), callbacks=[tensorboard], epochs=2, samples_per_epoch=200)

In [None]:
# Save model
model_json = model.to_json()
with open("model.json", "w") as json_file:
    json_file.write(model_json)
model.save_weights("model.h5")

In [None]:
# Test images
Xtest = rgb2lab(1.0/255*X[split:])[:,:,:,0]
Xtest = Xtest.reshape(Xtest.shape+(1,))
Ytest = rgb2lab(1.0/255*X[split:])[:,:,:,1:]
Ytest = Ytest / 128
print(model.evaluate(Xtest, Ytest, batch_size=batch_size))

In [None]:
color_me = []
for filename in os.listdir('/color_300/Test/'):
	color_me.append(img_to_array(load_img('/color_300/Test/'+filename)))
color_me = np.array(color_me, dtype=float)
color_me = rgb2lab(1.0/255*color_me)[:,:,:,0]
color_me = color_me.reshape(color_me.shape+(1,))

# Test model
output = model.predict(color_me)
output = output * 128

# Output colorizations
for i in range(len(output)):
	cur = np.zeros((300, 300, 3))
	cur[:,:,0] = Xtest[i][:,:,0]
	cur[:,:,1:] = output[i]
	imsave("result/img_"+str(i)+".png", lab2rgb(cur))