In [1]:
from keras.layers import Conv2D, Conv2DTranspose, UpSampling2D
from keras.layers import Activation, Dense, Dropout, Flatten, InputLayer
from keras.layers.normalization import layer_normalization
from keras.callbacks import TensorBoard
from keras.models import Sequential
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb, rgb2gray
from skimage.io import imsave
import numpy as np
import os
import random
import tensorflow as tf

In [2]:
# Get images
X = []
for filename in os.listdir('/Train/'):
    X.append(img_to_array(load_img('/Train/'+filename)))
X = np.array(X, dtype=float)

# Set up train and test data
split = int(0.95*len(X))
Xtrain = X[:split]
Xtrain = 1.0/255*Xtrain

In [8]:
model = Sequential()
model.add(InputLayer(input_shape=(256, 256, 1)))
model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(64, (3, 3), activation='relu', padding='same', strides=2))
model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(128, (3, 3), activation='relu', padding='same', strides=2))
model.add(Conv2D(256, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(256, (3, 3), activation='relu', padding='same', strides=2))
model.add(Conv2D(512, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(256, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(32, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(2, (3, 3), activation='tanh', padding='same'))
model.add(UpSampling2D((2, 2)))
model.compile(optimizer='rmsprop', loss='mse', metrics=['accuracy'])

In [9]:
# Image transformer
datagen = ImageDataGenerator(
        shear_range=0.2,
        zoom_range=0.2,
        rotation_range=20,
        horizontal_flip=True)

# Generate training data
batch_size = 10
def image_a_b_gen(batch_size):
    for batch in datagen.flow(Xtrain, batch_size=batch_size):
        lab_batch = rgb2lab(batch)
        X_batch = lab_batch[:,:,:,0]
        Y_batch = lab_batch[:,:,:,1:] / 128
        yield (X_batch.reshape(X_batch.shape+(1,)), Y_batch)

# Train model      
tensorboard = TensorBoard(log_dir="output/first_run")
model.fit_generator(image_a_b_gen(batch_size), callbacks=[tensorboard], epochs=1, steps_per_epoch=10)

2021-11-15 14:34:40.605613: I tensorflow/core/profiler/lib/profiler_session.cc:131] Profiler session initializing.
2021-11-15 14:34:40.605630: I tensorflow/core/profiler/lib/profiler_session.cc:146] Profiler session started.
2021-11-15 14:34:40.605831: I tensorflow/core/profiler/lib/profiler_session.cc:164] Profiler session tear down.


 1/10 [==>...........................] - ETA: 28s - loss: 0.0817 - accuracy: 0.4866

2021-11-15 14:34:44.038013: I tensorflow/core/profiler/lib/profiler_session.cc:131] Profiler session initializing.
2021-11-15 14:34:44.038029: I tensorflow/core/profiler/lib/profiler_session.cc:146] Profiler session started.


 2/10 [=====>........................] - ETA: 13s - loss: 0.5031 - accuracy: 0.4787

2021-11-15 14:34:45.766367: I tensorflow/core/profiler/lib/profiler_session.cc:66] Profiler session collecting data.
2021-11-15 14:34:45.768640: I tensorflow/core/profiler/lib/profiler_session.cc:164] Profiler session tear down.
2021-11-15 14:34:45.771341: I tensorflow/core/profiler/rpc/client/save_profile.cc:136] Creating directory: output/first_run/train/plugins/profile/2021_11_15_14_34_45

2021-11-15 14:34:45.773793: I tensorflow/core/profiler/rpc/client/save_profile.cc:142] Dumped gzipped tool data for trace.json.gz to output/first_run/train/plugins/profile/2021_11_15_14_34_45/Kaushals-MacBook-Pro.local.trace.json.gz
2021-11-15 14:34:45.780973: I tensorflow/core/profiler/rpc/client/save_profile.cc:136] Creating directory: output/first_run/train/plugins/profile/2021_11_15_14_34_45

2021-11-15 14:34:45.781229: I tensorflow/core/profiler/rpc/client/save_profile.cc:142] Dumped gzipped tool data for memory_profile.json.gz to output/first_run/train/plugins/profile/2021_11_15_14_34_45/Kau



<keras.callbacks.History at 0x161957cd0>

In [5]:
# Save model
model_json = model.to_json()
with open("model.json", "w") as json_file:
    json_file.write(model_json)
model.save_weights("model.h5")

In [6]:
# Test images
Xtest = rgb2lab(1.0/255*X[split:])[:,:,:,0]
Xtest = Xtest.reshape(Xtest.shape+(1,))
Ytest = rgb2lab(1.0/255*X[split:])[:,:,:,1:]
Ytest = Ytest / 128
print(model.evaluate(Xtest, Ytest, batch_size=batch_size))

0.010436208918690681


In [7]:
color_me = []
for filename in os.listdir('Test/'):
    color_me.append(img_to_array(load_img('Test/'+filename)))
color_me = np.array(color_me, dtype=float)
color_me = rgb2lab(1.0/255*color_me)[:,:,:,0]
color_me = color_me.reshape(color_me.shape+(1,))

# Test model
output = model.predict(color_me)
output = output * 128

# Output colorizations
for i in range(len(output)):
    cur = np.zeros((256, 256, 3))
    cur[:,:,0] = color_me[i][:,:,0]
    cur[:,:,1:] = output[i]
    imsave("result/img_"+str(i)+".png", lab2rgb(cur))

