In [None]:
import glob
import math
import tensorflow as tf
import numpy as np
import skimage as ski
from scipy.special import softmax
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

### Model Definition

In [None]:
tf.keras.backend.clear_session()

def init():
    return tf.keras.initializers.GlorotNormal()
def reg():
    # The weight decay is 1e-3 in the paper, but apparently
    # it ought to be divided by 2 when converting Caffe model to Keras.
    # https://bbabenko.github.io/weight-decay/
    return tf.keras.regularizers.l2(l=5e-4)

## COLORIZATION

input_ = tf.keras.Input(shape=(224, 224, 1))
conv1_1 = tf.keras.layers.Conv2D(64, (3, 3), name='conv1_1', activation='relu', padding='same',
                                 kernel_initializer=init(), kernel_regularizer = reg())(input_)
cov1_2pad = tf.keras.layers.ZeroPadding2D((1, 1))(conv1_1)
cov1_2 = tf.keras.layers.Conv2D(64, (3, 3), name='conv1_2', activation='relu', padding='valid', strides=(2, 2),
                                kernel_initializer=init(), kernel_regularizer = reg())(cov1_2pad)
conv1_2norm = tf.keras.layers.BatchNormalization(name='conv1_2norm', center=False, scale=False)(cov1_2)

conv2_1 = tf.keras.layers.Conv2D(128, (3, 3), name='conv2_1', activation='relu', padding='same',
                                 kernel_initializer=init(), kernel_regularizer = reg())(conv1_2norm)
conv2_2pad = tf.keras.layers.ZeroPadding2D((1, 1))(conv2_1)
conv2_2 = tf.keras.layers.Conv2D(128, (3, 3), name='conv2_2', activation='relu', padding='valid', strides=(2,2),
                                 kernel_initializer=init(), kernel_regularizer = reg())(conv2_2pad)
conv2_2norm = tf.keras.layers.BatchNormalization(name='conv2_2norm', center=False, scale=False)(conv2_2)

conv3_1 = tf.keras.layers.Conv2D(256, (3, 3), name='conv3_1', activation='relu', padding='same',
                                 kernel_initializer=init(), kernel_regularizer = reg())(conv2_2norm)
conv3_2 = tf.keras.layers.Conv2D(256, (3, 3), name='conv3_2', activation='relu', padding='same',
                                 kernel_initializer=init(), kernel_regularizer = reg())(conv3_1)
conv3_3pad = tf.keras.layers.ZeroPadding2D((1, 1))(conv3_2)
conv3_3 = tf.keras.layers.Conv2D(256, (3, 3), name='conv3_3', activation='relu', padding='valid', strides=(2,2),
                                 kernel_initializer=init(), kernel_regularizer = reg())(conv3_3pad)
conv3_3norm = tf.keras.layers.BatchNormalization(name='conv3_3norm', center=False, scale=False)(conv3_3)

conv4_1 = tf.keras.layers.Conv2D(512, (3, 3), name='conv4_1', activation='relu', padding='same',
                                 kernel_initializer=init(), kernel_regularizer = reg())(conv3_3norm)
conv4_2 = tf.keras.layers.Conv2D(512, (3, 3), name='conv4_2', activation='relu', padding='same',
                                 kernel_initializer=init(), kernel_regularizer = reg())(conv4_1)
conv4_3 = tf.keras.layers.Conv2D(512, (3, 3), name='conv4_3', activation='relu', padding='same',
                                 kernel_initializer=init(), kernel_regularizer = reg())(conv4_2)
conv4_3norm = tf.keras.layers.BatchNormalization(name='conv4_3norm', center=False, scale=False)(conv4_3)

conv5_1 = tf.keras.layers.Conv2D(512, (3, 3), name='conv5_1', activation='relu', padding='same', dilation_rate=(2,2),
                                 kernel_initializer=init(), kernel_regularizer = reg())(conv4_3norm)
conv5_2 = tf.keras.layers.Conv2D(512, (3, 3), name='conv5_2', activation='relu', padding='same', dilation_rate=(2,2),
                                 kernel_initializer=init(), kernel_regularizer = reg())(conv5_1)
