In [1]:
import keras
from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.preprocessing import image
from keras.engine import Layer
from keras.applications.inception_resnet_v2 import preprocess_input
from keras.layers import Conv2D, UpSampling2D, InputLayer, Conv2DTranspose, Input, Reshape, merge, concatenate
from keras.layers import Activation, Dense, Dropout, Flatten
from keras.layers.normalization import BatchNormalization
from keras.callbacks import TensorBoard 
from keras.models import Sequential, Model
from keras.layers.core import RepeatVector, Permute
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb, rgb2gray, gray2rgb
from skimage.transform import resize
from skimage.io import imsave
import numpy as np
import os
import random
import tensorflow as tf

Using TensorFlow backend.


In [2]:
# Get images
#X = []
#i = 0
#for filename in os.listdir('Train/'):
#    X.append(img_to_array(load_img('Train/'+filename)))
#    i += 1
#    if i > 10:
#        break
#X = np.array(X, dtype=float)
#Xtrain = 1.0/255*X

##Ossz hany kepbol tanuljon
all_id = list(range(0, 10))

valid_split = 0.1
test_split = 0.1

v_index = int(len(all_id)*(1-valid_split-test_split))
t_index = int(len(all_id)*(1-test_split))

train_ids = all_id[:v_index]
valid_ids = all_id[v_index:t_index]
test_ids = all_id[t_index:]

partition = {'train': train_ids, 'validation': valid_ids, 'test': test_ids}

#Load weights
inception = InceptionResNetV2(weights='imagenet', include_top=True)
inception.graph = tf.get_default_graph()

In [3]:
embed_input = Input(shape=(1000,), name='embed_input')

#Encoder
encoder_input = Input(shape=(256, 256, 1,), name='encoder_input')
encoder_output = Conv2D(64, (3,3), activation='relu', padding='same', strides=2)(encoder_input)
encoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(128, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
encoder_output = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)

#Fusion
fusion_output = RepeatVector(32 * 32)(embed_input) 
fusion_output = Reshape(([32, 32, 1000]))(fusion_output)
fusion_output = concatenate([encoder_output, fusion_output], axis=3) 
fusion_output = Conv2D(256, (1, 1), activation='relu', padding='same')(fusion_output) 

#Decoder
decoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(fusion_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)
decoder_output = Conv2D(64, (3,3), activation='relu', padding='same')(decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)
decoder_output = Conv2D(32, (3,3), activation='relu', padding='same')(decoder_output)
decoder_output = Conv2D(16, (3,3), activation='relu', padding='same')(decoder_output)
decoder_output = Conv2D(2, (3, 3), activation='tanh', padding='same')(decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)

model = Model(inputs=[encoder_input, embed_input], outputs=decoder_output)

In [4]:
def create_inception_embedding(grayscaled_rgb):
    grayscaled_rgb_resized = []
    for i in grayscaled_rgb:
        i = resize(i, (299, 299, 3), mode='constant')
        #print(i.shape)
        grayscaled_rgb_resized.append(i)
    grayscaled_rgb_resized = np.array(grayscaled_rgb_resized)
    grayscaled_rgb_resized = preprocess_input(grayscaled_rgb_resized)
    with inception.graph.as_default():
        embed = inception.predict(grayscaled_rgb_resized)
    return embed

class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs, batch_size=32, dim=(256,256,1), shuffle=True):
        'Initialization'
        self.dim = dim
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(list_IDs_temp)
        
        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.empty((self.batch_size, *(256,256,1)), dtype=float)
        Y = np.empty((self.batch_size, *(256,256,2)), dtype=float)
        I = np.empty((self.batch_size, 1000))

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            img = img_to_array(load_img('Train/' + str(ID) + '.jpg'))
            grayscaled_rgb = gray2rgb(rgb2gray(img))
            grayscaled_rgb = grayscaled_rgb.reshape((1,*(grayscaled_rgb.shape)))
            img = 1.0/255*img
            img = rgb2lab(img)
            gray = img[:,:,0]
            ab = img[:,:,1:] / 128
            
            # Inception
            I[i,] = create_inception_embedding(grayscaled_rgb)
            
            # Store sample
            X[i,] = gray.reshape((256,256,1))

            # Store class
            Y[i,] = ab

        return {'encoder_input': X, 'embed_input': I}, Y

# Image transformer
#datagen = ImageDataGenerator(
#        shear_range=0.2,
#        zoom_range=0.2,
#        rotation_range=20,
#        horizontal_flip=True)

#Generate training data
#Ez az eredeti batch_size, nincs koze a mi uj batch_sizeunkhoz
#batch_size = 10

#def image_a_b_gen(batch_size):
#    for batch in datagen.flow(Xtrain, batch_size=batch_size):
#        grayscaled_rgb = gray2rgb(rgb2gray(batch))
#        embed = create_inception_embedding(grayscaled_rgb)
#        lab_batch = rgb2lab(batch)
#        X_batch = lab_batch[:,:,:,0]
#        X_batch = X_batch.reshape(X_batch.shape+(1,))
#        Y_batch = lab_batch[:,:,:,1:] / 128
#        #print(batch.shape)
#        print(embed.shape)
#        yield ([X_batch, embed], Y_batch)

##A kepeket 10-esevel betoltogeto cucc (bazch_size, hogy hanyasaval toltse be)
training_generator = DataGenerator(partition['train'], batch_size=1)

#Train model      
model.compile(optimizer='rmsprop', loss='mse')
model.fit_generator(training_generator, epochs=10, steps_per_epoch=8) #a steps_per_epoch=ossz_kep/batch_size 100/10

Epoch 1/10


  warn("Anti-aliasing will be enabled by default in skimage 0.15 to "


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x1e30d3e7710>

In [5]:
color_me = []
for filename in os.listdir('Test/'):
    color_me.append(img_to_array(load_img('Test/'+filename)))
color_me = np.array(color_me, dtype=float)
gray_me = gray2rgb(rgb2gray(1.0/255*color_me))
color_me_embed = create_inception_embedding(gray_me)
color_me = rgb2lab(1.0/255*color_me)[:,:,:,0]
color_me = color_me.reshape(color_me.shape+(1,))


# Test model
output = model.predict([color_me, color_me_embed])
output = output * 128

# Output colorizations
for i in range(len(output)):
    cur = np.zeros((256, 256, 3))
    cur[:,:,0] = color_me[i][:,:,0]
    cur[:,:,1:] = output[i]
    imsave("result/img_"+str(i)+".png", lab2rgb(cur))

  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  .format(dtypeobj_in, dtypeobj_out))
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
