In [None]:
import numpy as np
import keras.backend.tensorflow_backend as KTF
import tensorflow as tf
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "5"

# set to test image directory
TEST_DIR = '/data_service/source_datasets/cifar_images/images_test'

In [None]:
from tqdm import tqdm
# get test image paths
filenames = []
for f_name in tqdm(os.listdir(TEST_DIR)):
    if os.path.splitext(f_name)[-1] == '.png':
        filenames.append(os.path.join(TEST_DIR, f_name))


In [None]:
from keras.preprocessing.image import img_to_array, load_img
from skimage.color import lab2rgb, rgb2lab

In [None]:
from IPython.display import display, Image
from matplotlib.pyplot import imshow

# l_channel values = gray image
def get_lab_channel(x, y):
    l_channel = np.zeros((32, 32, 3))
    l_channel[:,:,0] = x[:,:,0]
    l_channel = lab2rgb(l_channel)

    a_channel = np.zeros((32, 32, 3))
    a_channel[:,:,0] = 50
    a_channel[:,:,1] = y[:,:,0]
    a_channel = lab2rgb(a_channel)
    
    b_channel = np.zeros((32, 32, 3))
    b_channel[:,:,0] = 50
    b_channel[:,:,2] = y[:,:,1]
    b_channel = lab2rgb(b_channel)
    
    return l_channel, a_channel, b_channel


In [None]:
# gpu config
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

KTF.set_session(sess)

In [None]:
from unet import unet
from Colorizer import Colorizer

In [None]:
colorizer = Colorizer(model=unet())
colorizer.load('unet.model')

In [None]:
f_name = filenames[1]
img = img_to_array(load_img(f_name, target_size=(32, 32))) / 255
lab_image = rgb2lab(img)

# add norm?
lab_image_norm = (lab_image + [0, 128, 128]) / [100, 255, 255]

x = lab_image_norm[:, :, 0]
x = x.reshape(x.shape[0], x.shape[1], 1)
y = lab_image_norm[:, :, 1:]

# add batch_size shape
x = x.reshape(1, x.shape[0], x.shape[1], x.shape[2])
predicted = colorizer.predict(x)

In [None]:
# predicted.shape (1, 32, 32, 2)
cur = np.zeros((32, 32, 3))
cur[:,:,0] = x[0][:,:,0]
cur[:,:,1:] = predicted[0]

In [None]:
cur = (cur * [100, 255, 255]) - [0, 128, 128]
rgb_predicted = lab2rgb(cur)


In [None]:
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(7, 2))

ax = fig.add_subplot(1, 2, 1)
ax.imshow(rgb_predicted); ax.axis("off")
ax.set_title("predicted image")

ax = fig.add_subplot(1, 2, 2)
ax.imshow(img); ax.axis("off")
ax.set_title("ground truth image")