# Model training (X4 upscaling)
The X4 model is built on top of the previous X2 model by cascading a new X2 model to it. When the training of the previous X2 model is finished, the weights of the X2 model are kept as the first part of the X4 model, which only needs to learn the mapping from X2 to X4. In this notebook, 12000 training patches are extracted to train the X4 model at scale 4.

In [1]:
from tensorflow import config

gpu_devices = config.experimental.list_physical_devices('GPU')

for device in gpu_devices: config.experimental.set_memory_growth(device, True)

## Load training images from directory

In [2]:
# load training images from directory

import os
import numpy as np
from PIL import Image

# LR_train_path = './datasets/DIV2K_train_LR_unknown/X4/'
LR_train_path = './datasets/DIV2K_train_LR_bicubic/X4/'
HR_train_path = './datasets/DIV2K_train_HR/'

LR_train_imgs = []
HR_train_imgs = []

for path, subpath, files in os.walk(LR_train_path):
    files.sort()
    for i in files:
        if i == '.DS_Store':
            continue
        img = Image.open(LR_train_path + i)
        LR_train_imgs.append(np.asarray(img))

for path, subpath, files in os.walk(HR_train_path):
    files.sort()
    for i in files:
        if i == '.DS_Store':
            continue
        img = Image.open(HR_train_path + i)
        HR_train_imgs.append(np.asarray(img)) 

print(len(LR_train_imgs))
print(len(HR_train_imgs))

(12000, 48, 48, 3)
(12000, 192, 192, 3)
(48, 48, 3)
(192, 192, 3)


## Preprocess (patch extraction + normalization)

In [3]:
# randomly extract pathches from training images (X4 upscaling)

from extract_patches import *

patch_height = 48
patch_width = 48
patch_num = 12000
up_scale = 4

LR_patch_train, HR_patch_train = train_patch(LR_train_imgs, HR_train_imgs, patch_height, patch_width, patch_num, up_scale)


print(LR_patch_train.shape)
print(HR_patch_train.shape)

(12000, 48, 48, 3)
(12000, 192, 192, 3)


In [None]:
# normaliza imgs from 0~255 to 0~1

def normalize(imgs):
    return imgs / 255

HR_patch_train = normalize(HR_patch_train)
LR_patch_train = normalize(LR_patch_train)

print(LR_patch_train.shape)
print(HR_patch_train.shape)

## Load X2 model

In [4]:
# define the perceptual_loss_x2 so that the X2 model can be loaded

from keras.applications.vgg19 import VGG19
from keras.layers import Input, Lambda
from keras.models import Model
import keras


def get_VGG19(input_size):  
    vgg_input = Input(input_size)
    vgg = VGG19(include_top=False, input_tensor=vgg_inp)
    for l in vgg.layers: 
        l.trainable = False
    vgg_output = vgg.get_layer('block2_conv2').output 
    
    return vgg_input, vgg_output

def perceptual_loss_x2(y_true, y_pred):
    
    y_t = vgg_content1(y_true)
    y_p = vgg_content1(y_pred)
    loss = keras.losses.mean_squared_error(y_t, y_p)
    
    return loss

vgg_input, vgg_output = get_VGG19(input_size=(96,96,3))
vgg_content1 = Model(vgg_input, vgg_output)
#vgg_content.summary()

Using TensorFlow backend.


In [5]:
# load trained X2 model

import tensorflow as tf
from keras.models import load_model

x2_model = load_model('./models/final3_perceptual_unknown_4848_12000_subpixel_X2.h5', 
                      custom_objects={'tf': tf, 'perceptual_loss_x2':perceptual_loss_x2})
#x2_model.summary()

## Build Network Architecture

In [6]:
# define subpixel layer

import tensorflow as tf
from keras.layers import Lambda

def pixelshuffler(input_shape, batch_size, scale=2):
    def subpixel_shape(input_shape=input_shape, batch_size=batch_size):
        dim = [batch_size,
               input_shape[1] * scale,
               input_shape[2] * scale,
               int(input_shape[3]/ (scale ** 2))]

        output_shape = tuple(dim)

        return output_shape

    def pixelshuffle_upscale(x):
        return tf.nn.depth_to_space(input=x, block_size=scale)

    return Lambda(function=pixelshuffle_upscale, output_shape=subpixel_shape)

In [7]:
# define model architecture

from keras.models import Model, Sequential
from keras.layers import PReLU, Input, Conv2D, add
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint

def res_block(inputs):
    x = Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1), padding='same')(inputs)
    x = PReLU(shared_axes=[1, 2])(x)
    x = Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
    return add([x, inputs])