conv5_3 = tf.keras.layers.Conv2D(512, (3, 3), name='conv5_3', activation='relu', padding='same', dilation_rate=(2,2),
                                 kernel_initializer=init(), kernel_regularizer = reg())(conv5_2)
conv5_3norm = tf.keras.layers.BatchNormalization(name='conv5_3norm', center=False, scale=False)(conv5_3)

conv6_1 = tf.keras.layers.Conv2D(512, (3, 3), name='conv6_1', activation='relu', padding='same', dilation_rate=(2,2),
                                 kernel_initializer=init(), kernel_regularizer = reg())(conv5_3norm)
conv6_2 = tf.keras.layers.Conv2D(512, (3, 3), name='conv6_2', activation='relu', padding='same', dilation_rate=(2,2),
                                 kernel_initializer=init(), kernel_regularizer = reg())(conv6_1)
conv6_3 = tf.keras.layers.Conv2D(512, (3, 3), name='conv6_3', activation='relu', padding='same', dilation_rate=(2,2),
                                 kernel_initializer=init(), kernel_regularizer = reg())(conv6_2)
conv6_3norm = tf.keras.layers.BatchNormalization(name='conv6_3norm', center=False, scale=False)(conv6_3)

conv7_1 = tf.keras.layers.Conv2D(512, (3, 3), name='conv7_1', activation='relu', padding='same',
                                 kernel_initializer=init(), kernel_regularizer = reg())(conv6_3norm)
conv7_2 = tf.keras.layers.Conv2D(512, (3, 3), name='conv7_2', activation='relu', padding='same',
                                 kernel_initializer=init(), kernel_regularizer = reg())(conv7_1)
conv7_3 = tf.keras.layers.Conv2D(512, (3, 3), name='conv7_3', activation='relu', padding='same',
                                 kernel_initializer=init(), kernel_regularizer = reg())(conv7_2)
conv7_3norm = tf.keras.layers.BatchNormalization(name='conv7_3norm', center=False, scale=False)(conv7_3)

conv8_1 = tf.keras.layers.Conv2DTranspose(256, (4, 4), name='conv8_1', activation='relu', padding='same', strides=(2,2),
                                          kernel_initializer=init(), kernel_regularizer = reg())(conv7_3norm)
conv8_2 = tf.keras.layers.Conv2D(256, (3, 3), name='conv8_2', activation='relu', padding='same',
                                 kernel_initializer=init(), kernel_regularizer = reg())(conv8_1)
conv8_3 = tf.keras.layers.Conv2D(256, (3, 3), name='conv8_3', activation='relu', padding='same',
                                 kernel_initializer=init(), kernel_regularizer = reg())(conv8_2)

conv8_313 = tf.keras.layers.Conv2D(313, (1, 1), name='conv8_313',
                                   kernel_initializer=init(), kernel_regularizer = reg())(conv8_3)

## IMAGE SEGMENTATION

deconv5_1 = tf.keras.layers.Conv2DTranspose(128, (3, 3), name='deconv5_1', activation='relu', padding='same', dilation_rate=(2,2),
                                            kernel_initializer=init(), kernel_regularizer = reg())(conv5_3norm)
deconv5_2 = tf.keras.layers.Conv2DTranspose(128, (3, 3), name='deconv5_2', activation='relu', padding='same', dilation_rate=(2,2),
                                            kernel_initializer=init(), kernel_regularizer = reg())(deconv5_1)
deconv5_3 = tf.keras.layers.Conv2DTranspose(128, (3, 3), name='deconv5_3', activation='relu', padding='same', dilation_rate=(2,2),
                                            kernel_initializer=init(), kernel_regularizer = reg())(deconv5_2)
deconv5_3norm = tf.keras.layers.BatchNormalization(name='deconv5_3norm', center=False, scale=False)(deconv5_3)

deconv6_1 = tf.keras.layers.Conv2DTranspose(128, (3, 3), name='deconv6_1', activation='relu', padding='same', dilation_rate=(2,2),
                                            kernel_initializer=init(), kernel_regularizer = reg())(conv6_3norm)
deconv6_2 = tf.keras.layers.Conv2DTranspose(128, (3, 3), name='deconv6_2', activation='relu', padding='same', dilation_rate=(2,2),
                                            kernel_initializer=init(), kernel_regularizer = reg())(deconv6_1)
