# **Implementation of a Super-Resolution Convolutional Neural Network (SRNN)**

### References
*   Image Super-Resolution using Deep Convolutional Neural Networks (paper)
*   ogreen8084/srcnn_part1 (github)

### Steps
*   Import datasets
*   Preproce
*   Create architecture
*   Train model

## Import Datasets

In [0]:
!git clone https://github.com/luizmanke/image-super-resolution.git
!mkdir -p data/input
!cp image-super-resolution/Datasets/Yang\ et\ al/*.bmp data/input

## Preprocess Images

In [0]:
# Import system libraries
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
from PIL import Image
from sklearn.utils import shuffle

In [0]:
def reformat_image_and_blur(image_file_path, file_name, folder_name, blur=False, size=(128,128), dsfactor=2):
    size = size
    new_size = size[0] // dsfactor, size[1] //dsfactor
    if "._" in file_name:
        file_name = file_name[2:]
    image = Image.open(os.path.join(image_file_path,file_name))
    image_size = image.size
    width = image_size[0]
    height = image_size[1]
    bigside = width if width > height else height
    background = Image.new('RGB', (bigside, bigside), (255, 255, 255))
    offset = (int(round(((bigside - width) / 2), 0)), int(round(((bigside - height) / 2), 0)))
    background.paste(image, offset)
    if blur:
        background.thumbnail(new_size, Image.ANTIALIAS)
    back = background.resize(size, Image.BICUBIC)
    back = np.array(back)
    return back/255.

In [0]:
hr_imgs = []
bi_imgs = []
file_path = "data/input"

for shoe_pic in os.listdir(file_path):
  hr = reformat_image_and_blur(file_path, shoe_pic, '')
  hr_imgs.append(hr)
  bicubic = reformat_image_and_blur(file_path, shoe_pic, '', blur=True)
  bi_imgs.append(bicubic)
hr_imgs = np.array(hr_imgs)
bi_imgs = np.array(bi_imgs)

In [0]:
train_pct = 0.8
n_train = int(train_pct * len(hr_imgs))

x_train = bi_imgs[:n_train]
y_train = hr_imgs[:n_train]
x_test = bi_imgs[n_train:]
y_test = hr_imgs[n_train:]

## Create Architecture

In [0]:
batch_size = 19
N = bi_imgs.shape[0]
num_batches = N // batch_size

In [0]:
def conv_layer(input_, filters, bias, alpha=0.2, final_layer=False, strides=[1,1,1,1]):
    conv1 = tf.nn.conv2d(input_, filters, strides, padding='SAME')
    conv1 = tf.nn.bias_add(conv1, bias)
    if not final_layer:
        return tf.nn.leaky_relu(conv1, alpha )
    else:
        return tf.tanh(conv1)

def deconv_layer(inputs, filters, bias, output_shape, alpha=0.1, final_layer=False, strides=[1,1,1,1], padding='SAME'):
    deconv1 = tf.nn.conv2d_transpose(inputs, filters, output_shape, strides, padding)
    deconv1 = tf.nn.bias_add(deconv1, bias)
    if not final_layer:
        
        return tf.nn.leaky_relu(deconv1, alpha)
    else:
        return tf.tanh(deconv1)

In [0]:
def init_filter(shape):
    w = np.random.randn(*shape)/ np.sqrt(np.prod(shape[:-1]) + np.prod(shape[:-2])*shape[-1])
    return tf.Variable(w, dtype=np.float32)

def init_bias(shape):
    b = np.zeros(shape[-1])
    return tf.Variable(b, dtype=np.float32)

def init_bias_deconv(shape):
    b = np.zeros(shape[2])
    return tf.Variable(b, dtype=np.float32)

In [0]:
def next_batch(inputs, labels, batch_size, num_batches):
    for num in range(num_batches):
        inputs_batch = inputs[num*batch_size:(num+1)*batch_size]
        labels_batch = labels[num*batch_size:(num+1)*batch_size]
        yield np.array(inputs_batch), np.array(labels_batch)

In [0]:
inputs = tf.placeholder(tf.float32, [None, 128, 128, 3])
labels = tf.placeholder(tf.float32, [None, 128, 128, 3])

In [0]:
w1_init = [5,5,32,3]
w1 = init_filter(w1_init)
b1 = init_bias_deconv(w1_init)

w2_init = [5,5,32,32]
w2 = init_filter(w2_init)
b2 = init_bias_deconv(w2_init)

w3_init = [5,5,32,32]
w3 = init_filter(w3_init)
b3 = init_bias_deconv(w3_init)

w4_init = [3,3,32,32]
w4 = init_filter(w4_init)
b4 = init_bias_deconv(w4_init)

w5_init = [3,3,32,32]
w5 = init_filter(w5_init)
b5 = init_bias_deconv(w5_init)

w6_init = [3,3,32,32]
w6 = init_filter(w6_init)
b6 = init_bias_deconv(w6_init)

w7_init = [3,3,3,32]
w7 = init_filter(w7_init)
b7= init_bias_deconv(w7_init)

In [0]:
conv_layer1 = deconv_layer(inputs, w1, b1, [batch_size, 128, 128, 32])
conv_layer2 = deconv_layer(conv_layer1, w2, b2,[batch_size, 128, 128, 32])
conv_layer3 = deconv_layer(conv_layer2, w3, b3,[batch_size, 128, 128, 32])
conv_layer4 = deconv_layer(conv_layer3, w4, b4,[batch_size, 128, 128, 32])
conv_layer5 = deconv_layer(conv_layer4, w5, b5,[batch_size, 128, 128, 32])
conv_layer6 = deconv_layer(conv_layer5, w6, b6,[batch_size, 128, 128, 32])
pred = deconv_layer(conv_layer6, w7, b7, [batch_size, 128, 128, 3], final_layer=True)

## Train Model

In [0]:
def show_test_images(blurrys, samples, actuals):
    print('BiCubic    SRCNN (Perpetual Loss)     Ground Truth')
    rows, cols = 9, 3
    fig, axes = plt.subplots(figsize=(10, 10), nrows=rows, ncols=cols, sharex=True, sharey=True)
    samples2 = []
    blurrys2 = []
    actuals2 = []
    
    for blurry in blurrys:
        blurrys2.append((blurry*255).astype(np.uint8))

    for sample in samples[0]:
        samples2.append(((sample /2.0 + 0.5)*255).astype(np.uint8))
    for actual in actuals:
        actuals2.append((actual*255).astype(np.uint8))
    
    #to iterate through the images
    a = 0
    b = 1
    for ax_row in axes:
        for ax in ax_row:
            if b % 3 == 1:
                ax.imshow(blurrys2[a])
            elif b % 3 == 2:
                ax.imshow(samples2[a])
            elif b % 3 == 0:
                ax.imshow(actuals2[a])
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
            if b % 3 == 0:
                a+=1
            b += 1

    plt.show()

In [0]:
loss = tf.reduce_mean(tf.square(labels - pred))
train_op = tf.train.AdamOptimizer(0.00005).minimize(loss)
saver = tf.train.Saver()
epochs = 1001
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for e in range(epochs):
        for i in range(num_batches):
            batch = next_batch(x_train, y_train, batch_size, num_batches)
            inputs_, labels_ = next(batch)
            labels_ = (labels_ - 0.5) * 2
            
            train_loss,_ = sess.run([loss, train_op], feed_dict={inputs:inputs_, labels:labels_})
        if (e % 20 == 0) & (e % 200 != 0):
            print("Epoch: %d train loss: %f" %(e, float(train_loss)))
        if e % 200 == 0:
            print("Epoch: %d train loss: %f" %(e, float(train_loss)))
            gen_samples = sess.run([pred],feed_dict={inputs:x_test})
            show_test_images(x_test, gen_samples, y_test)
        np.random.shuffle([inputs_, labels_])