# SRCNN - Data augmentation
This notebook implements the super-resolution convolutional neural network, SRCNN, as described by [Dong *et al.*](https://link.springer.com/chapter/10.1007/978-3-319-10593-2_13), but expands on the methodology by investigating further data augmentation including:
 * rotation
 * colour channel swapping

This notebook only covers 2 scaling factors: $2\times$ and $4\times$. Results for each can be acquired by changing the scaling factor and re-running the notebook. See the **SRCNN - Baseline model** notebook for further information on the basic model and general methodology.

## Summary
Results from this notebook are summarized. Using $32 \times 32$ training patches ($33 \times 33$ for scaling factor 3), image patch stride of 14, training batch size of 64 and 500 epochs, training took roughly 5 minutes on a GTX 1080 Ti.

| Scaling factor | Image Set | Rotations | Channel Swap | Multi-size |  Bicubic - PSNR (mean) | SRCNN - PSNR (mean) | PSNR Mean Improvement | Bicubic - SSIM (mean) | SRCNN - SSIM (mean) | SSIM Mean Improvement |
| :------------- | :-------: | :-------: | :----------: | :--------: | :-------------------: | :-----------------: | :-------------------: | :-------------------: | :-----------------: | :-------------------: |
| $2\times$ | Set14 | $\checkmark$ | $\checkmark$ | $\checkmark$ | $23.54$ dB | $24.36$ dB | $+0.82$ dB | $0.65$ | $0.69$ | $+0.04$ |


In [1]:
# Imports.
import numpy as np
from PIL import Image
from tensorflow.keras.utils import Sequence
import os
import shutil
from random import shuffle

import site
site.addsitedir('../')
from srcnn_tf2.data.preprocessing import create_xy_patches, import_from_file, scale_batch, center_crop
from srcnn_tf2.data.plotting import n_compare
from srcnn_tf2.data.oom import SRCNNTrainingGenerator
from srcnn_tf2.model.srcnn_model import SRCNN

# Data locations.
training_folder = '../../../sr_data/T91'
set5_eval_folder = '../../../sr_data/Set5'
set14_eval_folder = '../../../sr_data/Set14'

---
## Data augmentation: Rotation, channel swap, and alternative patching sizes


In [None]:
# Data options.
# Note: 'scaling_factor' should evenly divide into 'y_image_size'.
y_sizes = [(32, 32), (48, 48), (64, 64)]  # Target image size, patches extracted from T91 inputs.
rotations = [0, 90, 180, 270]
channel_combos = [(0,1,2), (1,2,0), (2,0,1)]
scaling_factor = 2
patch_stride = 14
blur_kernel = -1  # Negative applies blur before downscaling, positive applies blur after downscaling
epochs_per_loop = 100
batch_size = 32

y_folder = '../../../sr_data/srcnn_training_temp/ydata'
x_folder = '../../../sr_data/srcnn_training_temp/xdata'

shutil.rmtree(y_folder)
os.makedirs(y_folder)
shutil.rmtree(x_folder)
os.makedirs(x_folder)

# Build data on disk.
counter = 0
filenames = []
for y_image_size in y_sizes:
    for rots in rotations:
        for channels in channel_combos:
            # Data extraction
            # ---------------
            xdata, ydata = create_xy_patches(training_folder,
                                             scaling_factor,
                                             patch_size=y_image_size,
                                             patch_stride=patch_stride,
                                             blur_kernel=blur_kernel,
                                             rotations=[rots], swap_channels=channels)
            xdata = scale_batch(xdata, ydata.shape[1:3])
            ydata = ydata if srcnn_model.padding=='same' else center_crop(ydata, srcnn_model.get_crop_size())
            xdata, ydata = xdata[:10], ydata[:10]
            for xd, yd in zip(xdata, ydata):
                im_size_str = f"{y_image_size.shape[0]}x{y_image_size.shape[1]}"
                rot_str = f"{rots}deg"
                chan_str = f"c{str(channels[0])+str(channels[1])+str(channels[2])}"
                filename_x = f"{x_folder}/{counter}_{im_size_str}_{rot_str}_{chan_str}_x.npy"
                filename_y = f"{y_folder}/{counter}_{im_size_str}_{rot_str}_{chan_str}_y.npy"
                filenames.append(filename_x, filename_y)
                np.save(filename_x, xd)
                np.save(filename_y, yd)
                counter += 1

print(f"Training data has been augmented and saved to disk. {len(filenames)} images have been saved.")

shuffle(filenames) # Shuffle the file names

filenames_rotation_only = [f for f in filenames if ('_c012' in f[0]) & ('_32x32' in f[0])]
filenames_chanswap_only = [f for f in filenames if ('_90deg' in f[0]) & ('_32x32' in f[0])]
filenames_imsize_only = [f for f in filenames if ('_c012' in f[0]) & ('_90deg' in f[0])]

# Test list contains tuples of the form (rotations, channel swap, image size, filenames)
test_list = [
    (True, False, False, filenames_rotation_only),
    (False, True, False, filenames_chanswap_only),
    (False, False, True, filenames_imsize_only),
    (True, True, True, filenames)
]

In [None]:
result_strings = []
testing_images_14 = import_from_file(set14_eval_folder)

for rot, chan, im, files in test_list:
    # Build new model.
    srcnn_m = SRCNN(num_channels=3, f1=9, f3=5, n1=64, n2=32, nlin_layers=1,
                    activation='relu', optimizer='adam', loss='mse', metrics=['accuracy'], padding='valid', batch_norm=False)
    
    # Instantiate generator and train.
    train_gen = SRCNNTrainingGenerator(files, 64)
    srcnn_m.model.fit_generator(train_gen)
    
    # Evaluate performance.
    psnr_14, psnr_bicubic_14 = srcnn_m.benchmark(testing_images_14, metric='psnr', return_metrics=True)
    ssim_14, ssim_bicubic_14 = srcnn_m.benchmark(testing_images_14, metric='ssim', return_metrics=True)
    
    # Record performance.
    rot_check = '$\\checkmark$' if rot > 1 else ''
    chan_check = '$\\checkmark$' if chan > 1 else ''
    mult_check = '$\\checkmark$' if im > 1 else ''
    result_strings.append(
        f"| ${scaling_factor}\\times$ | Set14 | {rot_check} | {chan_check} | {mult_check} |"+
        f" ${np.mean(psnr_bicubic_14):.2f}$ dB | ${np.mean(psnr_14):.2f}$ dB |"+
        f" $+{np.mean(psnr_14) - np.mean(psnr_bicubic_14):.2f}$ dB |"+
        f" ${np.mean(ssim_bicubic_14):.2f}$ | ${np.mean(ssim_14):.2f}$ | $+{np.mean(ssim_14) - np.mean(ssim_bicubic_14):.2f}$ |"
    )
    
    

In [None]:
print(f"| Scaling factor | Image Set | Rotations | Channel Swap | Multi-size |  Bicubic - PSNR (mean) | SRCNN - PSNR (mean) | PSNR Mean Improvement | Bicubic - SSIM (mean) | SRCNN - SSIM (mean) | SSIM Mean Improvement |")
print(f"| :------------- | :-------: | :-------: | :----------: | :--------: | :-------------------: | :-----------------: | :-------------------: | :-------------------: | :-----------------: | :-------------------: |")
for p in result_strings:
    print(p)

### Model defintion
Define the model first. Data must be imported on the fly as there are too many options to store the full training data set in memory at the same time (potentially).

In [2]:
srcnn_model = SRCNN(
    num_channels=3,
    f1=11,
    f3=7,
    n1=256,
    n2=128,
    nlin_layers=4,
    activation='relu',
    optimizer='adam',
    loss='mse',
    metrics=['accuracy'],
    padding='valid',
    batch_norm=False
)

srcnn_model.summary()

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, None, None, 3)]   0         
_________________________________________________________________
conv2d (Conv2D)              (None, None, None, 256)   93184     
_________________________________________________________________
conv2d_1 (Conv2D)            (None, None, None, 128)   32896     
_________________________________________________________________
conv2d_2 (Conv2D)            (None, None, None, 128)   16512     
_________________________________________________________________
conv2d_3 (Conv2D)            (None, None, None, 128)   16512     
_________________________________________________________________
conv2d_4 (Conv2D)            (None, None, None, 128)   16512     
_________________________________________________________________
conv2d_5 (Conv2D)            (None, None, None, 3)    