deconv6_3 = tf.keras.layers.Conv2DTranspose(128, (3, 3), name='deconv6_3', activation='relu', padding='same', dilation_rate=(2,2),
                                            kernel_initializer=init(), kernel_regularizer = reg())(deconv6_2)
deconv6_3norm = tf.keras.layers.BatchNormalization(name='deconv6_3norm', center=False, scale=False)(deconv6_3)

deconv7_1 = tf.keras.layers.Conv2DTranspose(128, (3, 3), name='deconv7_1', activation='relu', padding='same', dilation_rate=(2,2),
                                            kernel_initializer=init(), kernel_regularizer = reg())(conv7_3norm)
deconv7_2 = tf.keras.layers.Conv2DTranspose(128, (3, 3), name='deconv7_2', activation='relu', padding='same', dilation_rate=(2,2),
                                            kernel_initializer=init(), kernel_regularizer = reg())(deconv7_1)
deconv7_3 = tf.keras.layers.Conv2DTranspose(128, (3, 3), name='deconv7_3', activation='relu', padding='same', dilation_rate=(2,2),
                                            kernel_initializer=init(), kernel_regularizer = reg())(deconv7_2)
deconv7_3norm = tf.keras.layers.BatchNormalization(name='deconv7_3norm', center=False, scale=False)(deconv7_3)

concat = tf.keras.layers.concatenate([deconv5_3norm, deconv6_3norm, deconv7_3norm])

conv9_1 = tf.keras.layers.Conv2DTranspose(256, (4, 4), name='conv9_1', activation='relu', padding='same', strides=(2,2),
                                          kernel_initializer=init(), kernel_regularizer = reg())(concat)
conv9_2 = tf.keras.layers.Conv2D(256, (3, 3), name='conv9_2', activation='relu', padding='same',
                                 kernel_initializer=init(), kernel_regularizer = reg())(conv9_1)
conv9_3 = tf.keras.layers.Conv2D(256, (3, 3), name='conv9_3', activation='relu', padding='same',
                                 kernel_initializer=init(), kernel_regularizer = reg())(conv9_2)

conv9_21 = tf.keras.layers.Conv2D(21, (1, 1), name='conv9_21',
                                  kernel_initializer=init(), kernel_regularizer = reg())(conv9_3)

model = tf.keras.Model(inputs=input_, outputs=[conv8_313, conv9_21])

model.summary()

optimizer = tf.keras.optimizers.Adam(
    learning_rate=3e-05,
    beta_1=0.9,
    beta_2=0.999,
    amsgrad=False,
    epsilon=1e-07,
)

model.compile(optimizer=optimizer,
    loss={
        'conv8_313': tf.keras.losses.CategoricalCrossentropy(
            from_logits=True,
            label_smoothing=0),
        'conv9_21' : tf.keras.losses.CategoricalCrossentropy(
            from_logits=True,
            label_smoothing=0),
    },
    # 1:100 is the ratio given in the paper, to make the losses have similar magnitude.
    # However from my tests ratios 1:10 or even 1:1 seemed more fitting (depending upon initialization).
    # loss_weights = [1, 100],
    metrics=[tf.keras.metrics.CategoricalAccuracy()],
)

tf.keras.utils.plot_model(model, 'graph.png', show_shapes=True)

In [None]:
# Initializes colorization branch according to the weights given by Zhang et al.
# kmeans_init = weights before training, initialized with k-means.
# pretrained = weights after training.
# model.load_weights('models/zhang_kmeans_init.h5', by_name=True)
# model.load_weights('models/zhang_pretrained.h5', by_name=True)

### Image Processing & Data Reading Functions

In [None]:
## GLOBAL VARIABLES

gamut_vals_p = 'resources/gamut_vals.npy'
color_weights_p = 'resources/color_weights.npy'
train_txt_p = 'resources/train.txt'
val_txt_p = 'resources/val.txt'
test_txt_p = 'resources/val.txt'
img_p = 'data/VOCdevkit/VOC2012/JPEGImages/'
img_ext = '.jpg'
segm_p = 'data/VOCdevkit/VOC2012/SegmentationClass/'
segm_ext = '.png'
model_name = 'MVI'
model_p = 'models/' + model_name
tensorboard_p = 'tensorboard/' + model_name

