In [1]:
import os
os.environ['KERAS_BACKEND'] = 'tensorflow'

In [32]:
import keras
from keras.applications.inception_resnet_v2 import InceptionResNetV2, preprocess_input
from keras.layers import Conv2D, UpSampling2D, InputLayer, Input, Reshape, concatenate
from keras.callbacks import TensorBoard
from keras.models import Sequential, Model
from keras.layers.core import RepeatVector
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb, rgb2gray, gray2rgb
from skimage.io import imsave
from skimage.transform import resize

import numpy as np
import random
import tensorflow as tf

In [3]:
print(keras.__version__)
print(tf.__version__)

2.0.9
1.3.0


In [4]:
DATASET_PATH = './dataset'

In [5]:
# Get images
X = []
for filename in os.listdir(f"{DATASET_PATH}/train/"):
    X.append(img_to_array(load_img(f"{DATASET_PATH}/train/{filename}")))

X = np.array(X, dtype='float32')
X_train = 1.0 / 255 * X

In [6]:
# Load weights
inception = InceptionResNetV2(weights="imagenet", include_top=True)
inception.graph = tf.get_default_graph()

Downloading data from https://github.com/fchollet/deep-learning-models/releases/download/v0.7/inception_resnet_v2_weights_tf_dim_ordering_tf_kernels.h5


In [8]:
from functools import partial

conv = partial(Conv2D,
               1,
               (3, 3),
               activation='relu',
               padding='same'
               )

In [24]:
embed_input = Input(shape=(1000,))

# Encoder
encoder_input = Input(shape=(256, 256, 1,))
encoder = Conv2D(64, (3, 3), activation='relu', padding='same', strides=2)(encoder_input)
encoder = Conv2D(128, (3, 3), activation='relu', padding='same')(encoder)
encoder = Conv2D(128, (3, 3), activation='relu', padding='same', strides=2)(encoder)
encoder = Conv2D(256, (3, 3), activation='relu', padding='same')(encoder)
encoder = Conv2D(256, (3, 3), activation='relu', padding='same', strides=2)(encoder)
encoder = Conv2D(512, (3, 3), activation='relu', padding='same')(encoder)
encoder = Conv2D(512, (3, 3), activation='relu', padding='same')(encoder)
encoder = Conv2D(256, (3, 3), activation='relu', padding='same')(encoder)


In [25]:
#Fusion
fusion = RepeatVector(32 * 32)(embed_input) 
# embed_input.shape = (?, 1000)
# funsion.shape = (?, 1024, 1000)
fusion = Reshape(([32, 32, 1000]))(fusion)
fusion = concatenate([encoder, fusion], axis=3)
# funsion.shape = (?, 32, 32, 1000 + 256)
fusion = Conv2D(256, (1, 1), activation='relu', padding='same')(fusion)
# fusion.shape = (?, 32, 32, 256)

In [26]:
# Decoder
decoder = Conv2D(128, (3, 3), activation='relu', padding='same')(fusion)
decoder = UpSampling2D((2, 2))(decoder)
decoder = Conv2D(64, (3, 3), activation='relu', padding='same')(decoder)
decoder = UpSampling2D((2, 2))(decoder)
decoder = Conv2D(32, (3, 3), activation='relu', padding='same')(decoder)
decoder = Conv2D(16, (3, 3), activation='relu', padding='same')(decoder)
decoder = Conv2D(2, (3, 3), activation='tanh', padding='same')(decoder)
decoder = UpSampling2D((2, 2))(decoder)

In [27]:
model = Model(inputs=[encoder_input, embed_input], outputs=decoder)

In [34]:
def create_inception_embedding(grayscaled_rgb):
    grayscaled_rgb_resized = []
    for i in grayscaled_rgb:
        i = resize(i, (299, 299, 3), mode='constant')
        grayscaled_rgb_resized.append(i)
    grayscaled_rgb_resized = np.array(grayscaled_rgb_resized)
    grayscaled_rgb_resized = preprocess_input(grayscaled_rgb_resized)
    with inception.graph.as_default():
        embed = inception.predict(grayscaled_rgb_resized)
    return embed

In [29]:
def image_a_b_gen(batch_size):
    for batch in datagen.flow(X_train, batch_size=batch_size):
        grayscaled_rgb = gray2rgb(rgb2gray(batch))
        embed = create_inception_embedding(grayscaled_rgb)
        lab_batch = rgb2lab(batch)
        X_batch = lab_batch[:, :, :, 0]
        X_batch = X_batch.reshape(X_batch.shape + (1, ))
        y_batch = lab_batch[:, :, :, 1:] / 128.
        
        yield ([X_batch, create_inception_embedding(grayscaled_rgb)], y_batch)

In [30]:
# Image transformer
datagen = ImageDataGenerator(shear_range=0.2,
                             zoom_range=0.2,
                             rotation_range=20,
                             horizontal_flip=True)
batch_size = 10

model.compile(optimizer='rmsprop',
              loss = 'mse')

In [35]:
model.fit_generator(image_a_b_gen(batch_size), epochs=1, steps_per_epoch=1)

Epoch 1/1


<keras.callbacks.History at 0x2bd0c630>

In [36]:
color_me = []
for filename in os.listdir(f"{DATASET_PATH}/test/"):
    color_me.append(img_to_array(load_img(f"{DATASET_PATH}/test/{filename}")))
color_me = np.array(color_me, dtype='float32')
gray_me = gray2rgb(rgb2gray(1.0 / 255 * color_me))
color_me_embed = create_inception_embedding(gray_me)
color_me = rgb2lab(1.0 / 255 * color_me)[:, :, : , 0]
color_me = color_me.reshape(color_me.shape + (1, ))

In [39]:
# Test model
output = model.predict([color_me, color_me_embed])
output = output * 128

# Output colorizations
for i in range(len(output)):
    cur = np.zeros((256, 256, 3))
    cur[:, :, 0] = color_me[i][:,:, 0]
    cur[:, :, 1:] = output[i]
    imsave(f"result/full_img_{str(i)}.png", lab2rgb(cur))

  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  .format(dtypeobj_in, dtypeobj_out))
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)


In [41]:
output = model.predict([color_me, color_me_embed])
output

array([[[[ 1.        ,  1.        ],
         [ 1.        ,  1.        ],
         [ 1.        ,  1.        ],
         ..., 
         [ 1.        ,  1.        ],
         [ 1.        , -0.99999785],
         [ 1.        , -0.99999785]],

        [[ 1.        ,  1.        ],
         [ 1.        ,  1.        ],
         [ 1.        ,  1.        ],
         ..., 
         [ 1.        ,  1.        ],
         [ 1.        , -0.99999785],
         [ 1.        , -0.99999785]],

        [[ 1.        ,  1.        ],
         [ 1.        ,  1.        ],
         [ 1.        ,  1.        ],
         ..., 
         [ 1.        ,  1.        ],
         [ 1.        ,  0.9999997 ],
         [ 1.        ,  0.9999997 ]],

        ..., 
        [[ 1.        ,  1.        ],
         [ 1.        ,  1.        ],
         [ 1.        , -1.        ],
         ..., 
         [ 1.        , -1.        ],
         [ 1.        , -1.        ],
         [ 1.        , -1.        ]],

        [[ 1.        ,  0.9998