In [1]:
import keras

Using TensorFlow backend.


In [2]:
print('keras: ', keras.__version__)

keras:  2.0.8


In [93]:
import keras
from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.preprocessing import image
from keras.engine import Layer
from keras.applications.inception_resnet_v2 import preprocess_input
from keras.layers import Conv2D, UpSampling2D, InputLayer, Conv2DTranspose, Input, Reshape, merge, concatenate
from keras.layers import Activation, Dense, Dropout, Flatten
from keras.layers.normalization import BatchNormalization
from keras.models import Sequential
from keras.layers.core import RepeatVector, Permute
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb, rgb2gray, xyz2lab
from skimage.io import imsave
import numpy as np
import numpy as np
import os
import random
import tensorflow as tf

In [94]:
# Get images
image_raw = img_to_array(load_img('dog.jpg'))
image_raw = np.array(image_raw, dtype=float)
#Load weights
inception = InceptionResNetV2(weights='imagenet', include_top=False)

In [95]:
X = rgb2lab(1.0/255*image_raw)[:,:,0]
Y = rgb2lab(1.0/255*image_raw)[:,:,1:]
Y /= 128
X = X.reshape(1, 256, 256, 1)
Y = Y.reshape(1, 256, 256, 2)

In [122]:
def conv_stack(data, filters, s):
        output = Conv2D(filters, (3, 3), strides=s, activation='relu', padding='same')(data)
        output = BatchNormalization()(output)
        return output

#Add inception embedding
img_path = 'dog.jpg'
img = image.load_img(img_path, target_size=(299, 299))
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)
img = preprocess_input(img)
embed = inception.predict(img)

embed_input = Input(shape=(8, 8, 1536,))
embed_output= Conv2D(512, (3, 3), activation="relu", strides=2)(embed_input)
embed_output= Conv2D(512, (3, 3), activation="relu", strides=2)(embed_output)
embed_output = Dense(1024, activation='relu')(embed_output)
embed_output = Dense(512, activation='relu')(embed_output)
embed_output = Dense(256, activation='relu')(embed_output)
embed_output = Reshape(([256,]))(embed_output)


#Encoder
encoder_input = Input(shape=(224, 224, 1,))
encoder_output = conv_stack(encoder_input, 64, 2)
print(encoder_output.shape)
encoder_output = conv_stack(encoder_output, 128, 2)
print(encoder_output.shape)
encoder_output = conv_stack(encoder_output, 256, 2)
print(encoder_output.shape)
encoder_output = conv_stack(encoder_output, 512, 1)
encoder_output = conv_stack(encoder_output, 256, 1)
print(encoder_output.shape)

#Fusion
# y_mid: (None, 256, 28, 28)
fusion_output = RepeatVector(28 * 28)(embed_output) # shape: (None, 28*28, 256)
fusion_output = Permute((2, 1))(fusion_output) # shape: (None, 256, 28*28)
fusion_output = Reshape(([256, 28, 28]))(fusion_output)# shape: (None, 256, 28, 28)
fusion_output = concatenate([fusion_output, encoder_output], axis=1) # (None, 512, 28, 28)
fusion_output = Conv2D(256, 1, 1, activation='relu')(fusion_output) # (None, 256, 28, 28) and Eq. (5)

#Decoder
decoder_output = UpSampling2D((2, 2))(fusion_output)
decoder_output = conv_stack(decoder_output, 64, [1, 1])(decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)
decoder_output = conv_stack(decoder_output, 32, [1])(decoder_output)
decoder_output = Conv2D(2, (3, 3), activation='tanh')(decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)

model = Model(inputs=[encoder_input, fusion_input], outputs=decoder_output)

(?, 112, 112, 64)
(?, 56, 56, 128)
(?, 28, 28, 256)
(?, 28, 28, 256)


ValueError: `Concatenate` layer requires inputs with matching shapes except for the concat axis. Got inputs shapes: [(None, 256, 28, 28), (None, 28, 28, 256)]

In [None]:
# Finish model
model.compile(optimizer='rmsprop', loss='mse')
model.fit(encoder_input=X,
    fusion_input=encoder_output,
    embed_input=embed,
    y=Y,
    batch_size=1,
    epochs=1000)

In [None]:
print(decoder.evaluate(X, Y, batch_size=1))
output = decoder.predict(X)
output *= 128
# Output colorizations
cur = np.zeros((300, 300, 3))
cur[:,:,0] = X[0][:,:,0]
cur[:,:,1:] = output[0]
imsave("img_result.png", lab2rgb(cur))
imsave("img_gray_version.png", rgb2gray(lab2rgb(cur)))