In [1]:
import keras
from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.preprocessing import image
from keras.engine import Layer
from keras.applications.inception_resnet_v2 import preprocess_input
from keras.layers import Conv2D, UpSampling2D, InputLayer, Conv2DTranspose, Input, Reshape, merge, concatenate
from keras.layers import Activation, Dense, Dropout, Flatten
from keras.layers.normalization import BatchNormalization
from keras.callbacks import TensorBoard 
from keras.models import Sequential, Model
from keras.layers.core import RepeatVector, Permute
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb, rgb2gray, gray2rgb
from skimage.transform import resize
from skimage.io import imsave
import numpy as np
import os
import random
import tensorflow as tf

Using TensorFlow backend.


In [6]:
all_id = list(range(1000, 1100))

valid_split = 0.1
test_split = 0.1

v_index = int(len(all_id)*(1-valid_split-test_split))
t_index = int(len(all_id)*(1-test_split))

train_ids = all_id[:v_index]
valid_ids = all_id[v_index:t_index]
test_ids = all_id[t_index:]

print("Length of train set: " + str(len(train_ids)))
print("Length of validation set: " + str(len(valid_ids)))
print("Length of test set: " + str(len(test_ids)))

partition = {'train': train_ids, 'validation': valid_ids, 'test': test_ids}

Length of train set: 80
Length of validation set: 10
Length of test set: 10


In [7]:
class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs, batch_size=10, dim=(256,256,1), shuffle=True):
        'Initialization'
        self.dim = dim
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(list_IDs_temp)

        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.empty((self.batch_size, *(256,256,1)), dtype=float)
        Y = np.empty((self.batch_size, *(256,256,2)), dtype=float)

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            img = img_to_array(load_img('floydhub_dataset/' + str(ID) + '.jpg'))
            img = 1.0/255*img
            img = rgb2lab(img)
            
            gray = img[:,:,0]
            ab = img[:,:,1:] / 128
            
            # Store sample
            X[i,] = gray.reshape((256,256,1))

            # Store class
            Y[i,] = ab

        return X, Y

In [6]:
#model = Sequential()
#model.add(InputLayer(input_shape=(256, 256, 1)))
#model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
#model.add(Conv2D(64, (3, 3), activation='relu', padding='same', strides=2))
#model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
#model.add(Conv2D(128, (3, 3), activation='relu', padding='same', strides=2))
#model.add(Conv2D(256, (3, 3), activation='relu', padding='same'))
#model.add(Conv2D(256, (3, 3), activation='relu', padding='same', strides=2))
#model.add(Conv2D(512, (3, 3), activation='relu', padding='same'))
#model.add(Conv2D(256, (3, 3), activation='relu', padding='same'))
#model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
#model.add(UpSampling2D((2, 2)))
#model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
#model.add(UpSampling2D((2, 2)))
#model.add(Conv2D(32, (3, 3), activation='relu', padding='same'))
#model.add(Conv2D(2, (3, 3), activation='tanh', padding='same'))
#model.add(UpSampling2D((2, 2)))
#model.compile(optimizer='rmsprop', loss='mse')

In [8]:
encoder_input = Input(shape=(256, 256, 1,), name='encoder_input')
encoder_output = Conv2D(64, (3,3), activation='relu', padding='same', strides=2)(encoder_input)
encoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(128, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
encoder_output = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)

#Decoder
decoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(encoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)
decoder_output = Conv2D(64, (3,3), activation='relu', padding='same')(decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)
decoder_output = Conv2D(32, (3,3), activation='relu', padding='same')(decoder_output)
decoder_output = Conv2D(16, (3,3), activation='relu', padding='same')(decoder_output)
decoder_output = Conv2D(2, (3, 3), activation='tanh', padding='same')(decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)

model = Model(inputs=encoder_input, outputs=decoder_output)

In [9]:
# Image transformer
#datagen = ImageDataGenerator(
#        shear_range=0.2,
#        zoom_range=0.2,
#        rotation_range=20,
#        horizontal_flip=True)

# Generate training data
#batch_size = 10
#def image_a_b_gen(batch_size):
#    for batch in datagen.flow(Xtrain, batch_size=batch_size):
#        lab_batch = rgb2lab(batch)
#        X_batch = lab_batch[:,:,:,0]
#        Y_batch = lab_batch[:,:,:,1:] / 128
#        yield (X_batch.reshape(X_batch.shape+(1,)), Y_batch)


training_generator = DataGenerator(partition['train'], batch_size=10)

# Train model      
tensorboard = TensorBoard(log_dir="/output/beta_run")
model.compile(optimizer='rmsprop', loss='mse')
model.fit_generator(training_generator, callbacks=[tensorboard], epochs=2, steps_per_epoch=8)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x14e56fb4080>

In [7]:
# Save model
model_json = model.to_json()
with open("model.json", "w") as json_file:
    json_file.write(model_json)
model.save_weights("model.h5")

In [11]:
# Test images
#Xtest = rgb2lab(1.0/255*X[split:])[:,:,:,0]
#Xtest = Xtest.reshape(Xtest.shape+(1,))
#Ytest = rgb2lab(1.0/255*X[split:])[:,:,:,1:]
#Ytest = Ytest / 128
#print(model.evaluate(Xtest, Ytest, batch_size=batch_size))

In [10]:
# Change to '/data/images/Test/' to use all the 500 images
color_me = []
for filename in os.listdir('Test/'):
	color_me.append(img_to_array(load_img('Test/'+filename)))
color_me = np.array(color_me, dtype=float)
color_me = rgb2lab(1.0/255*color_me)[:,:,:,0]
color_me = color_me.reshape(color_me.shape+(1,))

# Test model
output = model.predict(color_me)
output = output * 128

# Output colorizations
for i in range(len(output)):
	cur = np.zeros((256, 256, 3))
	cur[:,:,0] = color_me[i][:,:,0]
	cur[:,:,1:] = output[i]
	imsave("result/img_"+str(i)+".png", lab2rgb(cur))

  .format(dtypeobj_in, dtypeobj_out))
