## Sources

The sources for this implemetation can be found at:
- https://richzhang.github.io/colorization/
- https://github.com/richzhang/colorization
- https://github.com/foamliu/Colorful-Image-Colorization

## Code

#### Useful imports

In [1]:
import os
import numpy as np
import cv2 as cv
import sklearn.neighbors as nn
import keras.backend as K
import tensorflow as tf
from keras.layers import Input, Conv2D, BatchNormalization, UpSampling2D
from keras.models import Model

#### Initial configuration of input size, temperature and epsilon.

In [2]:
img_rows, img_cols = 256, 256
channel = 3
num_classes = 313
epsilon = 1e-6 # Default = 1e-6
epsilon_sqr = epsilon ** 2

nb_neighbors = 5
T = 0.8 # Temperature parameter

#### Implementation of original function to enhance the images.

In [3]:
def enhance_saturation(hsvImg):
    hsvImg[..., 1] = hsvImg[..., 1] * ( enhance_weight*(np.exp( - np.absolute( ( (hsvImg[...,1] / 255) - center ) * bell_tighness )**exponentiality )) + 1 )
    hsvImg[..., 1] = np.clip(hsvImg[..., 1], 0, 255)                          
    out_bgr = cv.cvtColor(hsvImg, cv.COLOR_HSV2BGR)
    return out_bgr

#### Configuration of enchancing parameters.

In [4]:
enhance = True # True or False, True to enable the enhancement, False to disable
enhance_weight = 0.55 # Range 0-1 Default = 0.55
bell_tighness = 3 # Default 5
exponentiality = 4 # Default 3
center = 0.2 #default 1/5

#### Construction of the Model

In [5]:
def build_encoder_decoder():
    kernel = 3

    input_tensor = Input(shape=(img_rows, img_cols, 1))
    x = Conv2D(64, (kernel, kernel), activation='relu', padding='same', name='conv1_1')(input_tensor)
    x = Conv2D(64, (kernel, kernel), activation='relu', padding='same', name='conv1_2', strides=(2, 2))(x)
    x = BatchNormalization()(x)

    x = Conv2D(128, (kernel, kernel), activation='relu', padding='same', name='conv2_1')(x)
    x = Conv2D(128, (kernel, kernel), activation='relu', padding='same', name='conv2_2', strides=(2, 2))(x)
    x = BatchNormalization()(x)

    x = Conv2D(256, (kernel, kernel), activation='relu', padding='same', name='conv3_1')(x)
    x = Conv2D(256, (kernel, kernel), activation='relu', padding='same', name='conv3_2')(x)
    x = Conv2D(256, (kernel, kernel), activation='relu', padding='same', name='conv3_3', strides=(2, 2))(x)
    x = BatchNormalization()(x)

    x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', name='conv4_1')(x)
    x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', name='conv4_2')(x)
    x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', name='conv4_3')(x)
    x = BatchNormalization()(x)

    x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', dilation_rate=2, name='conv5_1')(x)
    x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', dilation_rate=2, name='conv5_2')(x)
    x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', dilation_rate=2, name='conv5_3')(x)
    x = BatchNormalization()(x)

    x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', dilation_rate=2, name='conv6_1')(x)
    x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', dilation_rate=2, name='conv6_2')(x)
    x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', dilation_rate=2, name='conv6_3')(x)
    x = BatchNormalization()(x)

    x = Conv2D(256, (kernel, kernel), activation='relu', padding='same', name='conv7_1')(x)
    x = Conv2D(256, (kernel, kernel), activation='relu', padding='same', name='conv7_2')(x)
    x = Conv2D(256, (kernel, kernel), activation='relu', padding='same', name='conv7_3')(x)
    x = BatchNormalization()(x)

    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(128, (kernel, kernel), activation='relu', padding='same', name='conv8_1')(x)
    x = Conv2D(128, (kernel, kernel), activation='relu', padding='same', name='conv8_2')(x)
    x = Conv2D(128, (kernel, kernel), activation='relu', padding='same', name='conv8_3')(x)
    x = BatchNormalization()(x)

    outputs = Conv2D(num_classes, (1, 1), activation='softmax', padding='same', name='pred')(x)

    model = Model(inputs=input_tensor, outputs=outputs, name="ColorNet")
    return model

