In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt 
import numpy as np
import cv2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
from tensorflow.keras.applications.inception_resnet_v2 import InceptionResNetV2, preprocess_input
from skimage.transform import resize
from tensorflow.keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb, rgb2gray, gray2rgb
from sklearn.preprocessing import MinMaxScaler
import os

In [2]:
# data from https://www.floydhub.com/emilwallner/datasets/colornet
# download data an put it in a folder in the root called "data"
items = []
for file in os.listdir("./data/images/Train/"):
    img_array = img_to_array(load_img("./data/images/Train/" + file))
    items.append(img_array)
items = np.array(items)
X_train = 1.0/255 * items



In [None]:
# # splits the data into train and test sets
# X_train, X_test, Y_train, Y_test = train_test_split(imgs_gray, imgs_ab, test_size=0.2, random_state=42, shuffle=True)
# X_train.astype('float')
# X_test.astype('float')
# Y_train.astype('float')
# Y_test.astype('float')
# print("X_train " + str(X_train.shape))
# print("X_test " + str(X_test.shape))
# print("y_train " + str(Y_train.shape))
# print("y_test " + str(Y_test.shape))

# print(X_train.min())
# print(X_train.max())

# print(Y_train.min())
# print(Y_train.max())



In [3]:
# loads the inceptionResNetV2 for prediction of large features of the images
import tensorflow.compat.v1 as tf_compact
tf_compact.disable_v2_behavior()
inception = InceptionResNetV2(weights=None, include_top=True)
inception.load_weights("./data/inception_resnet_v2_weights_tf_dim_ordering_tf_kernels.h5")
inception.graph = tf_compact.get_default_graph()

Instructions for updating:
non-resource variables are not supported in the long term


In [4]:
# making the prediction with inceptionResNet
# them the prediction is embed into the imgs
def embed_inception_prediction(img_unsized):
    # resize the image to fit in the resNet
    img_resized = []
    for i in img_unsized:
        img = resize(i, (299, 299, 3), mode='constant')
        img_resized.append(img)
    img_resized = np.array(img_resized)
    # this prepares the data for the resNet
    img_resized = preprocess_input(img_resized)
    with inception.graph.as_default():
        imgs_embed = inception.predict(img_resized)
    return imgs_embed
    


In [13]:
train_datagen = ImageDataGenerator(
        shear_range=0.2,
        zoom_range=0.2,
        rotation_range=20,
        horizontal_flip=True
)


def image_train_datagen(batch_size):
    for batch_imgs in train_datagen.flow(X_train, batch_size=batch_size):

        grayscale_for_embeding = gray2rgb(rgb2gray(batch_imgs))
        embeding = embed_inception_prediction(grayscale_for_embeding)

        lab_train = rgb2lab(batch_imgs)
        l_batch = lab_train[:, :, :, 0]
        l_batch = l_batch.reshape(l_batch.shape + (1,))

        ab_batch = lab_train[:, :, :, 1:] / 128
       
        yield([l_batch, embeding], ab_batch)

                

In [None]:
# def test_image_train_datagen(X_data, batch_size):
#     for batch_imgs in train_datagen.flow(X_data, batch_size=batch_size):
#         grayscale_for_embeding = gray2rgb(rgb2gray(batch_imgs))
#         embeding = embed_inception_prediction(grayscale_for_embeding)

#         lab_train = rgb2lab(batch_imgs)
#         l_batch = lab_train[:, :, :, 0]
#         l_batch = l_batch.reshape(l_batch.shape + (1,))
#         print(l_batch.shape)

#         ab_batch = lab_train[:, :, :, 1:] / 128
#         print(l_batch.shape)
#         print(l_batch[0])

#         print(ab_batch.shape)
#         print(ab_batch[0])
#         return([l_batch, embeding], ab_batch)


# data = np.array([X_train[0]])
# print()
# test = test_image_train_datagen(data, 1)


In [6]:
from tensorflow.keras.layers import Conv2D, UpSampling2D, InputLayer, Conv2DTranspose, Input, Reshape, concatenate, Activation, Dense, Dropout, Flatten, RepeatVector
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint


In [7]:
embedings_input = Input(shape=(1000, ))

# using kears functional API to pass the outputs to each layer
encoder_input = Input(shape=(256, 256, 1))
encoder1 = Conv2D(64, (3, 3), strides=(2, 2), activation='relu', padding='same')(encoder_input)
encoder2 = Conv2D(128, (3, 3), activation='relu', padding='same')(encoder1)
encoder3 = Conv2D(128, (3, 3), activation='relu', padding='same', strides=(2, 2))(encoder2)
encoder4 = Conv2D(256, (3, 3), activation='relu', padding='same')(encoder3)
encoder5 = Conv2D(256, (3, 3), activation='relu', padding='same', strides=(2, 2))(encoder4)
encoder6 = Conv2D(512, (3, 3), activation='relu', padding='same')(encoder5)
encoder7 = Conv2D(512, (3, 3), activation='relu', padding='same')(encoder6)
encoder_output = Conv2D(512, (3, 3), activation='relu', padding='same')(encoder7)

