## Use content loss to create a super-resolution network

In [1]:
%matplotlib inline

import importlib
import sys
sys.path.insert(0, '../util')
import utils; importlib.reload(utils)
from utils import *

Using TensorFlow backend.


In [2]:
dpath = './data/'

In [3]:
arr_lr = bcolz.open(dpath+'trn_resized_72.bc')[:]
arr_hr = bcolz.open(dpath+'trn_resized_288.bc')[:]

In [4]:
parms = {'verbose': 0, 'callbacks': [TQDMNotebookCallback(leave_inner=True)]}

In [5]:
def conv_block(x, filters, size, stride=(2,2), mode='same', act=True):
    x = Convolution2D(filters, size, size, subsample=stride, border_mode=mode)(x)
    x = BatchNormalization(mode=2)(x)
    return Activation('relu')(x) if act else x

def res_block(ip, nf=64):
    x = conv_block(ip, nf, 3, (1,1))
    x = conv_block(x, nf, 3, (1,1), act=False)
    return merge([x, ip], mode='sum')

def deconv_block(x, filters, size, shape, stride=(2,2)):
    x = Deconvolution2D(filters, size, size, subsample=stride, 
        border_mode='same', output_shape=(None,)+shape)(x)
    x = BatchNormalization(mode=2)(x)
    return Activation('relu')(x)

def up_block(x, filters, size):
    x = keras.layers.UpSampling2D()(x)
    x = Convolution2D(filters, size, size, border_mode='same')(x)
    x = BatchNormalization(mode=2)(x)
    return Activation('relu')(x)


In [6]:
# This model here is using the previously defined blocks to encode a low resolution image and 
# then upsample it to match the same image in high resolution.

inp=Input(arr_lr.shape[1:])
x=conv_block(inp, 64, 9, (1,1))
for i in range(4): x=res_block(x)
x=up_block(x, 64, 3)
x=up_block(x, 64, 3)
x=Convolution2D(3, 9, 9, activation='tanh', border_mode='same')(x)
outp=Lambda(lambda x: (x+1)*127.5)(x)

In [7]:
# shp = 
shp = outp.shape[1:]

preproc = lambda x: (x-imagenet_mean)[:,:,:,::-1]
deproc = lambda x,s: np.clip(x.reshape(s)[:, :, :, ::-1] + imagenet_mean, 0, 255)


The method of training this network is almost exactly the same as training the pixels from our previous implementations. The idea here is we're going to feed two images to Vgg16 and compare their convolutional outputs at some layer. These two images are the target image (which in our case is the same as the original but at higher resolution), and the output of the previous network we just defined, which we hope will learn to output a high resolution image. 

The key then is to train this other network to produce an image that minimizes the loss between the outputs of some convolutional layer in Vgg16 (which the paper refers to as "perceptual loss"). In doing so, we are able to train a network that can upsample an image and recreate the higher resolution details.

In [8]:
vgg_inp=Input(shp)
vgg= VGG16(include_top=False, input_tensor=Lambda(preproc)(vgg_inp))

In [9]:
# Since we only want to learn the "upsampling network", and are just using VGG to calculate the loss function, 
# we set the Vgg layers to not be trainable.
for l in vgg.layers: l.trainable=False

An important difference in training for super resolution is the loss function. We use what's known as a perceptual loss function (which is simply the content loss for some layer).

In [10]:
def get_outp(m, ln): return m.get_layer(f'block{ln}_conv1').output

vgg_content = Model(vgg_inp, [get_outp(vgg, o) for o in [1,2,3]])

vgg1 = vgg_content(vgg_inp)
vgg2 = vgg_content(outp)

In [11]:
def mean_sqr_b(diff): 
    dims = list(range(1,K.ndim(diff)))
    return K.expand_dims(K.sqrt(K.mean(diff**2, dims)), 0)

In [12]:
w=[0.1, 0.8, 0.1]
def content_fn(x): 
    res = 0; n=len(w)
    for i in range(n): res += mean_sqr_b(x[i]-x[i+n]) * w[i]
    return res

In [13]:
m_sr = Model([inp, vgg_inp], Lambda(content_fn)(vgg1+vgg2))
targ = np.zeros((arr_hr.shape[0], 1))

Finally we compile this chain of models and we can pass it the original low resolution image as well as the high resolution to train on. We also define a zero vector as a target parameter, which is a necessary parameter when calling fit on a keras model.

In [15]:
m_sr.compile('adam', 'mse')
m_sr.summary()

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_1 (InputLayer)             (None, 72, 72, 3)     0                                            
____________________________________________________________________________________________________
convolution2d_1 (Convolution2D)  (None, 72, 72, 64)    15616       input_1[0][0]                    
____________________________________________________________________________________________________
batchnormalization_1 (BatchNorma (None, 72, 72, 64)    256         convolution2d_1[0][0]            
____________________________________________________________________________________________________
activation_1 (Activation)        (None, 72, 72, 64)    0           batchnormalization_1[0][0]       
___________________________________________________________________________________________

In [None]:
m_sr.fit([arr_lr, arr_hr], targ, 8, 2, **parms)

A Jupyter Widget

A Jupyter Widget

In [None]:
K.get_value(m_sr.optimizer.lr)


In [None]:
# We use learning rate annealing to get a better fit.
print(m_sr.optimizer.lr)
K.set_value(m_sr.optimizer.lr, 1e-4)
m_sr.fit([arr_lr, arr_hr], targ, 16, 1, **parms)