# Used when computing "expected value" of the resulting color distribution.
gamut_vals = np.squeeze(np.load(gamut_vals_p))
# Used to rebalance colors based on their rarity.
color_weights = np.load(color_weights_p)
# Used to find closest colors when smoothing labels (soft encoding).
neighbors = NearestNeighbors(algorithm='auto', metric='minkowski', p=2).fit(gamut_vals)

## PREPROCESSING

def preprocess_x(imgs):
    imgs = [preprocess_x_(img) for img in imgs]
    imgs = np.stack(imgs)
    return imgs

def preprocess_x_(img):
    img = ski.transform.resize(img, (224, 224, 3))
    img = ski.color.rgb2lab(img)
    img = img[:, :, :1]
    # Zhang et al. subtract 50 from each input image for "mean-centering",
    # even though they don't mention it directly in the paper.
    # Possibly it is the mean of the whole ImageNet dataset (in 0-255 values) ?
    # What is interesting is that the ski.transform.resize() method
    # transforms the input range from 0-255 to 0-1, thus all the values
    # stay around -50. The pretrained model does not perform well without it though.
    img = img - 50
    return img

def preprocess_img_y(imgs):
    imgs = [preprocess_img_y_(img) for img in imgs]
    imgs = np.stack(imgs)
    return imgs

def preprocess_img_y_(img):
    img = ski.transform.resize(img, (56, 56, 3))
    img = ski.color.rgb2lab(img)
    img = img[:, :, 1:]
    img = soft_encode(img)
    # The classes are rebalanced here in preprocessing and
    # not during loss calculation,, as to the best of my knowledge
    # thanks to the distributive law it should not matter,
    # and it enables me to not have to use custom losses.
    img = rebalance_classes(img)
    return img

def soft_encode(img):
    img = np.reshape(img, (3136, 2))
    # In the paper n_neighbors is equal to 5, but I found in a GitHub issue it should be 10.
    # https://github.com/richzhang/colorization/issues/59
    dist, ind = neighbors.kneighbors(img, n_neighbors=10)
    sigma = 5
    # This line "weights the neighbors proportionally to their distance from
    # the ground truth usinga Gaussian kernel with sigma = 5."
    dist_norm = softmax(- ((dist ** 2) / (2 * (sigma ** 2))), axis=-1)
    img = np.zeros((3136, 313))
    tmp = np.arange(3136)[:, np.newaxis]
    img[tmp, ind] = dist_norm
    img = np.reshape(img, (56, 56, 313))
    return img

def rebalance_classes(img):
    ind = np.argmax(img, axis=-1)
    img_weights = color_weights[ind]
    img_weights = img_weights[:, :, np.newaxis]
    img = img * img_weights
    return img

def preprocess_segm_y(masks):
    masks = [preprocess_segm_y_(mask) for mask in masks]
    masks = np.stack(masks)
    return masks

def preprocess_segm_y_(mask):
    # These arguments are important to keep the proper class values. Same during postprocessing.
    mask = ski.transform.resize(mask, (56, 56), order=0, preserve_range=True, anti_aliasing=False).astype('uint8')
    mask = tf.one_hot(mask, 21).numpy()
    return mask

## POSTPROCESSING

def postprocess_color(y_test, test_imgs):
    return [postprocess_color_(a, b) for a, b in zip(y_test, test_imgs)]
    
def postprocess_color_(y_test, test_img):
    # Turn logits into probabilities, with T being the temperature parameter described in the paper.
    T = 0.38
    y_test = softmax(y_test / T, axis=-1)
    # Turn probabilities into discrete values by computing their "expected value".
    ab_vals = y_test @ gamut_vals
    ab_vals = ski.transform.resize(ab_vals, (test_img.shape[0], test_img.shape[1], 2))
    out_img = ski.color.rgb2lab(test_img)
    out_img = out_img[:, :, :1]
    out_img = np.concatenate((out_img, ab_vals), axis=-1)
    out_img = ski.color.lab2rgb(out_img)
    out_img = (255 * np.clip(out_img, 0, 1)).astype('uint8')
    return out_img

def postprocess_segm(segm_out, test_imgs):
    return [postprocess_segm_(a, b) for a, b in zip(segm_out, test_imgs)]
    