### Import and preprocess
High-resolution target images, $I_y$, are 32 x 32 pixel sub-images extracted from the original T91 image set. This is done by passing a 32 x 32 pixel window over the originals at a stride of 14 pixels. Target images are then given a Gaussian blur and downscaled by the scaling factor to produce the low-resolution input images, $I_x$. Pre-upscaling for the model is performed as part of the model class.

Training data is saved to disk temporarily as it is too large to hold the complete set in memory. A custom generator is built to pull data from disk and the `.fit_generator()` method will have to be called directly on the model. Pre-upscaling will be done *outside* the SRCNN class (unlike with in-memory data).

#### Multi-size patching
Three sizes of patches, all at a stride of 14 will be attemped: $32 \times 32$, $48 \times 48$, $64 \times 64$

#### Rotation

#### Colour channel swapping

In [None]:
# Data options.
# Note: 'scaling_factor' should evenly divide into 'y_image_size'.
y_sizes = [(32, 32), (48, 48), (64, 64)]  # Target image size, patches extracted from T91 inputs.
rotations = [0, 90, 180, 270]
channel_combos = [(0,1,2), (1,2,0), (2,0,1)]
scaling_factor = 4
patch_stride = 14
blur_kernel = -1  # Negative applies blur before downscaling, positive applies blur after downscaling
epochs_per_loop = 100
batch_size = 32