fusion1 = RepeatVector(32 * 32)(embedings_input)
fusion2 = Reshape(([32, 32, 1000]))(fusion1)
fusion3 = concatenate([encoder_output, fusion2], axis=3)
fusion_output = Conv2D(256, (1, 1), activation='relu', padding='same')(fusion3)

decoder1 = Conv2D(128, (3, 3), activation='relu', padding='same')(fusion_output)
decoder2 = UpSampling2D((2, 2))(decoder1)
decoder3 = Conv2D(64, (3, 3), activation='relu', padding='same')(decoder2)
decoder4 = Conv2D(64, (3, 3), activation='relu', padding='same')(decoder3)
decoder5 = UpSampling2D((2, 2))(decoder4)
decoder6 = Conv2D(32, (3, 3), activation='relu', padding='same')(decoder5)
decoder7 = Conv2D(16, (3, 3), activation='relu', padding='same')(decoder6)
decoder8 = Conv2D(2, (3, 3), activation='tanh', padding='same')(decoder7)
final_decoder_output = UpSampling2D((2, 2))(decoder8)

myModel = Model(inputs=[encoder_input, embedings_input], outputs=final_decoder_output)

In [8]:
model_path = './chroma_model.h5'

make_checkpoint = ModelCheckpoint(model_path,
                            monitor = "val_loss",
                            mode="min",
                            save_best_only = True,
                            verbose = 1)
early_stoping = EarlyStopping(monitor='val_loss', mode='min', verbose = 1)

myModel.compile(optimizer='adam', loss="mse", metrics=['accuracy'])

In [14]:
batch_size_number = 10 # batch
num_epochs = 1
# total number of imgs / batch size
steps_per_epoch_num = 929
history = myModel.fit_generator(image_train_datagen(batch_size_number), epochs = num_epochs, steps_per_epoch=steps_per_epoch_num, verbose = 2, callbacks = [make_checkpoint,early_stoping])

FailedPreconditionError: Error while reading resource variable conv2d_167/kernel from Container: localhost. This could mean that the variable was uninitialized. Not found: Container localhost does not exist. (Could not find resource: localhost/conv2d_167/kernel)
	 [[{{node conv2d_167/Conv2D/ReadVariableOp}}]]

In [None]:
def lab2RGB(l, ab):
    shape = (l.shape[0],l.shape[1],3)
    img = np.zeros(shape)
    img[:,:,0] = l[:,:,0]
    img[:,:,1:] = ab
    img = img.astype('uint8')
    print(img.max())
    img = cv2.cvtColor(img, cv2.COLOR_LAB2RGB)
    return img
def display(img):
    plt.figure()
    plt.set_cmap('gray')
    plt.imshow(img)
    plt.show()


img = lab2RGB(X_train[0], Y_train[0])
display(img)
    
# testing_resacel = testing_resacel * 255
# testing_res_a_b = testing_res_a_b * 255

# img = l_ab_to_RGB(X_train[1], y_train[1])
# # img = l_ab_to_RGB(testing_resacel, testing_res_a_b)
# # img = np.resize(img, (256, 256, 3))
# img = np.array(resize(img, (300, 300, 3), mode='constant'))
# print(img.shape)

# # testing_resacel = testing_resacel * 255
# # display(testing_resacel)
# display(img)

In [None]:
# for loading and image to predict

test_img = cv2.imread("./data/ww1.jpeg")
rgb_img = cv2.cvtColor(test_img, cv2.COLOR_BGR2RGB)
print (rgb_img.shape)

display(rgb_img)

img_lab = rgb2lab(test_img)
img_l = img_lab[:, :, 0]
img_ab = img_lab[:, :, 1:]
print(img_l.max())
print(img_l.min())
print(img_ab.max())
print(img_ab.min())
display(img_l)

In [None]:

test_img = cv2.imread("./data/testImg.jpeg")
rgb_img = cv2.cvtColor(test_img, cv2.COLOR_BGR2RGB)
print (rgb_img.shape)

display(rgb_img)
rgb_img.astype("uint8")
img_lab = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2LAB)
img_lab = rgb2lab(test_img)
img_l = img_lab[:, :, 0]
img_ab = img_lab[:, :, 1:]
print(img_l.max())
print(img_l.min())
print(img_ab.max())
print(img_ab.min())
display(img_l)