# ABPN

In [None]:
import os

import matplotlib.pyplot as plt
import tensorflow as tf

from data import DIV2K
from model.abpn import abpn
from train import AbpnTrainer

%matplotlib inline

In [None]:
# Super-resolution factor
scale = 2

# Downgrade operator
downgrade = 'bicubic'

In [None]:
# Location of model weights (needed for demo)
weights_dir = f'weights/abpn-x{scale}'
weights_file = os.path.join(weights_dir, 'weights.h5')

os.makedirs(weights_dir, exist_ok=True)

## Datasets

You don't need to download the DIV2K dataset as the required parts are automatically downloaded by the `DIV2K` class. By default, DIV2K images are stored in folder `.div2k` in the project's root directory.

In [None]:
div2k_train = DIV2K(scale=scale, subset='train', downgrade=downgrade)
div2k_valid = DIV2K(scale=scale, subset='valid', downgrade=downgrade)

In [None]:
train_ds = div2k_train.dataset(batch_size=16, random_transform=True)
valid_ds = div2k_valid.dataset(batch_size=1, random_transform=False, repeat_count=1)

## Training

### Pre-trained models

If you want to skip training and directly run the demo below, download [weights-abpn-x2.tar.gz](https://github.com/freedomtan/some_super_resolution_tflite_models/blob/main/h5_weights/weights-abpn-x2.tar.gz?raw=true) and extract the archive in the project's root directory. This will create a `weights/abpn-x2` directory containing the weights of the pre-trained model.

In [None]:
trainer = AbpnTrainer(model=abpn(scale=scale), 
                      checkpoint_dir=f'.ckpt/abpn-x{scale}')

In [None]:
# Train ABPN model for 300,000 steps and evaluate model
# every 1000 steps on the first 10 images of the DIV2K
# validation set. Save a checkpoint only if evaluation
# PSNR has improved.
trainer.train(train_ds,
              valid_ds.take(10),
              steps=300000, 
              evaluate_every=1000, 
              save_best_only=True)

In [None]:
# Restore from checkpoint with highest PSNR
trainer.restore()

In [None]:
# Evaluate model on full validation set
psnrv = trainer.evaluate(valid_ds)
print(f'PSNR = {psnrv.numpy():3f}')

In [None]:
# Save weights to separate location (needed for demo)
trainer.model.save_weights(weights_file)

## Demo

In [None]:
model = abpn(scale=scale)
model.load_weights(weights_file)

In [None]:
from model import resolve_single
from utils import load_image, plot_sample

def resolve_and_plot(lr_image_path):
    lr = load_image(lr_image_path)
    sr = resolve_single(model, lr)
    plot_sample(lr, sr)

In [None]:
resolve_and_plot('demo/0869x4-crop.png')

In [None]:
resolve_and_plot('demo/0829x4-crop.png')

In [None]:
resolve_and_plot('demo/0851x4-crop.png')