In [1]:
import tarfile, tensorflow as tf, matplotlib.pyplot as plt, operator, random, pickle, glob, os, bcolz
import numpy as np
from keras.layers import *
from keras.applications.vgg16 import VGG16
from keras.models import Model, Sequential
from bcolz_array_iterator import BcolzArrayIterator

Using TensorFlow backend.


In [2]:
rn_mean = np.array([123.68, 116.779, 103.939], dtype=np.float32)
preproc = lambda x: (x - rn_mean)[:, :, :, ::-1]

In [3]:
deproc = lambda x,s: np.clip(x.reshape(s)[:, :, :, ::-1] + rn_mean, 0, 255)

In [4]:
arr_lr = bcolz.open('trn_resized_72.bc')
arr_hr = bcolz.open('trn_resized_288.bc')

In [5]:
print(arr_hr.shape)

(19439, 288, 288, 3)


In [6]:
def conv_block(x, filters, size, stride=(2,2), mode='same', act=True):
    x = Convolution2D(filters, (size, size), strides=stride, padding=mode)(x)
    x = BatchNormalization()(x)
    return Activation('relu')(x) if act else x

In [7]:
def res_block(ip, nf=64):
    x = conv_block(ip, nf, 3, (1,1))
    x = conv_block(x, nf, 3, (1,1), act=False)
    return Add()([x, ip])

In [8]:
def up_block(x, filters, size):
    x = UpSampling2D()(x)
    x = Convolution2D(filters, (size, size), padding='same')(x)
    x = BatchNormalization()(x)
    return Activation('relu')(x)

In [9]:
def get_model(arr):
    inp=Input(arr.shape[1:])
    x=conv_block(inp, 64, 9, (1,1))
    for i in range(4): x=res_block(x)
    x=up_block(x, 64, 3)
    x=up_block(x, 64, 3)
    x=Conv2D(3, (9, 9), activation="tanh", padding="same")(x)
    outp=Lambda(lambda x: (x+1)*127.5)(x)
    return inp,outp

In [10]:
inp,outp=get_model(arr_lr)

In [11]:
shp = arr_hr.shape[1:]

vgg_inp=Input(shp)
vgg= VGG16(include_top=False,weights='imagenet', input_tensor=Lambda(preproc)(vgg_inp))
for l in vgg.layers: l.trainable=False

In [12]:
def get_outp(m, ln): return m.get_layer(f'block{ln}_conv2').output
vgg_content = Model(vgg_inp, [get_outp(vgg, o) for o in [1,2,3]])
vgg1 = vgg_content(vgg_inp)
vgg2 = vgg_content(outp)

In [13]:
def mean_sqr_b(diff): 
    dims = list(range(1,K.ndim(diff)))
    return K.expand_dims(K.sqrt(K.mean(diff**2, dims)), 0)

In [14]:
w=[0.1, 0.8, 0.1]
def content_fn(x): 
    res = 0; n=len(w)
    for i in range(n): res += mean_sqr_b(x[i]-x[i+n]) * w[i]
    return res

In [15]:
m_sr = Model([inp, vgg_inp], Lambda(content_fn)(vgg1+vgg2))
m_sr.compile('adam', 'mae')

In [16]:
def train(bs, niter=10):
    targ = np.zeros((bs, 1))
    bc = BcolzArrayIterator(arr_hr, arr_lr, batch_size=bs)
    for i in range(niter):
        hr,lr = next(bc)
        m_sr.train_on_batch([lr[:bs], hr[:bs]], targ)

In [None]:
print(len(arr_hr))

19439


In [None]:
%time train(64, 100)