def final_model(patch_height, patch_width, channel, upscale=2):
    # conv and then upsample
    
    inputs = Input(shape=(patch_height, patch_width, channel))
    x_init = Conv2D(filters=64, kernel_size=(9, 9), strides=(1, 1), padding='same')(inputs)
    x = PReLU(shared_axes=[1, 2])(x_init)
    
    # residual block
    for i in range(8):
        x = res_block(x)
        
    x = Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
    x = add([x, x_init])
    
    # sub-pixel up_block    
    x = Conv2D(filters=256, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
    x = pixelshuffler(input_shape=(96,96,3), batch_size=4, scale=upscale)(x)
    x = PReLU(shared_axes=[1, 2])(x)
    
    # output_block
    output = Conv2D(filters=3, kernel_size=(9, 9), strides=(1, 1), padding='same')(x)
    output = Conv2D(3, (1, 1), activation='sigmoid',padding='same')(output)
    
    model = Model(inputs=inputs, outputs=output)
    
    return model

In [8]:
# define the perceptual_loss_x4 for compare X4 model output of size (192, 192)

from keras.applications.vgg19 import VGG19
from keras.layers import Input
from keras.layers import Lambda
import keras

def get_VGG19(input_size):
    
    vgg_inp = Input(input_size)
    vgg = VGG19(include_top=False, input_tensor=vgg_inp)
    for l in vgg.layers: 
        l.trainable = False
    vgg_outp = vgg.get_layer('block2_conv2').output 
    
    return vgg_inp, vgg_outp

def perceptual_loss_x4(y_true, y_pred):
    
    y_t = vgg_content2(y_true)
    y_p = vgg_content2(y_pred)
    loss = keras.losses.mean_squared_error(y_t, y_p)
    
    return loss

vgg_input, vgg_output = get_VGG19(input_size=(192,192,3))
vgg_content2 = Model(vgg_input, vgg_output)
#vgg_content.summary()

In [9]:
# define the latter part of the integrated model

x4_model = final_model(96, 96, 3)
#x4_model.summary()

In [10]:
# cascade two models to achieve progressive super-resolution
# set the first part of model not trainable

def integrated_network(base_model1, base_model2):
    
    base_model1.trainable = False

    add_model = Sequential()
    add_model.add(base_model1)
    add_model.add(base_model2)
    
    return add_model

model = integrated_network(x2_model, x4_model)
model.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
model_2 (Model)              (4, 96, 96, 3)            807311    
_________________________________________________________________
model_3 (Model)              (4, 192, 192, 3)          807311    
Total params: 1,614,622
Trainable params: 807,311
Non-trainable params: 807,311
_________________________________________________________________


In [12]:
# train bicubic_X4 model

model.compile(optimizer=Adam(lr=1e-4), loss=perceptual_loss_x4, metrics=['accuracy'])
checkpointer = ModelCheckpoint(filepath='./model_and_history/final3_perceptual_bi_4848_12000_subpixel_X4.h5', verbose=1, 
                               monitor='val_loss', mode='auto', save_best_only=True)

history = model.fit(LR_patch_train, HR_patch_train, epochs=20, verbose=1, 
                    batch_size=4, validation_split=0.2,
                    callbacks=[checkpointer]
                   )

Train on 9600 samples, validate on 2400 samples
Epoch 1/20

Epoch 00001: val_loss improved from inf to 1.21457, saving model to ./model_and_history/final3_perceptual_bi_4848_12000_subpixel_X4.h5
Epoch 2/20

Epoch 00002: val_loss improved from 1.21457 to 1.18041, saving model to ./model_and_history/final3_perceptual_bi_4848_12000_subpixel_X4.h5
Epoch 3/20

Epoch 00003: val_loss improved from 1.18041 to 1.15921, saving model to ./model_and_history/final3_perceptual_bi_4848_12000_subpixel_X4.h5
Epoch 4/20

Epoch 00004: val_loss improved from 1.15921 to 1.14045, saving model to ./model_and_history/final3_perceptual_bi_4848_12000_subpixel_X4.h5
Epoch 5/20

Epoch 00005: val_loss improved from 1.14045 to 1.13569, saving model to ./model_and_history/final3_perceptual_bi_4848_12000_subpixel_X4.h5
Epoch 6/20

Epoch 00006: val_loss improved from 1.13569 to 1.12697, saving model to ./model_and_history/final3_perceptual_bi_4848_12000_subpixel_X4.h5
Epoch 7/20

Epoch 00007: val_loss improved from 1.

In [11]:
# train unknown X4 model

model.compile(optimizer=Adam(lr=1e-4), loss=perceptual_loss_x4, metrics=['accuracy'])
checkpointer = ModelCheckpoint(filepath='./model_and_history/final3_perceptual_unknown_4848_12000_subpixel_X4.h5', verbose=1, 
                               monitor='val_loss', mode='auto', save_best_only=True)

history = model.fit(LR_patch_train, HR_patch_train, epochs=20, verbose=1, 
                    batch_size=4, validation_split=0.2,
                    callbacks=[checkpointer]
                   )

Train on 9600 samples, validate on 2400 samples
Epoch 1/20

Epoch 00001: val_loss improved from inf to 1.79303, saving model to ./model_and_history/final3_perceptual_unknown_4848_12000_subpixel_X4.h5
Epoch 2/20

Epoch 00002: val_loss improved from 1.79303 to 1.69226, saving model to ./model_and_history/final3_perceptual_unknown_4848_12000_subpixel_X4.h5
Epoch 3/20

Epoch 00003: val_loss improved from 1.69226 to 1.63231, saving model to ./model_and_history/final3_perceptual_unknown_4848_12000_subpixel_X4.h5
Epoch 4/20

Epoch 00004: val_loss improved from 1.63231 to 1.60518, saving model to ./model_and_history/final3_perceptual_unknown_4848_12000_subpixel_X4.h5
Epoch 5/20

Epoch 00005: val_loss improved from 1.60518 to 1.56887, saving model to ./model_and_history/final3_perceptual_unknown_4848_12000_subpixel_X4.h5
Epoch 6/20

Epoch 00006: val_loss improved from 1.56887 to 1.56174, saving model to ./model_and_history/final3_perceptual_unknown_4848_12000_subpixel_X4.h5
Epoch 7/20

Epoch 00

In [12]:
# save model and history
import pickle

with open('./model_and_history/final3_perceptual_unknown_4848_12000_subpixel_X4.pkl','wb') as f:
    pickle.dump(history.history, f)
