# Model training (X2 upscaling)
After deciding the final network architecture, more patches and epochs are used to train a better model. In this notebook, 12000 training patches are extracted to train the X2 model at scale 2.

In [1]:
from tensorflow import config

gpu_devices = config.experimental.list_physical_devices('GPU')

for device in gpu_devices: config.experimental.set_memory_growth(device, True)

## Load training images from directory

In [None]:
# load training images from directory

import os
import numpy as np
from PIL import Image

# LR_train_path = './datasets/DIV2K_train_LR_unknown/X2/'
LR_train_path = './datasets/DIV2K_train_LR_bicubic/X2/'
HR_train_path = './datasets/DIV2K_train_HR/'

LR_train_imgs = []
HR_train_imgs = []

for path, subpath, files in os.walk(LR_train_path):
    files.sort()
    for i in files:
        if i == '.DS_Store':
            continue
        img = Image.open(LR_train_path + i)
        LR_train_imgs.append(np.asarray(img))

for path, subpath, files in os.walk(HR_train_path):
    files.sort()
    for i in files:
        if i == '.DS_Store':
            continue
        img = Image.open(HR_train_path + i)
        HR_train_imgs.append(np.asarray(img)) 

print(len(LR_train_imgs))
print(len(HR_train_imgs))

## Preprocess (patch extraction + normalization)

In [2]:
# randomly extract pathches from training images (X2 upscaling)

from extract_patches import *

patch_height = 48
patch_width = 48
patch_num = 12000
up_scale = 2

LR_patch_train, HR_patch_train = train_patch(LR_train_imgs, HR_train_imgs, patch_height, patch_width, patch_num, up_scale)


print(LR_patch_train.shape)
print(HR_patch_train.shape)

(12000, 48, 48, 3)
(12000, 96, 96, 3)
(48, 48, 3)
(96, 96, 3)


In [3]:
# normaliza imgs from 0~255 to 0~1

def normalize(imgs):
    return imgs / 255

HR_patch_train = normalize(HR_patch_train)
LR_patch_train = normalize(LR_patch_train)

print(LR_patch_train.shape)
print(HR_patch_train.shape)

(12000, 48, 48, 3)
(12000, 96, 96, 3)


## Build Network Architecture

In [4]:
# define subpixel layer

import tensorflow as tf
from keras.layers import Lambda

def pixelshuffler(input_shape, batch_size, scale=2):
    def subpixel_shape(input_shape=input_shape, batch_size=batch_size):
        dim = [batch_size,
               input_shape[1] * scale,
               input_shape[2] * scale,
               int(input_shape[3]/ (scale ** 2))]

        output_shape = tuple(dim)

        return output_shape

    def pixelshuffle_upscale(x):
        return tf.nn.depth_to_space(input=x, block_size=scale)

    return Lambda(function=pixelshuffle_upscale, output_shape=subpixel_shape)

Using TensorFlow backend.


In [5]:
# define model architecture

from keras.models import Model, Sequential
from keras.layers import PReLU, Input, Conv2D, add
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint

def res_block(inputs):
    x = Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1), padding='same')(inputs)
    x = PReLU(shared_axes=[1, 2])(x)
    x = Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
    return add([x, inputs])


