In [None]:
import numpy as np
import pandas as pd
import keras
import tensorflow as tf
from sklearn.model_selection import train_test_split
from time import time
from skimage import io

from numpy.random import seed
seed(2)
from tensorflow import set_random_seed
set_random_seed(2)

import matplotlib.pyplot as plt
from scipy.misc import imread
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
import h5py
import warnings
warnings.filterwarnings('ignore')

In [None]:

"""
Args:
    base: base folder name in which files are located
    SRF: image_SRF folder to select
    n_ids: number of ids to generated from 1...n_ids
Returns:
    A numpy character array of filenames
"""
def gen_filenames_labels(base, SRF, n_ids):
    basepath = base + SRF
    
    dataset_arr = np.chararray(shape=(n_ids, 2), itemsize=128)
    
    for i in range(1, n_ids+1):
        Y = basepath + 'img_{:03}_SRF_2_HR.png'.format(i)
        X = basepath + 'img_{:03}_SRF_2_bicubic.png'.format(i)
        X_dash = basepath + 'img_{:03}_SRF_2_LR.png'.format(i)
        
        dataset_arr[i-1, 0] = X_dash
        dataset_arr[i-1, 1] = Y
    
    return dataset_arr

dataset = gen_filenames_labels('train/', 'image_SRF_2/', 100)
val_dataset = gen_filenames_labels('Set14/', 'image_SRF_2/', 5)

X_train, X_test, y_train, y_test = dataset[:, 0], val_dataset[:, 0], dataset[:, 1], val_dataset[:, 1]
print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)


In [None]:
"""
Class to store parameters which would be required to build and train model
"""

class ModelParameters:
    def __init__(self):
        self.batch_size = 256
        self.input_height = 32
        self.input_width = 32
        self.label_height = 64
        self.label_width = 64
        self.n_channels = 3
        self.n_epochs = 30
        self.decay_rate = 0.9
        self.decay_steps = 10000

params = ModelParameters()

'''
def decode_ycbcr(image):
    Y = image[0]
    Cr = image[1]
    Cb = image[2]
    
    delta = 0.5
    
    R = Y + 1.403 * (Cr - delta)
    G = Y - 0.714 * (Cr - delta) - 0.344 * (Cb - delta)
    B = Y + 1.779 * (Cb - delta)
    
    ycrcb = np.array([R,G,B])
    
    rgb[0], rgb[1], rgb[2] = R, G, B
    return image
'''

def encode_ycbcr(image):
    R = image[:, :, 0]
    G = image[:, :, 1]
    B = image[:, :, 2]
    
    delta = 0.5
    
    Y = tf.add(tf.add(tf.multiply(0.299, R), tf.multiply(0.587, G)), tf.multiply(0.114, B))
    Cr = tf.add(tf.multiply(tf.subtract(R, Y),0.713), delta)
    Cb = tf.add(tf.multiply(tf.subtract(B, Y), 0.564), delta)
    
    img = tf.stack([Y, Cr, Cb], axis=2)
    return img


"""
Builds a tensorflow graph to be used for loading files into memory. ETL process

Args:
    filename: filename
    ksizes: A list of ints that has length >= 4. 
            The size of the sliding window for each dimension of images
    strides: A list of ints that has length >= 4. 
            1-D of length 4. How far the centers of two consecutive patches are in the images. 
            Must be: [1, stride_rows, stride_cols, 1].
    rates: A list of ints that has length >= 4. 1-D of length 4. 
            Must be: [1, rate_rows, rate_cols, 1]. 
            This is the input stride, specifying how far two consecutive patch samples are in the input. 
            Equivalent to extracting patches with patch_sizes_eff = patch_sizes + (patch_sizes - 1) * (rates - 1), 
            followed by subsampling them spatially by a factor of rates. 
            This is equivalent to rate in dilated (a.k.a. Atrous) convolutions.
            
Returns:
    Batch of patches extracted from the image.
"""
def get_img_from_file(filename, ksizes, kstrides, rates, height, width, channels, color_space='YCbCr'):
    file = tf.read_file(filename)
    img = tf.image.decode_png(file, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)
    
    if color_space is 'YCbCr':
        img = encode_ycbcr(img)
    
    patches = tf.image.extract_image_patches(tf.expand_dims(img, 0), 
                                             ksizes, 
                                             kstrides, 
                                             rates, 
                                             padding='VALID')

    input_img_batch = tf.squeeze(patches)
    shape = tf.shape(input_img_batch)
    input_img_batch = tf.reshape(input_img_batch, [shape[0]*shape[1], 
                                                   height, 
                                                   width, 
                                                   channels])
    return input_img_batch

