In [0]:
img_rows, img_cols = 256, 256
channel = 3
batch_size = 32
epochs = 100
patience = 50
num_train_samples = 1000
num_valid_samples = 200
num_classes = 313
kernel = 3
weight_decay = 1e-1
epsilon = 1e-8
nb_neighbors = 5
# temperature parameter T
T = 0.8

In [0]:
import keras.backend as K
import tensorflow as tf
from keras.layers import Input, Conv2D, BatchNormalization, UpSampling2D
from keras.models import Model
from keras.regularizers import l2
from keras.utils import multi_gpu_model
from keras.utils import plot_model

l2_reg = l2(1e-3)

def build_model():
    input_tensor = Input(shape=(img_rows, img_cols, 1))
    x = Conv2D(64, (kernel, kernel), activation='relu', padding='same', name='conv1_1', kernel_initializer="he_normal",
               kernel_regularizer=l2_reg)(input_tensor)
    x = Conv2D(64, (kernel, kernel), activation='relu', padding='same', name='conv1_2', kernel_initializer="he_normal",
               kernel_regularizer=l2_reg, strides=(2, 2))(x)
    x = BatchNormalization()(x)

    x = Conv2D(128, (kernel, kernel), activation='relu', padding='same', name='conv2_1', kernel_initializer="he_normal",
               kernel_regularizer=l2_reg)(x)
    x = Conv2D(128, (kernel, kernel), activation='relu', padding='same', name='conv2_2', kernel_initializer="he_normal",
               kernel_regularizer=l2_reg,
               strides=(2, 2))(x)
    x = BatchNormalization()(x)

    x = Conv2D(256, (kernel, kernel), activation='relu', padding='same', name='conv3_1',
               kernel_initializer="he_normal", kernel_regularizer=l2_reg)(x)
    x = Conv2D(256, (kernel, kernel), activation='relu', padding='same', name='conv3_2',
               kernel_initializer="he_normal", kernel_regularizer=l2_reg)(x)
    x = Conv2D(256, (kernel, kernel), activation='relu', padding='same', name='conv3_3', kernel_initializer="he_normal",
               strides=(2, 2))(x)
    x = BatchNormalization()(x)

    x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', name='conv4_1',
               kernel_initializer="he_normal", kernel_regularizer=l2_reg)(x)
    # x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', name='conv4_2',
    #            kernel_initializer="he_normal", kernel_regularizer=l2_reg)(x)
    # x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', name='conv4_3',
    #            kernel_initializer="he_normal", kernel_regularizer=l2_reg)(x)
    x = BatchNormalization()(x)

    x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', dilation_rate=2, name='conv5_1',
               kernel_initializer="he_normal", kernel_regularizer=l2_reg)(x)
    # x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', dilation_rate=2, name='conv5_2',
    #            kernel_initializer="he_normal", kernel_regularizer=l2_reg)(x)
    # x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', dilation_rate=2, name='conv5_3',
    #            kernel_initializer="he_normal", kernel_regularizer=l2_reg)(x)
    x = BatchNormalization()(x)

    x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', dilation_rate=2, name='conv6_1',
               kernel_initializer="he_normal", kernel_regularizer=l2_reg)(x)
    # x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', dilation_rate=2, name='conv6_2',
    #            kernel_initializer="he_normal", kernel_regularizer=l2_reg)(x)
    # x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', dilation_rate=2, name='conv6_3',
    #            kernel_initializer="he_normal", kernel_regularizer=l2_reg)(x)
    x = BatchNormalization()(x)

    x = Conv2D(256, (kernel, kernel), activation='relu', padding='same', name='conv7_1',
               kernel_initializer="he_normal", kernel_regularizer=l2_reg)(x)
    # x = Conv2D(256, (kernel, kernel), activation='relu', padding='same', name='conv7_2',
    #            kernel_initializer="he_normal", kernel_regularizer=l2_reg)(x)
    # x = Conv2D(256, (kernel, kernel), activation='relu', padding='same', name='conv7_3',
    #            kernel_initializer="he_normal", kernel_regularizer=l2_reg)(x)
    x = BatchNormalization()(x)

    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(128, (kernel, kernel), activation='relu', padding='same', name='conv8_1',
               kernel_initializer="he_normal", kernel_regularizer=l2_reg)(x)
    # x = Conv2D(128, (kernel, kernel), activation='relu', padding='same', name='conv8_2',
    #            kernel_initializer="he_normal", kernel_regularizer=l2_reg)(x)
    # x = Conv2D(128, (kernel, kernel), activation='relu', padding='same', name='conv8_3',
    #            kernel_initializer="he_normal", kernel_regularizer=l2_reg)(x)
    x = BatchNormalization()(x)

    outputs = Conv2D(num_classes, (1, 1), activation='softmax', padding='same', name='pred')(x)

    model = Model(inputs=input_tensor, outputs=outputs, name="ColorNet")
    return model

In [0]:
import os
import random
from random import shuffle

import cv2 as cv
import numpy as np
import sklearn.neighbors as nn
from keras.utils import Sequence
from keras.datasets import cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()