def final_model(patch_height, patch_width, channel, upscale=2):

    inputs = Input(shape=(patch_height, patch_width, channel))
    x_init = Conv2D(filters=64, kernel_size=(9, 9), strides=(1, 1), padding='same')(inputs)
    x = PReLU(shared_axes=[1, 2])(x_init)
    
    # residual block
    for i in range(8):
        x = res_block(x)
        
    x = Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
    x = add([x, x_init])
    
    # sub-pixel up_block
    x = Conv2D(filters=256, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
    x = pixelshuffler(input_shape=(48,48,3), batch_size=4, scale=upscale)(x)
    x = PReLU(shared_axes=[1, 2])(x)
    
    # output_block
    output = Conv2D(filters=3, kernel_size=(9, 9), strides=(1, 1), padding='same')(x)
    output = Conv2D(3, (1, 1), activation='sigmoid',padding='same')(output)
    
    model = Model(inputs=inputs, outputs=output)
    
    return model

In [6]:
# define perceptual loss based on the first 5 layers of VGG19 model
# compare the X2 model output of size (96, 96)

from keras.applications.vgg19 import VGG19
from keras.layers import Input, Lambda
import keras

# get VGG network
def get_VGG19(input_size):
    
    vgg_input = Input(input_size)
    vgg = VGG19(include_top=False, input_tensor=vgg_input)
    for l in vgg.layers: 
        l.trainable = False
    vgg_output = vgg.get_layer('block2_conv2').output
    
    return vgg_input, vgg_output

def perceptual_loss_x2(y_true, y_pred):
    
    y_t = vgg_content(y_true)
    y_p = vgg_content(y_pred)
    loss = keras.losses.mean_squared_error(y_t, y_p)
    
    return loss

# VGG input = model output
vgg_input, vgg_output = get_VGG19(input_size=(96,96,3))
vgg_content = Model(vgg_input, vgg_output)
#vgg_content.summary()

In [13]:
# train bicubic X2 model

model = final_model(48, 48, 3)
model.compile(optimizer=Adam(lr=1e-4), loss=perceptual_loss_x2, metrics=['accuracy'])
checkpointer = ModelCheckpoint(filepath='./model_and_history/final3_perceptual_bi_4848_12000_subpixel_X2.h5', verbose=1, 
                               monitor='val_loss', mode='auto', save_best_only=True)

history = model.fit(LR_patch_train, HR_patch_train, epochs=20, verbose=1, 
                    batch_size=4, validation_split=0.2,
                    callbacks=[checkpointer]
                   )

Train on 9600 samples, validate on 2400 samples
Epoch 1/20

Epoch 00001: val_loss improved from inf to 0.34235, saving model to ./model_and_history/final3_perceptual_bi_4848_12000_subpixel_X2.h5
Epoch 2/20

Epoch 00002: val_loss improved from 0.34235 to 0.31135, saving model to ./model_and_history/final3_perceptual_bi_4848_12000_subpixel_X2.h5
Epoch 3/20

Epoch 00003: val_loss improved from 0.31135 to 0.27439, saving model to ./model_and_history/final3_perceptual_bi_4848_12000_subpixel_X2.h5
Epoch 4/20

Epoch 00004: val_loss improved from 0.27439 to 0.26521, saving model to ./model_and_history/final3_perceptual_bi_4848_12000_subpixel_X2.h5
Epoch 5/20

Epoch 00005: val_loss improved from 0.26521 to 0.24811, saving model to ./model_and_history/final3_perceptual_bi_4848_12000_subpixel_X2.h5
Epoch 6/20

Epoch 00006: val_loss improved from 0.24811 to 0.24567, saving model to ./model_and_history/final3_perceptual_bi_4848_12000_subpixel_X2.h5
Epoch 7/20

Epoch 00007: val_loss improved from 0.

In [7]:
# train unknown X2 model

model = final_model(48, 48, 3)
model.compile(optimizer=Adam(lr=1e-4), loss=perceptual_loss_x2, metrics=['accuracy'])
checkpointer = ModelCheckpoint(filepath='./model_and_history/final3_perceptual_unknown_4848_12000_subpixel_X2.h5', verbose=1, 
                               monitor='val_loss', mode='auto', save_best_only=True)

history = model.fit(LR_patch_train, HR_patch_train, epochs=20, verbose=1, 
                    batch_size=4, validation_split=0.2,
                    callbacks=[checkpointer]
                   )

Train on 9600 samples, validate on 2400 samples
Epoch 1/20

Epoch 00001: val_loss improved from inf to 0.48318, saving model to ./model_and_history/final3_perceptual_unknown_4848_12000_subpixel_X2.h5
Epoch 2/20

Epoch 00002: val_loss improved from 0.48318 to 0.42897, saving model to ./model_and_history/final3_perceptual_unknown_4848_12000_subpixel_X2.h5
Epoch 3/20

Epoch 00003: val_loss improved from 0.42897 to 0.39237, saving model to ./model_and_history/final3_perceptual_unknown_4848_12000_subpixel_X2.h5
Epoch 4/20

Epoch 00004: val_loss improved from 0.39237 to 0.37328, saving model to ./model_and_history/final3_perceptual_unknown_4848_12000_subpixel_X2.h5
Epoch 5/20

Epoch 00005: val_loss improved from 0.37328 to 0.36112, saving model to ./model_and_history/final3_perceptual_unknown_4848_12000_subpixel_X2.h5
Epoch 6/20

Epoch 00006: val_loss improved from 0.36112 to 0.34487, saving model to ./model_and_history/final3_perceptual_unknown_4848_12000_subpixel_X2.h5
Epoch 7/20

Epoch 00

In [8]:
# save model and history
import pickle

with open('./model_and_history/final3_perceptual_unknown_4848_12000_subpixel_X2.pkl','wb') as f:
    pickle.dump(history.history, f)

#model.save_weights('./model_and_history/final_mse_unknown_4848_12000_subpixel_X4 _weights.hdf5')
#model.save('./model_and_history/perceptual_baseline1_3232_model.h5')