In [None]:
"""
A class for building graph that loads and initializes iterators for loading data into graph
"""

class DataLoader:
    def __init__(self, X_train, y_train, X_test, y_test, params):
        self.X_train = X_train
        self.y_train = y_train
        self.X_test = X_test
        self.y_test = y_test
        self.params = params
        self.in_ksizes = [1, params.input_height, params.input_width, 1]
        self.out_ksizes = [1, params.label_height, params.label_width, 1]
        self.in_kstrides = [1, 10, 10, 1]
        self.out_kstrides = [1, 20, 20, 1]
        self.rates = [1, 1, 1, 1]
        
    """
    Args:
        num_threads: number of parallel calls to be made while parsing
        num_prefetch: number of batches to be pre loaded before training
        
    Returns:
        Next batch to be served to model
    
    """
    def build_iterators(self, num_threads=8, num_prefetch=8):
        def parse_fn(filename, label):
            input_img_batch = get_img_from_file(filename, 
                                                self.in_ksizes, 
                                                self.in_kstrides, 
                                                self.rates, 
                                                self.params.input_height, 
                                                self.params.input_width, 
                                                self.params.n_channels)

            ground_img_batch = get_img_from_file(label, 
                                                 self.out_ksizes, 
                                                 self.out_kstrides, 
                                                 self.rates, 
                                                 self.params.label_height, 
                                                 self.params.label_width, 
                                                 self.params.n_channels)

            return [input_img_batch, ground_img_batch]
        
 
        #train dataset graph
        filenames = self.X_train
        labels = self.y_train
        
        train_dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
        train_dataset = train_dataset.shuffle(len(filenames))
        train_dataset = train_dataset.map(parse_fn, num_parallel_calls=num_threads)
        train_dataset = train_dataset.apply(tf.contrib.data.unbatch())

        train_dataset = train_dataset.batch(self.params.batch_size)
        self.train_dataset = train_dataset.prefetch(num_prefetch)
        
        #val dataset graph
        filenames = self.X_test
        labels = self.y_test
        
        val_dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
        val_dataset = val_dataset.shuffle(len(filenames))
        val_dataset = val_dataset.map(parse_fn, num_parallel_calls=num_threads)
        val_dataset = val_dataset.apply(tf.contrib.data.unbatch())
        val_dataset = val_dataset.batch(self.params.batch_size)
        self.val_dataset = val_dataset.prefetch(num_prefetch)

        #Make iterator
        iterator = tf.data.Iterator.from_structure(train_dataset.output_types, 
                                                   train_dataset.output_shapes)
        next_element = iterator.get_next()
        
        self.train_init_op = iterator.make_initializer(self.train_dataset)
        self.val_init_op = iterator.make_initializer(self.val_dataset)
        
        return next_element
        

