# Super Resolution

## Imports and Config

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Disable warnings and info
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

In [None]:
import wandb
import PIL

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import ExponentialDecay
from wandb.keras import WandbCallback
from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.figure import Figure
from IPython.display import Image

from data import Images
from utils import TrackTraining
from models.edsr import edsr
from models.srcnn import srcnn
from models.fsrcnn import fsrcnn

In [None]:
DIR_HIGHRES = './images/original/'
DIR_SPLITS = './images/'
WEIGHTS_DIR = './weights/'
RESULTS_DIR = './results/'
SAVED_DIR = './saved/'
CREATE_FOLDERS = False # Creates folder splits the first time

BATCH_SIZE = 16
FACTOR = 4  # 2-4
INPUT_SIZE = 256//FACTOR
REPEAT_COUNT = 10  # Data Augmentation
RGB = True # If false, only Y channel (luminance) in YUV is used
CHANNELS = 3 if RGB else 1

EPOCHS = 20
LOSS = 'mean_absolute_error'
MODEL = 'edsr'
MODEL_NAME = f'{MODEL}-x{FACTOR}-a{REPEAT_COUNT}-c{CHANNELS}-e{EPOCHS}-{LOSS}'

AUTOTUNE = tf.data.experimental.AUTOTUNE

# SRCNN
if MODEL == 'srcnn':
    model = srcnn(factor=FACTOR, channels=CHANNELS)
    model_trained = srcnn(factor=FACTOR, channels=CHANNELS)

# FSRCNN
elif MODEL == 'fsrcnn':
    model = fsrcnn(factor=FACTOR, channels=CHANNELS)
    model_trained = fsrcnn(factor=FACTOR, channels=CHANNELS)

# EDSR baseline from https://arxiv.org/abs/1707.02921
elif MODEL == 'edsr':
    model = edsr(factor=FACTOR, residual_scaling=None, channels=CHANNELS)
    model_trained = edsr(factor=FACTOR, residual_scaling=None, channels=CHANNELS)

In [None]:
wandb.init()
wandb.run.name = MODEL_NAME
config = wandb.config
config.batch_size = BATCH_SIZE
config.factor = FACTOR
config.input_size = INPUT_SIZE
config.repeat_count = REPEAT_COUNT
config.rgb = RGB
config.epochs = EPOCHS
config.loss = LOSS
config.model = MODEL

## Data Preparation

### Splits of high resolution images
Generate splits and obtain dataset of high resolution images

In [None]:
images = Images(path=DIR_HIGHRES, split_path=DIR_SPLITS)
train_ds, val_ds, test_ds = images.get_high_res_partitions(createFolders=CREATE_FOLDERS)
print(f'High res images: {len(train_ds)} (training), {len(val_ds)} (validation), {len(test_ds)}(test)')

Use only Y channel from YUV model (luminance) if RGB is False

In [None]:
def processs_input(input):
    input = tf.image.rgb_to_yuv(input)
    last_dimension_axis = len(input.shape) - 1
    y, u, v = tf.split(input, 3, axis=last_dimension_axis)
    return y, u, v

In [None]:
if not RGB:
    train_ds = train_ds.map(lambda x: processs_input(x)[0])
    val_ds = val_ds.map(lambda x: processs_input(x)[0])
    test_ds = test_ds.map(lambda x: processs_input(x)[0])

Display some images from train dataset

In [None]:
plt.figure(figsize=(5, 5))
for i, image in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    if not RGB:
        plt.imshow(image.numpy().astype("uint32"), cmap='gray')
    else:
        plt.imshow(image.numpy().astype("uint32"))
    plt.axis("off")

### Data Augmentation

In [None]:
data_augmentation = tf.keras.Sequential([
  tf.keras.layers.RandomFlip("horizontal_and_vertical"),
  tf.keras.layers.RandomRotation(0.5),
])

if REPEAT_COUNT > 0:
  augmented_train_ds = train_ds.concatenate(train_ds.map(data_augmentation, num_parallel_calls=AUTOTUNE))
  for _ in range(REPEAT_COUNT-1):
      augmented_train_ds = augmented_train_ds.concatenate(train_ds.map(data_augmentation, num_parallel_calls=AUTOTUNE))
  train_ds = augmented_train_ds

In [None]:
print(f'High res images: {len(train_ds)} (training), {len(val_ds)} (validation), {len(test_ds)}(test)')

In [None]:
plt.figure(figsize=(5, 5))
for i, image in enumerate(train_ds.skip(1000).take(9)):
    ax = plt.subplot(3, 3, i + 1)
    if not RGB:
        plt.imshow(image.numpy().astype("uint32"), cmap='gray')
    else:
        plt.imshow(image.numpy().astype("uint32"))
    plt.axis("off")

### Low resolution images
Obtain low resolution image for each high resolution image

In [None]:
# Scales down images using bicubic downsampling.
def downscale_image(image, input_size=INPUT_SIZE):
    return tf.clip_by_value(tf.image.resize(
        image,
        [input_size, input_size],
        method=tf.image.ResizeMethod.BICUBIC,
        antialias=True
    ), 0, 255)

In [None]:
train_ds = train_ds.map(lambda x: (downscale_image(x), x))
val_ds = val_ds.map(lambda x: (downscale_image(x), x))
test_ds = test_ds.map(lambda x: (downscale_image(x), x))

Display some pairs of images from train dataset with high and low resolutions