def postprocess_segm_(segm_out, test_img):
    mask = np.argmax(segm_out, axis=-1)
    mask = ski.transform.resize(mask, (test_img.shape[0], test_img.shape[1]), order=0, preserve_range=True, anti_aliasing=False).astype('uint8')
    return mask

## DATA READING

class MyDataGen(tf.keras.utils.Sequence):
    def __init__(self, usage, batch_size):
        if usage == 'train':
            filename = train_txt_p
        elif usage == 'val':
            filename = val_txt_p
        else:
            ...
            
        self.batch_size = batch_size
        
        with open(filename, 'r') as f:
            self.names = f.read().splitlines()
            np.random.shuffle(self.names)
        
    def __len__(self):
        return math.ceil(len(self.names) / self.batch_size)

    def __getitem__(self, index):
        imgs = [ski.io.imread(img_p + name + img_ext) for name in self.names[index * self.batch_size : (index + 1) * self.batch_size]]
        masks = [ski.io.imread(segm_p + name + segm_ext, pilmode='P') for name in self.names[index * self.batch_size : (index + 1) * self.batch_size]]
        # Half the time mirrors a batch.
        if np.random.random() < 1/2: 
            imgs = [np.fliplr(img) for img in imgs]
            masks = [np.fliplr(mask) for mask in masks]
        return preprocess_x(imgs), {'conv8_313' : preprocess_img_y(imgs), 'conv9_21' : preprocess_segm_y(masks)}

    def on_epoch_end(self):
        np.random.shuffle(self.names)

# Creates the color map associated with the PASCAL VOC2012 dataset - code is copied from:
# https://gist.github.com/wllhf/a4533e0adebe57e3ed06d4b50c8419ae
def color_map(N=256, normalized=False):
    def bitget(byteval, idx):
        return ((byteval & (1 << idx)) != 0)

    dtype = 'float32' if normalized else 'uint8'
    cmap = np.zeros((N, 3), dtype=dtype)
    for i in range(N):
        r = g = b = 0
        c = i
        for j in range(8):
            r = r | (bitget(c, 0) << 7-j)
            g = g | (bitget(c, 1) << 7-j)
            b = b | (bitget(c, 2) << 7-j)
            c = c >> 3

        cmap[i] = np.array([r, g, b])

    cmap = cmap/255 if normalized else cmap
    return cmap
        
## CALLBACKS

callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        filepath=model_p + '_last.h5',
        save_freq='epoch',
        verbose=1,
        save_best_only=False,
        save_weights_only=False,
    ),
    tf.keras.callbacks.ModelCheckpoint(
        filepath=model_p + '_best.h5',
        save_freq='epoch',
        verbose=1,
        save_best_only=True,
        save_weights_only=False,
        monitor='val_loss',
        mode='min',
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        mode='min',
        factor=1/3,
        patience=15,
        cooldown=5,
        min_delta=0.,
        min_lr=1e-06
    ),
    tf.keras.callbacks.TensorBoard(
        log_dir=tensorboard_p,
        histogram_freq=0,
        write_graph=False,
        write_images=False,
        update_freq='epoch',
        profile_batch=0,
        embeddings_freq=0,
        embeddings_metadata=None,
    )
]

### Training

In [None]:
model.fit(
    MyDataGen('train', 40),
    validation_data=MyDataGen('val', 40),
    epochs=1000,
    verbose=2,
    callbacks=callbacks,
    shuffle=False,
)

### Testing

In [None]:
names = glob.glob('data/test/*')
imgs = [ski.io.imread(name) for name in names]
x = preprocess_x(imgs)

color_out, segm_out = model.predict(x)

imgs_out, masks_out = postprocess_color(color_out, imgs), postprocess_segm(segm_out, imgs)

cmap = LinearSegmentedColormap.from_list('pascalVOC2012', color_map(256, True), N=256)
for a, b, c in zip(imgs, imgs_out, masks_out):
    bw = ski.color.rgb2gray(a)
    _, axes = plt.subplots(1, 4, figsize=(16, 16))
    axes[0].imshow(bw, cmap='gray')
    axes[1].imshow(a)
    axes[2].imshow(b)
    axes[3].imshow(c, cmap=cmap)
plt.show()