In [18]:
from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.preprocessing import image
from keras.models import Model, load_model
from keras.layers import Dense, GlobalAveragePooling2D, Input, Conv2D, UpSampling2D, Reshape, RepeatVector, concatenate
from keras import backend as K
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb, rgb2gray, gray2rgb
from skimage.transform import resize
import numpy as np
import os
import random
import tensorflow as tf
import pickle
import matplotlib.pyplot as plt

In [19]:
inception = load_model('inception_resnet_v2_model.h5')



In [20]:
inception._make_predict_function()

## Preprocessing

In [21]:
# takes pixels values between [0, 255] and scales values between [-1,1]
def preprocess_input(im):
    return 2*(im/255.0)-1.0

In [22]:
# takes pixel values between [0, 255] and returns gray image of dim-3 between [0,1]
def turn_gray(im):
    return gray2rgb(rgb2lab(im/255)[:,:,0]/100)

In [23]:
path = './unzipped_images/tundra/'
X = []
for filename in os.listdir(path):
    X.append(img_to_array(load_img(path + filename)))
X = np.array(X)

In [24]:
Xtrain = X[:4500]
Xval = X[4500:]

In [35]:
def embed(images):
    images = preprocess_input(turn_gray(images)*255)
    images = resize(images, (images.shape[0], 299, 299, 3))
    return inception.predict(images)

In [26]:
train_datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

batch_size = 32
def train_generator(batch_size):
    for batch in train_datagen.flow(Xtrain, batch_size=batch_size):
        lab_batch = rgb2lab(batch/255)
        X_batch = lab_batch[:,:,:,0] / 100 # scale to [0, 1] bc neural networks prefer small input values
        Y_batch = lab_batch[:,:,:,1:] / 128
        yield ([X_batch.reshape(X_batch.shape+(1,)), embed(batch)], Y_batch)

In [27]:
val_datagen = ImageDataGenerator()

def val_generator(batch_size):
    for batch in val_datagen.flow(Xtrain, batch_size=batch_size):
        lab_batch = rgb2lab(batch/255)
        X_batch = lab_batch[:,:,:,0] / 100 # scale to [0, 1] bc neural networks prefer small input values
        Y_batch = lab_batch[:,:,:,1:] / 128
        yield ([X_batch.reshape(X_batch.shape+(1,)), embed(batch)], Y_batch)

## Model

### Encoder

In [28]:
encoder_input = Input(shape=(256, 256, 1))
encoder_output = Conv2D(64, (3, 3), activation='relu', padding='same', strides=2)(encoder_input)
encoder_output = Conv2D(128, (3, 3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(128, (3, 3), activation='relu', padding='same', strides=2)(encoder_output)
encoder_output = Conv2D(256, (3, 3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(256, (3, 3), activation='relu', padding='same', strides=2)(encoder_output)
encoder_output = Conv2D(512, (3, 3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(512, (3, 3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(256, (3, 3), activation='relu', padding='same')(encoder_output)

### Embedder

In [29]:
embedder_input = Input(shape=(1000,))
embedder_output = RepeatVector(1024)(embedder_input)
embedder_output = Reshape([32, 32, 1000])(embedder_output)

### Fusion

In [30]:
fusion_input = concatenate([encoder_output, embedder_output], axis=-1)
fusion_output = Conv2D(256, (1, 1), activation='relu', padding='same')(fusion_input)

### Decoder

In [31]:
decoder_output = Conv2D(256, (3, 3), activation='relu', padding='same')(fusion_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)
decoder_output = Conv2D(128, (3, 3), activation='relu', padding='same')(decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)
decoder_output = Conv2D(64, (3, 3), activation='relu', padding='same')(decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)
decoder_output = Conv2D(32, (3, 3), activation='relu', padding='same')(decoder_output)
decoder_output = Conv2D(2, (3, 3), activation='tanh', padding='same')(decoder_output)

In [32]:
model = Model(inputs=[encoder_input, embedder_input], outputs=decoder_output)

In [33]:
model.compile(optimizer='adam', loss='mse')

## Train

In [34]:
hist = model.fit_generator(
    train_generator(batch_size),
    steps_per_epoch=len(Xtrain)/batch_size,
    epochs=10000,
    validation_data=val_generator(batch_size),
    validation_steps=len(Xval)/batch_size,
    callbacks = [EarlyStopping(monitor='val_loss', patience=20),
             ModelCheckpoint(filepath='tundra_final_model.h5', monitor='val_loss', save_best_only=True)]
)


Epoch 1/10000


  warn("The default mode, 'constant', will be changed to 'reflect' in "
  warn("Anti-aliasing will be enabled by default in skimage 0.15 to "


(8, 299, 299, 3)
(8, 299, 299, 3)
(8, 299, 299, 3)
(8, 299, 299, 3)
(8, 299, 299, 3)
(8, 299, 299, 3)
  2/562 [..............................] - ETA: 1:52:39 - loss: 0.0171(8, 299, 299, 3)
  3/562 [..............................] - ETA: 1:16:21 - loss: 0.0367(8, 299, 299, 3)
  4/562 [..............................] - ETA: 58:52 - loss: 0.0308  (8, 299, 299, 3)
(8, 299, 299, 3)
  5/562 [..............................] - ETA: 49:30 - loss: 0.0268(8, 299, 299, 3)
(8, 299, 299, 3)
(8, 299, 299, 3)
  6/562 [..............................] - ETA: 43:35 - loss: 0.0243(8, 299, 299, 3)
(8, 299, 299, 3)
  7/562 [..............................] - ETA: 38:58 - loss: 0.0225(8, 299, 299, 3)
(8, 299, 299, 3)
  8/562 [..............................] - ETA: 35:26 - loss: 0.0213(8, 299, 299, 3)
(8, 299, 299, 3)
  9/562 [..............................] - ETA: 32:44 - loss: 0.0207(8, 299, 299, 3)
 10/562 [..............................] - ETA: 30:30 - loss: 0.0201(8, 299, 299, 3)
(8, 299, 299, 3)
 11/562 

KeyboardInterrupt: 

(8, 299, 299, 3)


In [None]:
model.summary()