In [None]:
from sklearn.feature_extraction.image import extract_patches_2d
from tqdm import tqdm
from model import FSRCNN
"""
Model class for training and inference
Contains:
    build(): Builds the graphs
    fit(): trains the network for params.n_epochs
    score(): For scoring PSNR and loss on validation dataset
    predict(): Takes the filename as arg and returns prediction.
"""
class Model:
    def __init__(self, params, data_loader):
        self.params = params
        self.data_loader = data_loader
        self.d = 32
        self.s = 8
        self.m = 1
        
        #Set mode to 'train' for training and 'infer' for prediction on your own images.
        #Set this parameter before running the build function.
        self.mode = 'train'
        
    
    """
    Builds the training and inference graphs
        
    """
    def build(self):
        decay_rate = self.params.decay_rate
        decay_steps = self.params.decay_steps
        
        next_element = self.data_loader.build_iterators()
        self.y = next_element[1]
        
        gt_image_summ = tf.summary.image('gt_image', self.y)
        
        with tf.name_scope('inference'):
            self.input_img_placeholder = tf.placeholder(tf.float32, shape=(None, self.params.input_height, 
                                                                           self.params.input_width, 
                                                                           self.params.n_channels))

        
        with tf.name_scope('convolutional'):
            if self.mode == 'train':
                x = next_element[0]
                input_image_summ = tf.summary.image('input_image', x)
                
            elif self.mode == 'infer':
                x = self.input_img_placeholder
                
            x, self.model_summ = FSRCNN(x)
            
            self.output = x
            
            frac = self.output.get_shape().as_list()[1]/self.y.get_shape().as_list()[1]
            print(frac, self.output.get_shape().as_list(), self.y.get_shape().as_list()[1])
            self.y = tf.image.central_crop(self.y, central_fraction=frac)
            pred_image_summ = tf.summary.image('pred_image', x)
                        
        with tf.name_scope('loss'):
            
            self.loss = tf.losses.mean_squared_error(self.y, self.output)
            
        with tf.name_scope('metrics'):
            self.psnr = tf.reduce_mean(tf.image.psnr(self.output, self.y, max_val=1))
            self.ssim = tf.reduce_mean(tf.image.ssim(self.output, self.y, max_val=1))
        
        with tf.name_scope('train'):
            self.g_step_tensor = tf.Variable(0, trainable=False)
            self.learning_rate = tf.placeholder(tf.float32, shape=None, name='learning_rate')
            init_lr = self.learning_rate
            decayed_lr = tf.train.exponential_decay(init_lr, self.g_step_tensor,
                                                    decay_steps, decay_rate, staircase=True)
            lr_summ = tf.summary.scalar('lr', decayed_lr)
            
            self.optimizer = tf.train.AdamOptimizer(decayed_lr)
            self.grads_and_vars = self.optimizer.compute_gradients(self.loss)
            self.train_op = self.optimizer.apply_gradients(self.grads_and_vars, global_step=self.g_step_tensor)
             
        with tf.name_scope('performance'):
            
            train_loss_summ = tf.summary.scalar('train_loss', self.loss)
            train_psnr_summ = tf.summary.scalar('train_psnr', self.psnr)
            train_ssim_summ = tf.summary.scalar('train_ssim', self.ssim)
            self.train_stats = tf.summary.merge([train_loss_summ, 
                                            train_psnr_summ, 
                                            train_ssim_summ])
            
            valid_loss_summ = tf.summary.scalar('valid_loss', self.loss)
            valid_psnr_summ = tf.summary.scalar('valid_psnr', self.psnr)
            valid_ssim_summ = tf.summary.scalar('valid_ssim', self.ssim)
            self.valid_stats = tf.summary.merge([valid_loss_summ, 
                                            valid_psnr_summ, 
                                            valid_ssim_summ])
            grads_summ = []
            l2_norm = lambda t: tf.sqrt(tf.reduce_sum(tf.pow(t, 2)))
            
            for gv in self.grads_and_vars:
                if 'conv2d' in gv[1].name:
                    name = gv[1].name.replace(':', '_')
                    grads_summ.append(tf.summary.scalar(name, l2_norm(gv[1])))
                    
            self.performance_summ = tf.summary.merge([grads_summ, lr_summ])
            self.image_stats = tf.summary.merge([gt_image_summ, 
                                                 input_image_summ, 
                                                 pred_image_summ])

    def fit(self, lr, summ_writer):
        epochs = self.params.n_epochs
        saver = tf.train.Saver()
        with tf.Session() as sess:
            
            sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
            #saver.restore(sess, 'dense_models/srcnn_0.0007.ckpt')
            summ_writer.add_graph(sess.graph)
            
            val_losses = [1000.0]
            for epoch in range(epochs):
                train_len = 102
                loss_all = []
                psnr_all = []
                
                
                print('\n**************************')
                print('Epoch: ' + str(epoch))
                with tqdm(total=train_len) as pbar:
                    sess.run(self.data_loader.train_init_op)
                    
                    for steps in range(train_len):
                        train, loss, psnr, mod_summ, train_summ, perf_summ, image_summ = sess.run([self.train_op, 
                                                                                         self.loss, 
                                                                                         self.psnr, 
                                                                                         self.model_summ,
                                                                                         self.train_stats, 
                                                                                         self.performance_summ,
                                                                                         self.image_stats], 
                                                                                         feed_dict={self.learning_rate: lr})
                        
                        
                        summ_writer.add_summary(train_summ, global_step=tf.train.global_step(sess, self.g_step_tensor))
                        summ_writer.add_summary(perf_summ, global_step=tf.train.global_step(sess, self.g_step_tensor))
                        summ_writer.add_summary(image_summ, global_step=tf.train.global_step(sess, self.g_step_tensor))
                        summ_writer.add_summary(mod_summ, global_step=tf.train.global_step(sess, self.g_step_tensor))
                        
                        loss_all.append(loss)
                        psnr_all.append(psnr)
                        pbar.set_description('loss: {:.4f} -- psnr: {:.4f}'.format(float(loss), psnr))
                        pbar.update(1)
                    
                print('train_loss: {:.4f} -- train_psnr: {:.4f}'.format(np.mean(loss_all), 
                                                                        np.mean(psnr_all)))
                

                print('lr:{}'.format(lr))
                
                
                val_loss = self.score(sess)
                valid_summ = sess.run(self.valid_stats)
                summ_writer.add_summary(valid_summ, global_step=tf.train.global_step(sess, self.g_step_tensor))
                
                if val_loss < min(val_losses):
                    print('Saving model with val_loss: {:.4f} to file: unnamed_cnn_{:.4f}.ckpt'.format(val_loss,
                                                                                                 val_loss))
                    saver.save(sess, 'dense_models/unnamed_cnn_{:.4f}.ckpt'.format(val_loss))
                else:
                    print('val_loss did not improve from last best: {:.4}'.format(min(val_losses)))
                print('**************************\n')
                val_losses.append(val_loss)

    def score(self, session):

        with session.as_default():

            test_len = 3

            loss_all = []
            psnr_all = []

            with tqdm(total=test_len) as pbar:
                session.run(self.data_loader.val_init_op)

                for steps in range(test_len):
                    loss, psnr = session.run([self.loss, self.psnr])

                    loss_all.append(loss)
                    psnr_all.append(psnr)

                    pbar.set_description('val_loss: {:.4f} -- val_psnr: {:.4f}'.format(loss, psnr))
                    pbar.update(1)

            print('val_loss: ' + str(np.mean(loss_all)), ' -- val_psnr: ' + str(np.mean(psnr_all)))
            return np.mean(loss_all)