#### Preprocessing on input images and creation of output.


In [7]:
    model_weights_path = 'weights/model_weights.hdf5'
    model = build_encoder_decoder()
    model.load_weights(model_weights_path)

    print(model.summary())
  
    image_folder = 'demo_input_images'
    names_file = 'valid_images.txt'
    with open(names_file, 'r') as f:
        names = f.read().splitlines()
    samples = names
    h, w = img_rows // 4, img_cols // 4

    # Load the array of quantized ab value
    q_ab = np.load("data/pts_in_hull.npy")
    nb_q = q_ab.shape[0]

    # Fit a NN to q_ab
    nn_finder = nn.NearestNeighbors(n_neighbors=nb_neighbors, algorithm='ball_tree').fit(q_ab)


    for i in range(len(samples)):
        
        image_name = samples[i]
        filename = os.path.join(image_folder, image_name)
        print('Start processing image: {}'.format(filename))
        
        bgr = cv.imread(filename)
        gray = cv.imread(filename, 0)
        bgr = cv.resize(bgr, (img_rows, img_cols), cv.INTER_CUBIC)
        gray = cv.resize(gray, (img_rows, img_cols), cv.INTER_CUBIC)
        lab = cv.cvtColor(bgr, cv.COLOR_BGR2LAB)
        L = lab[:, :, 0]
        a = lab[:, :, 1]
        b = lab[:, :, 2]
        x_test = np.empty((1, img_rows, img_cols, 1), dtype=np.float32)
        x_test[0, :, :, 0] = gray / 255.

        X_colorized = model.predict(x_test)
        X_colorized = X_colorized.reshape((h * w, nb_q))

        X_colorized = np.exp(np.log(X_colorized + epsilon) / T)
        X_colorized = X_colorized / np.sum(X_colorized, 1)[:, np.newaxis]

        q_a = q_ab[:, 0].reshape((1, 313))
        q_b = q_ab[:, 1].reshape((1, 313))

        X_a = np.sum(X_colorized * q_a, 1).reshape((h, w))
        X_b = np.sum(X_colorized * q_b, 1).reshape((h, w))
        X_a = cv.resize(X_a, (img_rows, img_cols), cv.INTER_CUBIC)
        X_b = cv.resize(X_b, (img_rows, img_cols), cv.INTER_CUBIC)
        X_a = X_a + 128
        X_b = X_b + 128

        out_lab = np.zeros((img_rows, img_cols, 3), dtype=np.int32)
        out_lab[:, :, 0] = lab[:, :, 0]
        out_lab[:, :, 1] = X_a
        out_lab[:, :, 2] = X_b
        out_L = out_lab[:, :, 0]
        out_a = out_lab[:, :, 1]
        out_b = out_lab[:, :, 2]
        out_lab = out_lab.astype(np.uint8)
        out_bgr = cv.cvtColor(out_lab, cv.COLOR_LAB2BGR)
        out_bgr = out_bgr.astype(np.uint8)
        
        if enhance:
            hsvImg = cv.cvtColor(out_bgr,cv.COLOR_BGR2HSV)
            out_bgr = enhance_saturation(hsvImg)
      
        if not os.path.exists('output_images'):
            os.makedirs('output_images')

        cv.imwrite('output_images/{}_image.png'.format(i), gray)
        cv.imwrite('output_images/{}_gt.png'.format(i), bgr)
        cv.imwrite('output_images/{}_out_t{}.png'.format(i,T), out_bgr)

Model: "ColorNet"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 256, 256, 1)]     0         
_________________________________________________________________
conv1_1 (Conv2D)             (None, 256, 256, 64)      640       
_________________________________________________________________
conv1_2 (Conv2D)             (None, 128, 128, 64)      36928     
_________________________________________________________________
batch_normalization_8 (Batch (None, 128, 128, 64)      256       
_________________________________________________________________
conv2_1 (Conv2D)             (None, 128, 128, 128)     73856     
_________________________________________________________________
conv2_2 (Conv2D)             (None, 64, 64, 128)       147584    
_________________________________________________________________
batch_normalization_9 (Batch (None, 64, 64, 128)       512