In [2]:
from keras.layers import Conv2D, UpSampling2D, InputLayer, Conv2DTranspose
from keras.layers import Activation, Dense, Dropout, Flatten
from keras.layers.normalization import BatchNormalization
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.models import Sequential
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb, rgb2gray, xyz2lab
import numpy as np
import os
import random
import tensorflow as tf
import pickle

## Preprocessing

In [2]:
try:
    with open('X_tundra.pkl', 'rb') as handle:
        X = pickle.load(handle)
except:
    path = './unzipped_images/tundra/'
    X = []
    for filename in os.listdir(path):
        X.append(img_to_array(load_img(path + filename))/255)
    X = np.array(X)
    with open('X_tundra.pkl', 'wb') as handle:
        pickle.dump(X, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [3]:
Xtrain = X[:4500]
Xval = X[4500:]

In [8]:
train_datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

batch_size = 32
def train_generator(batch_size):
    for batch in train_datagen.flow(Xtrain, batch_size=batch_size):
        lab_batch = rgb2lab(batch)
        X_batch = lab_batch[:,:,:,0] / 100 # scale to [0, 1] bc neural networks prefer small input values
        Y_batch = lab_batch[:,:,:,1:] / 128
        yield (X_batch.reshape(X_batch.shape+(1,)), Y_batch)

In [9]:
val_datagen = ImageDataGenerator()

def val_generator(batch_size):
    for batch in val_datagen.flow(Xtrain, batch_size=batch_size):
        lab_batch = rgb2lab(batch)
        X_batch = lab_batch[:,:,:,0] / 100 # scale to [0, 1] bc neural networks prefer small input values
        Y_batch = lab_batch[:,:,:,1:] / 128
        yield (X_batch.reshape(X_batch.shape+(1,)), Y_batch)

## Model

In [10]:
def VGG_like(weights_path=None):
    model = Sequential()
    model.add(InputLayer(input_shape=(256, 256, 1)))
    model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
    model.add(Conv2D(64, (3, 3), activation='relu', padding='same', strides=2))
    model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
    model.add(Conv2D(128, (3, 3), activation='relu', padding='same', strides=2))
    model.add(Conv2D(256, (3, 3), activation='relu', padding='same'))
    model.add(Conv2D(256, (3, 3), activation='relu', padding='same', strides=2))
    model.add(Conv2D(512, (3, 3), activation='relu', padding='same'))
    model.add(UpSampling2D((2, 2)))
    model.add(Conv2D(256, (3, 3), activation='relu', padding='same'))
    model.add(UpSampling2D((2, 2)))
    model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
    model.add(UpSampling2D((2, 2)))
    model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
    model.add(Conv2D(2, (3, 3), activation='tanh', padding='same'))
    
    if weights_path:
        model.load_weights(weights_path)
    
    return model

In [11]:
model = VGG_like()
model.compile(optimizer='adam', loss='mse')

## Train

In [None]:
hist = model.fit_generator(
    train_generator(batch_size),
    steps_per_epoch=len(Xtrain)/batch_size,
    epochs=10000,
    validation_data=val_generator(batch_size),
    validation_steps=len(Xval)/batch_size,
    callbacks = [EarlyStopping(monitor='val_loss', patience=20),
             ModelCheckpoint(filepath='best_model_small.h5', monitor='val_loss', save_best_only=True)]
)

Epoch 1/10000
Epoch 2/10000