y_folder = '../../../sr_data/srcnn_training_temp/ydata'
x_folder = '../../../sr_data/srcnn_training_temp/xdata'

# Run training
for y_image_size in y_sizes:
    for rots in rotations:
        for channels in channel_combos:
            # Data extraction
            # ---------------
            xdata, ydata = create_xy_patches(training_folder,
                                             scaling_factor,
                                             patch_size=y_image_size,
                                             patch_stride=patch_stride,
                                             blur_kernel=blur_kernel,
                                             rotations=[rots], swap_channels=channels)
            print(f"\nRotation: {rots} | Image size: {y_image_size} | Channel order: {channels}")
            print(f"\tTarget data size (number of images x image shape x channels): {ydata.shape}")
            print(f"\tTraining data input size (number of images x image shape x channels): {xdata.shape}")
            
            # Model training
            # --------------
            srcnn_model.fit(xdata=xdata, ydata=ydata, epochs=epochs_per_loop, batch_size=batch_size, validation_split=0.1, verbose=0)
            

#srcnn_model.plot_training(figsize=(16, 8), plot_vars=['loss', 'val_loss'])


Rotation: 0 | Image size: (32, 32) | Channel order: (0, 1, 2)
	Target data size (number of images x image shape x channels): (22623, 32, 32, 3)
	Training data input size (number of images x image shape x channels): (22623, 8, 8, 3)
100 epochs completed in 5.0 minutes 22.86 seconds, approx. 3.23 seconds per epoch.

Rotation: 0 | Image size: (32, 32) | Channel order: (1, 2, 0)
	Target data size (number of images x image shape x channels): (22623, 32, 32, 3)
	Training data input size (number of images x image shape x channels): (22623, 8, 8, 3)
100 epochs completed in 5.0 minutes 21.87 seconds, approx. 3.22 seconds per epoch.

Rotation: 0 | Image size: (32, 32) | Channel order: (2, 0, 1)
	Target data size (number of images x image shape x channels): (22623, 32, 32, 3)
	Training data input size (number of images x image shape x channels): (22623, 8, 8, 3)
100 epochs completed in 5.0 minutes 24.43 seconds, approx. 3.24 seconds per epoch.

Rotation: 90 | Image size: (32, 32) | Channel order

### Evaluate model
We evaluate the model using the images on Set14 alone. These run through our benchmarking method that measures both PSNR and SSIM.

In [None]:
testing_images_5 = import_from_file(set5_eval_folder)
testing_images_14 = import_from_file(set14_eval_folder)

#psnr_5, psnr_bicubic_5 = srcnn_model.benchmark(testing_images_5, metric='psnr', return_metrics=True)
psnr_14, psnr_bicubic_14 = srcnn_model.benchmark(testing_images_14, metric='psnr', return_metrics=True)
#ssim_5, ssim_bicubic_5 = srcnn_model.benchmark(testing_images_5, metric='ssim', return_metrics=True)
ssim_14, ssim_bicubic_14 = srcnn_model.benchmark(testing_images_14, metric='ssim', return_metrics=True)

