In [None]:
import keras
from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.preprocessing import image
from keras.engine import Layer
from keras.applications.inception_resnet_v2 import preprocess_input
from keras.layers import Conv2D, UpSampling2D, InputLayer, Conv2DTranspose
from keras.layers import Activation, Dense, Dropout, Flatten
from keras.layers.normalization import BatchNormalization
from keras.models import Sequential
from keras.layers.core import RepeatVector, Permute
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb, rgb2gray, xyz2lab
from skimage.io import imsave
import numpy as np
import numpy as np
import os
import random
import tensorflow as tf

In [None]:
# Get images
image_raw = img_to_array(load_img('dog.jpg'))
image_raw = np.array(image_raw, dtype=float)
#Load weights
inception = InceptionResNetV2(weights='imagenet', include_top=False)

In [None]:
X = rgb2lab(1.0/255*image_raw)[:,:,0]
Y = rgb2lab(1.0/255*image_raw)[:,:,1:]
Y /= 128
X = X.reshape(1, 256, 256, 1)
Y = Y.reshape(1, 256, 256, 2)

In [None]:
def conv_stack(model, filters, strides):
    for i in strides:
        model.add(Conv2D(filters, (3, 3), strides=i, activation='relu', padding='same'))
        model.add(BatchNormalization())

#Add inception embedding
img_path = 'dog.jpg'
img = image.load_img(img_path, target_size=(299, 299))
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)
img = preprocess_input(img)
embed = inception.predict(img)

#Encoder
encoder = Sequential() 
encoder.add(InputLayer(input_shape=(None, None, 1)))
conv_stack(encoder, 64, [2])(input1)
conv_stack(encoder, 128, [1, 2])
conv_stack(encoder, 256, [1, 2])
conv_stack(encoder, 512, [1, 1])
conv_stack(encoder, 256, [1])
conv_stack(encoder, 128, [1])

#Fusion
# y_mid: (None, 256, 28, 28)
fusion = Sequential() 
fusion.add(InputLayer(input_shape=(32, 32, 1)))(input1)
fusion.add(RepeatVector(28 * 28)) # shape: (None, 28*28, 256)
fusion.add(Permute((2, 1))) # shape: (None, 256, 28*28)
fusion.add(Reshape(256, 28, 28)) # shape: (None, 256, 28, 28)
y_concat = Merge(layers=[fusion, encoder], mode='concat', concat_axis=1) # (None, 512, 28, 28)
fusion_output = Conv2D(256, 1, 1, activation='relu')(y_concat) # (None, 256, 28, 28) and Eq. (5)

#Decoder
decoder = Sequential() 
decoder.add(InputLayer(input_shape=(fusion_output.shape)))
decoder.add(UpSampling2D((2, 2)))(fusion_output)
conv_stack(decoder, 64, [1, 1])
decoder.add(UpSampling2D((2, 2)))
conv_stack(decoder, 32, [1])
decoder.add(Conv2D(2, (3, 3), activation='tanh'))
decoder.add(UpSampling2D((2, 2)))

model = Model(inputs=[input1, input1], outputs=decoder)

In [None]:
# Finish model
model.compile(optimizer='rmsprop', loss='mse')
model.fit(input1=X, 
    input2 = embed,
    y=Y,
    batch_size=1,
    epochs=1000)

In [None]:
print(decoder.evaluate(X, Y, batch_size=1))
output = decoder.predict(X)
output *= 128
# Output colorizations
cur = np.zeros((300, 300, 3))
cur[:,:,0] = X[0][:,:,0]
cur[:,:,1:] = output[0]
imsave("img_result.png", lab2rgb(cur))
imsave("img_gray_version.png", rgb2gray(lab2rgb(cur)))