In [None]:
plt.figure(figsize=(10, 10))
for i, image in enumerate(train_ds.take(4)):
    plt.subplot(4, 2, 2*i+1)
    if not RGB:
        plt.imshow(image[1].numpy().astype("uint32"), cmap='gray')
    else:
        plt.imshow(image[1].numpy().astype("uint32"))
    plt.axis("off")
    plt.subplot(4, 2, 2*i+2)
    if not RGB:
        plt.imshow(image[0].numpy().astype("uint32"), cmap='gray')
    else:
        plt.imshow(image[0].numpy().astype("uint32"))
    plt.axis("off")

### Performance
Improve performance with cache and prefetch

In [None]:
train_ds = train_ds.batch(BATCH_SIZE).cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.batch(BATCH_SIZE).cache().prefetch(buffer_size=AUTOTUNE)
test_ds = test_ds.batch(BATCH_SIZE).cache().prefetch(buffer_size=AUTOTUNE)

## Training

**Loss:**   
The pixel-wise $L^2$ loss and the pixel-wise $L^1$ loss are frequently used loss functions for training super-resolution models. They measure the pixel-wise mean squared error and the pixel-wise mean absolute error, respectively, between an HR image $I^{HR}$ and an SR image $I^{SR}$. The pixel-wise $L^2$ loss directly optimizes PSNR. Experiments have shown that the pixel-wise $L^1$ loss can sometimes achieve even better performance.

In [None]:
# Adam optimizer and scheduler to reduce learning rate every 5,000 steps
lr_schedule = ExponentialDecay(
    initial_learning_rate=1e-4,
    decay_steps=5000,
    decay_rate=0.9)
OPTIMIZER = Adam(learning_rate=lr_schedule)

# Compile and train model (L1 loss)
model.compile(optimizer=OPTIMIZER, loss=LOSS)
model.fit(train_ds, epochs=EPOCHS, validation_data=val_ds, callbacks=[TrackTraining(), WandbCallback()])

# Save weights
os.makedirs(WEIGHTS_DIR, exist_ok=True)
model.save_weights(f'{WEIGHTS_DIR}weights-{MODEL_NAME}.h5')

# Save model
os.makedirs(WEIGHTS_DIR, exist_ok=True)
model.save(f'{SAVED_DIR}model-{MODEL_NAME}.h5')

## Evaluation

In [None]:
model_trained.load_weights(os.path.join(WEIGHTS_DIR, f'weights-{MODEL_NAME}.h5'))

In [None]:
hr = tf.convert_to_tensor(PIL.Image.open('./images/test/hr/25.jpg'))
lr = downscale_image(hr)
lr_input = lr
if not RGB:
    y, cb, cr = processs_input(lr)
    lr_input = y
lr_batch = tf.expand_dims(lr_input, axis=0)
sr = model_trained(lr_batch)[0]
sr = tf.clip_by_value(sr, 0, 255)

if not RGB:
    out_img_cb = tf.image.resize(cb, [sr.shape[0], sr.shape[0]], method=tf.image.ResizeMethod.BICUBIC)
    out_img_cr = tf.image.resize(cr, [sr.shape[0], sr.shape[0]], method=tf.image.ResizeMethod.BICUBIC)
    sr = tf.concat([sr, out_img_cb, out_img_cr], axis=-1)
    sr = tf.clip_by_value(tf.image.yuv_to_rgb(sr), 0, 255)

images = [lr, sr, hr]
titles = ["Low Resolution", "Super Resolution", "High Resolution"]

In [None]:
# Obtain image with LR, SR, and HR
fig = Figure(figsize=(40, 10), dpi=300)
canvas = FigureCanvasAgg(fig)
for i, (image, title) in enumerate(zip(images, titles)):
    ax = fig.add_subplot(1, 3, i+1)
    ax.imshow(image.numpy().astype("uint32"))
    ax.set_title(title)
    ax.set_axis_off()

# Save and display images
os.makedirs(RESULTS_DIR, exist_ok=True)
fig.savefig(f'{RESULTS_DIR}{MODEL_NAME}.png')
Image(filename=f'{RESULTS_DIR}{MODEL_NAME}.png')

# Log image into wandb
wandb.log({"test_img": wandb.Image(f'{RESULTS_DIR}{MODEL_NAME}.png')})

## Metrics

Scoring functions PSNR and SSMI

In [None]:
def psnr(sr, hr):
    return tf.image.psnr(
        tf.keras.preprocessing.image.img_to_array(sr),
        tf.keras.preprocessing.image.img_to_array(hr),
        max_val=255)

def ssmi(sr, hr):
    sr = tf.expand_dims(sr.numpy().astype("uint8"), axis=0)
    hr = tf.expand_dims(hr.numpy().astype("uint8"), axis=0)
    return tf.image.ssim(sr, hr, max_val=255)[0]

def mse(sr, hr):
    return np.mean(np.square(hr - sr))

In [None]:
psnr_values = []
ssmi_values = []
mse_values = []
for batch in test_ds:
    lr_batch = batch[0]
    sr_batch = model_trained(lr_batch)
    for i in range(len(batch[0])):
        sr = sr_batch[i]
        hr = batch[1][i]
        psnr_values.append(psnr(sr, hr).numpy())
        ssmi_values.append(ssmi(sr, hr).numpy())
        mse_values.append(mse(sr, hr))

avg_psnr = sum(psnr_values)/len(psnr_values)
avg_ssmi = sum(ssmi_values)/len(ssmi_values)
avg_mse = sum(mse_values)/len(mse_values)

print("Average PSNR on test:", avg_psnr)
print("Average SSMI on test:", avg_ssmi)
print("Average MSE on test:", avg_mse)

In [None]:
wandb.log({"PSNR_test": avg_psnr, "SSMI_test": avg_ssmi, "MSE_test": avg_mse})

In [None]:
wandb.finish()