In [None]:
data_loader = DataLoader(X_train, y_train, X_test, y_test, params)

In [None]:
model = Model(params, data_loader)
model.build()

In [None]:
#Lr search
for lr in [0.001, 0.002, 0.003, 0.004, 0.005]:
    hparam_str = 'lr_{:}'.format(lr)
    writer = tf.summary.FileWriter('summaries/prelu/' + hparam_str)
    model.fit(lr, writer)


In [None]:
#I run these cells to train the network on Colab
#Downloads train and test sets and sets up tensorboard. Run cells in order

#Step 1 - download 

#!wget "https://uofi.box.com/shared/static/kfahv87nfe8ax910l85dksyl2q212voc.zip" Set5
!wget "https://uofi.box.com/shared/static/65upg43jjd0a4cwsiqgl6o6ixube6klm.zip"
!wget "https://uofi.box.com/shared/static/igsnfieh4lz68l926l8xbklwsnnk8we9.zip"

import zipfile
import os
for file in os.listdir():
  print(file)
  if file == 'igsnfieh4lz68l926l8xbklwsnnk8we9.zip' or file == '65upg43jjd0a4cwsiqgl6o6ixube6klm.zip':
    zip_ref = zipfile.ZipFile(file, 'r')
    zip_ref.extractall()
    zip_ref.close()
    

!mkdir train
!mv image_SRF_2 train



In [None]:
#Step 2
!mkdir 'summaries'

In [None]:
#Step 3
#Tensorboard
!wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip
!unzip ngrok-stable-linux-amd64.zip

In [None]:
#Step 4
LOG_DIR = './summaries/prelu'
get_ipython().system_raw(
    'tensorboard --logdir {} --host 0.0.0.0 --port 6006 &'
    .format(LOG_DIR)
)
get_ipython().system_raw('./ngrok http 6006 &')
! curl -s http://localhost:4040/api/tunnels | python3 -c \
    "import sys, json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])"

In [None]:
!rm -r summaries