#### PSNR

In [None]:
#srcnn_model.benchmark(testing_images_5, metric='psnr', return_metrics=False)
#srcnn_model.benchmark(testing_images_14, metric='psnr', return_metrics=False)

#### SSIM

In [None]:
#srcnn_model.benchmark(testing_images_5, metric='ssim', return_metrics=False)
#srcnn_model.benchmark(testing_images_14, metric='ssim', return_metrics=False)

In [None]:
rot_check = '$\\checkmark$' if len(rotations) > 1 else ''
chan_check = '$\\checkmark$' if len(channel_combos) > 1 else ''
mult_check = '$\\checkmark$' if len(y_sizes) > 1 else ''

print(f"| Scaling factor | Image Set | Rotations | Channel Swap | Multi-size |  Bicubic - PSNR (mean) | SRCNN - PSNR (mean) | PSNR Mean Improvement | Bicubic - SSIM (mean) | SRCNN - SSIM (mean) | SSIM Mean Improvement |")
print(f"| :------------- | :-------: | :-------: | :----------: | :--------: | :-------------------: | :-----------------: | :-------------------: | :-------------------: | :-----------------: | :-------------------: |")
print(f"| ${scaling_factor}\\times$ | Set14 | {rot_check} | {chan_check} | {mult_check} | ${np.mean(psnr_bicubic_14):.2f}$ dB | ${np.mean(psnr_14):.2f}$ dB | $+{np.mean(psnr_14) - np.mean(psnr_bicubic_14):.2f}$ dB |", end='')
print(f" ${np.mean(ssim_bicubic_14):.2f}$ | ${np.mean(ssim_14):.2f}$ | $+{np.mean(ssim_14) - np.mean(ssim_bicubic_14):.2f}$ |")

---
### CIFAR-10 Examples
We use the CIFAR-10 dataset as a "real world" application where there is no target with which to compare. We upscale using both the trained SRCNN and bicubic interpolation, and compare visually (there is no metric in this case).

Images are saved with tags so they can be included in the summary and compare across the three scaling factors.

In [None]:
from tensorflow.keras.datasets.cifar10 import load_data

# Data import and definition.
d_example_index = {'airplane': 30,
                   'automobile': 32,
                   'bird': 90,
                   'cat': 91,
                   'deer': 130,
                   'dog': 156,
                   'frog': 72,
                   'horse': 152,
                   'ship': 62,
                   'truck': 122}

(x_train, y_train), (x_test, y_test) = load_data()
del x_test
del y_test
del y_train

label_list, im_list = [], []
for di in d_example_index.keys():
    label_list.append(di)
    im_list.append(x_train[d_example_index[di]])

im_list = np.array(im_list) / 255.0

del x_train
# Save image string.
im_prefix = f'srcnn_rotation_channelSwap_multisizeImage_{scaling_factor}x_'

In [None]:
im_pred = srcnn_model.predict(im_list)
im_scale = center_crop(
    images=scale_batch(im_list, (im_list.shape[2]*scaling_factor, im_list.shape[1]*scaling_factor)),
    remove_edge=(im_list.shape[1]*scaling_factor - im_pred.shape[1])//2)

print((im_list.shape[1]*scaling_factor - im_pred.shape[1])//2, im_list.shape[1]*scaling_factor, im_pred.shape[1])

In [None]:
for label, image_raw, image_pred, image_scale in zip(label_list, list(im_list), list(im_pred), list(im_scale)):
    n_compare(
        im_list=[image_raw, image_scale, image_pred],
        label_list=[f'Original: {label.upper()} - [{image_raw.shape[1]} x {image_raw.shape[0]}]',
                    f'Bicubic Interpolation x{scaling_factor} - [{image_scale.shape[1]} x {image_scale.shape[0]}]',
                    f'SRCNN x{scaling_factor} - [{image_pred.shape[1]} x {image_pred.shape[0]}]'],
        figsize=(12,5))
    
    #im = Image.fromarray(np.uint8(image_pred*255))
    #im.save(f"results/{im_prefix}{label}.png")