def get_soft_encoding(image_ab, nn_finder, nb_q):
    h, w = image_ab.shape[:2]
    a = np.ravel(image_ab[:, :, 0])
    b = np.ravel(image_ab[:, :, 1])
    ab = np.vstack((a, b)).T
    # Get the distance to and the idx of the nearest neighbors
    dist_neighb, idx_neigh = nn_finder.kneighbors(ab)
    # Smooth the weights with a gaussian kernel
    sigma_neighbor = 5
    wts = np.exp(-dist_neighb ** 2 / (2 * sigma_neighbor ** 2))
    wts = wts / np.sum(wts, axis=1)[:, np.newaxis]
    # format the tar get
    y = np.zeros((ab.shape[0], nb_q))
    idx_pts = np.arange(ab.shape[0])[:, np.newaxis]
    y[idx_pts, idx_neigh] = wts
    y = y.reshape(h, w, nb_q)
    return y


class DataGenSequence(Sequence):
    def __init__(self, usage):
        self.usage = usage

        if usage == 'train':
            images = x_train[0:40000]
        else:
            images = x_train[40001:50000]


        np.random.shuffle(images)

        # Load the array of quantized ab value
        q_ab = np.load("/content/pts_in_hull.npy")
        self.nb_q = q_ab.shape[0]
        # Fit a NN to q_ab
        self.nn_finder = nn.NearestNeighbors(n_neighbors=nb_neighbors, algorithm='ball_tree').fit(q_ab)

    def __len__(self):
       return int(np.ceil(40000 / float(batch_size)))
    def __getitem__(self, idx):
        i = idx * batch_size

        out_img_rows, out_img_cols = img_rows // 4, img_cols // 4

        length = min(batch_size, (40000 - i))
        batch_x = np.empty((length, img_rows, img_cols, 1), dtype=np.float32)
        batch_y = np.empty((length, out_img_rows, out_img_cols, self.nb_q), dtype=np.float32)

        for i_batch in range(length):
            #name = self.names[i]
            #filename = os.path.join(image_folder, name)
            # b: 0 <=b<=255, g: 0 <=g<=255, r: 0 <=r<=255.
            #bgr = cv.imread(filename)
            # bgr = cv.resize(bgr, (img_rows, img_cols), cv.INTER_CUBIC)
            gray = cv.cvtColor(x_train[i], cv.COLOR_BGR2GRAY)
            gray = cv.resize(gray, (256, 256))
            # gray = cv.resize(gray, (img_rows, img_cols), cv.INTER_CUBIC)
            lab = cv.cvtColor(x_train[i], cv.COLOR_BGR2LAB)
            lab = cv.resize(lab, (256, 256))
            x = gray / 255.

            out_lab = cv.resize(lab, (out_img_rows, out_img_cols), cv.INTER_CUBIC)
            # Before: 42 <=a<= 226, 20 <=b<= 223
            # After: -86 <=a<= 98, -108 <=b<= 95
            out_ab = out_lab[:, :, 1:].astype(np.int32) - 128

            y = get_soft_encoding(out_ab, self.nn_finder, self.nb_q)

            if np.random.random_sample() > 0.5:
                x = np.fliplr(x)
                y = np.fliplr(y)

            batch_x[i_batch, :, :, 0] = x
            batch_y[i_batch] = y

            i += 1

        return batch_x, batch_y

    def on_epoch_end(self):
        np.random.shuffle(x_train)


def train_gen():
    return DataGenSequence('train')


def valid_gen():
    return DataGenSequence('valid')




In [0]:
import argparse

import keras
import tensorflow as tf
from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from keras.utils import multi_gpu_model
import matplotlib.pyplot as plt

#from utils import get_available_gpus, categorical_crossentropy_color
checkpoint_models_path = 'data/'


# Callbacks
# tensor_board = keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0, write_graph=True, write_images=True)
# model_names = checkpoint_models_path + 'model.{epoch:02d}-{val_loss:.4f}.hdf5'
# model_checkpoint = ModelCheckpoint(model_names, monitor='val_loss', verbose=1, save_best_only=True)
# early_stop = EarlyStopping('val_loss', patience=patience)
# reduce_lr = ReduceLROnPlateau('val_loss', factor=0.1, patience=int(patience / 4), verbose=1)

model = build_model()
sgd = keras.optimizers.SGD(lr=0.001, momentum=0.9, nesterov=True, clipnorm=5.)
model.compile(optimizer=sgd, loss='mean_squared_error')
#print(model.summary())
# Final callbacks
#callbacks = [tensor_board, model_checkpoint, early_stop, reduce_lr]

# Start Fine-tuning
history = model.fit_generator(train_gen(),
                        steps_per_epoch=num_train_samples,
                        validation_data=valid_gen(),
                        validation_steps=num_valid_samples,
                        epochs=epochs,
                        verbose=1,
                        # callbacks=callbacks,
                        # use_multiprocessing=True,
                        # workers=8
                        )
plt.plot(history.history['loss'])


In [0]:
model.evaluate(x=train_gen(), batch_size=None, verbose=1, sample_weight=None, steps=100, callbacks=None, max_queue_size=10, workers=1, use_multiprocessing=False)


In [0]:
model.predict(valid_gen(), verbose=0, steps=10, callbacks=None, max_queue_size=10, workers=1, use_multiprocessing=False)
