In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt 
import numpy as np
import cv2
from tensorflow.keras.preprocessing.image import ImageDataGenerator

from tensorflow.keras.applications.inception_resnet_v2 import InceptionResNetV2, preprocess_input
from skimage.transform import resize
from tensorflow.keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb, rgb2gray, gray2rgb
from sklearn.preprocessing import MinMaxScaler
import os

from tensorflow.keras.layers import Conv2D,InputLayer, Input, concatenate ,RepeatVector ,Reshape ,UpSampling2D

from tensorflow.keras.models import Model, load_model, Sequential



from tensorflow.keras.callbacks import EarlyStopping,ModelCheckpoint



In [2]:
# data from https://www.floydhub.com/emilwallner/datasets/colornet
# download data an put it in a folder in the root called "data"
items = []
num = 0
for file in os.listdir("./data/images/Train/"):
    if (num < 250):

        img_array = img_to_array(load_img("./data/images/Train/" + file))
        items.append(img_array)
    num += 1
items = np.array(items)
X_train = 1.0/255 * items



In [4]:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.keras.applications.inception_resnet_v2 import InceptionResNetV2, decode_predictions, preprocess_input
inception = InceptionResNetV2(weights=None, include_top=True)
inception.load_weights('./data/inception_resnet_v2_weights_tf_dim_ordering_tf_kernels.h5')
inception.graph = tf.get_default_graph()

In [5]:
#To generate embeddings of 1000*1 by passing input images through InceptionResNetV2
def create_inception_embedding(grayscaled_rgb):
    grayscaled_rgb_resized = []
    for i in grayscaled_rgb:
        i = resize(i, (299, 299, 3), mode='constant')
        grayscaled_rgb_resized.append(i)
    grayscaled_rgb_resized = np.array(grayscaled_rgb_resized)
    grayscaled_rgb_resized = preprocess_input(grayscaled_rgb_resized)
    with inception.graph.as_default():
        embed = inception.predict(grayscaled_rgb_resized)
    return embed
import tensorflow as tf

In [6]:
lab_img = rgb2lab(X_train)
x_batch = lab_img[:, :, :, 0].reshape(250,256,256,1)
y_batch =lab_img[:, :, :, 1:] / 128



#creating embeddings for Train data
incept_em = create_inception_embedding(x_batch)
embeddings = RepeatVector(32 * 32)(incept_em)
layer_embedding_train = Reshape(([32, 32, 1000]))(embeddings)

In [7]:

# tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
# tpu_strategy = tf.distribute.experimental.TPUStrategy(tpu)
# with tpu_strategy.scope():


embed_input = Input(shape=(32, 32, 1000))
encoder_input = Input(shape=(256, 256, 1,))

encoder1 = Conv2D(64, (3,3), activation='relu', padding='same', strides=2)(encoder_input)
encoder2 = Conv2D(128, (3,3), activation='relu', padding='same')(encoder1)
encoder3 = Conv2D(128, (3,3), activation='relu', padding='same', strides=2)(encoder2)
encoder4 = Conv2D(256, (3,3), activation='relu', padding='same')(encoder3)
encoder5 = Conv2D(256, (3,3), activation='relu', padding='same', strides=2)(encoder4)
encoder6 = Conv2D(512, (3,3), activation='relu', padding='same')(encoder5)
encoder7 = Conv2D(512, (3,3), activation='relu', padding='same')(encoder6)
encoder_output= Conv2D(256, (3,3), activation='relu', padding='same')(encoder7)
#Fusion layer
fusion1 = concatenate([encoder_output, embed_input], axis=3) 
fusion2 = Conv2D(256, (1, 1), activation='relu', padding='same')(fusion1)
fusion_output = Conv2D(256, (1, 1), activation='relu', padding='same')(fusion2)
#Decoder layer
decoder1 = Conv2D(128, (3,3), activation='relu', padding='same')(fusion_output)
decoder2 = UpSampling2D((2, 2))(decoder1)
decoder3 = Conv2D(64, (3,3), activation='relu', padding='same')(decoder2)
decoder4 = UpSampling2D((2, 2))(decoder3)
decoder5 = Conv2D(32, (3,3), activation='relu', padding='same')(decoder4)
decoder6 = Conv2D(16, (3,3), activation='relu', padding='same')(decoder5)
decoder7 = Conv2D(2, (3, 3), activation='tanh', padding='same')(decoder6)
decoder_output = UpSampling2D((2, 2))(decoder7)

model = Model(inputs=[encoder_input,embed_input], outputs=decoder_output)





In [None]:
model.compile(optimizer='adam', loss='mse',metrics=['accuracy'])
model.fit(x=[x_batch,layer_embedding_train] ,y=y_batch, batch_size=5, epochs=1400,steps_per_epoch=1)

In [None]:
# Util functions

def lab2RGB(l, ab):
    shape = (l.shape[0],l.shape[1],3)
    img = np.zeros(shape)
    img[:,:,0] = l[:,:,0]
    img[:,:,1:] = ab
    img = img.astype('uint8')
    print(img.max())
    img = cv2.cvtColor(img, cv2.COLOR_LAB2RGB)
    return img
def display(img):
    plt.figure()
    plt.set_cmap('gray')
    plt.imshow(img)
    plt.show()


img = lab2RGB(X_train[0], Y_train[0])
display(img)
    


In [None]:
# for loading and image to predict

test_img = cv2.imread("./data/ww1.jpeg")
rgb_img = cv2.cvtColor(test_img, cv2.COLOR_BGR2RGB)
print (rgb_img.shape)

display(rgb_img)

img_lab = rgb2lab(test_img)
img_l = img_lab[:, :, 0]
img_ab = img_lab[:, :, 1:]
print(img_l.max())
print(img_l.min())
print(img_ab.max())
print(img_ab.min())
display(img_l)

In [None]:

test_img = cv2.imread("./data/testImg.jpeg")
rgb_img = cv2.cvtColor(test_img, cv2.COLOR_BGR2RGB)
print (rgb_img.shape)

display(rgb_img)
rgb_img.astype("uint8")
img_lab = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2LAB)
img_lab = rgb2lab(test_img)
img_l = img_lab[:, :, 0]
img_ab = img_lab[:, :, 1:]
print(img_l.max())
print(img_l.min())
print(img_ab.max())
print(img_ab.min())
display(img_l)