# Modified IDN

Adapted with code from https://github.com/Zheng222/IDN-tensorflow and https://github.com/krasserm/super-resolution

You don't need to download the DIV2K dataset as the required parts are automatically downloaded by the DIV2K class. By default, DIV2K images are stored in folder .div2k in the project's root directory.

In [None]:
import os
import matplotlib.pyplot as plt
from tensorflow_addons.optimizers import CyclicalLearningRate
from numba import cuda
from data import DIV2K
from model.idn import idn
from train import IdnTrainer

%matplotlib inline

In [None]:
# helps after out of memory errors
device = cuda.get_current_device()
device.reset()

In [None]:
# Location of model weights (needed for demo)
weights_dir = f'weights/idn-x2'
weights_file = os.path.join(weights_dir, 'weights.h5')
os.makedirs(weights_dir, exist_ok=True)

In [None]:
scale=2
train_data = DIV2K(scale=scale, subset='train', downgrade='bicubic')
valid_data = DIV2K(scale=scale, subset='valid', downgrade='bicubic')
items_in_trainingset = 800
items_in_validationset = 100

In [None]:
batch_size = 16
train_ds = train_data.dataset(batch_size=batch_size, random_transform=True)
valid_ds = valid_data.dataset(batch_size=1, random_transform=False, repeat_count=1)

## Training

In [None]:
maximal_learning_rate=7e-3
initial_learning_rate=maximal_learning_rate/10

epochs_within_each_step = 2
iterations_in_epoch = items_in_trainingset//batch_size
step_size = iterations_in_epoch * epochs_within_each_step
cycles = step_size * 2
training_steps= cycles * 3

In [None]:
cyclical_learning_rate_schedule = CyclicalLearningRate(
    initial_learning_rate=initial_learning_rate, 
    maximal_learning_rate=maximal_learning_rate, 
    step_size=step_size, 
    scale_fn=lambda x: 1 / (2.0 ** (x - 1)))

In [None]:
trainer = IdnTrainer(model=idn(scale=scale), 
    checkpoint_dir=f'.ckpt/idn-x2',
    learning_rate=cyclical_learning_rate_schedule)

In [None]:
trainer.train(train_ds,
    valid_ds.take(items_in_validationset),
    steps=training_steps, 
    evaluate_every=cycles,
    save_best_only=True)

In [None]:
# Included for reference, although I don't believe a mean average of PSNR that is a log based metric is a valid comparison
psnrv = trainer.evaluate(valid_ds.take(items_in_validationset))
print(f'PSNR = {psnrv.numpy():3f}')

In [None]:
# reduce learning rate by several orders of magnitude for 1 cycle
maximal_learning_rate=maximal_learning_rate/100
initial_learning_rate=initial_learning_rate/100
trainer = IdnTrainer(model=idn(scale=scale), 
    checkpoint_dir=f'.ckpt/idn-x2',
    learning_rate=cyclical_learning_rate_schedule)

In [None]:
# train 1 extra cycle
training_steps= training_steps + step_size * 2
trainer.train(train_ds,
    valid_ds.take(items_in_validationset),
    steps=training_steps, 
    evaluate_every=cycles,
    save_best_only=True)

In [None]:
# Included for reference, although I don't believe a mean average of PSNR that is a log based metric is a valid comparison
psnrv = trainer.evaluate(valid_ds.take(items_in_validationset))
print(f'PSNR = {psnrv.numpy():3f}')

In [None]:
trainer.model.save_weights(weights_file)

## Demo

In [None]:
model = idn(scale=scale)
model.load_weights(weights_file)

In [None]:
from model import resolve_single
from utils import load_image, plot_sample

def resolve_and_plot(lr_image_path):
    lr = load_image(lr_image_path)
    sr = resolve_single(model, lr)
    plot_sample(lr, sr)

In [None]:
resolve_and_plot('demo/0869-crop.png')

In [None]:
resolve_and_plot('demo/0829-crop.png')

In [None]:
resolve_and_plot('demo/0851-crop.png')

In [None]:
resolve_and_plot('demo/0855-crop.png')

In [None]:
resolve_and_plot('demo/0855-crop2.png')