# WDSR

In [None]:
import os
from datetime import datetime
import matplotlib.pyplot as plt
import tensorflow as tf

from data import DIV2K
from model.srgan import SrGan

%matplotlib inline

In [None]:
tf.compat.v1.enable_eager_execution()

In [None]:
def show_current_time():
    current_time = datetime.now().strftime("%D %H:%M:%S")
    print(current_time)

In [None]:
# activate GPU memory growth
gpus = tf.config.experimental.list_physical_devices('GPU')

if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)

        logical_gpus = tf.config.experimental.list_logical_devices('GPU')

        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

In [None]:
# Number of residual blocks
depth = 16

# Super-resolution factor
scale = 4

# Downgrade operator
downgrade = 'bicubic'

# number of steps
steps = 100000

In [None]:
# Location of model weights (needed for demo)
weights_dir = f'weights/srgan-{depth}-x{scale}'
weights_file = os.path.join(weights_dir, 'weights.h5')

os.makedirs(weights_dir, exist_ok=True)

## Datasets

You don't need to download the DIV2K dataset as the required parts are automatically downloaded by the `DIV2K` class. By default, DIV2K images are stored in folder `.div2k` in the project's root directory.

In [None]:
div2k_train = DIV2K(scale=scale, subset='train', downgrade=downgrade)
div2k_valid = DIV2K(scale=scale, subset='valid', downgrade=downgrade)

In [None]:
train_ds = div2k_train.dataset(batch_size=16, random_transform=True)
valid_ds = div2k_valid.dataset(batch_size=1, random_transform=False, repeat_count=1)

## Training

### Pre-trained models

If you want to skip training and directly run the demo below, download [weights-wdsr-b-32-x4.tar.gz](https://martin-krasser.de/sisr/weights-wdsr-b-32-x4.tar.gz) and extract the archive in the project's root directory. This will create a `weights/wdsr-b-32-x4` directory containing the weights of the pre-trained model.

In [None]:
srgan = SrGan(scale=scale, 
             checkpoint_dir=f'.ckpt/srgan-{depth}-x{scale}',
             valid_ds=valid_ds.take(10),
             steps=steps)

srgan.compile()

In [None]:
show_current_time()

In [None]:
number_of_steps = steps - srgan.checkpoint.step.numpy()
print('Number 0f steps:', number_of_steps)

srgan.fit(train_ds.take(number_of_steps), epochs=1)

In [None]:
show_current_time()

In [None]:
# Restore from checkpoint with highest PSNR
srgan.restore()

In [None]:
# Evaluate model on full validation set
psnr = trainer.evaluate(valid_ds)
print(f'PSNR = {psnr.numpy():3f}')

In [None]:
# Save weights to separate location (needed for demo)
trainer.model.save_weights(weights_file)

## Demo

In [None]:
model = wdsr_model(scale=scale, num_res_blocks=depth)
model.load_weights(weights_file)

In [None]:
from model import resolve_single
from utils import load_image, plot_sample

def resolve_and_plot(lr_image_path):
    lr = load_image(lr_image_path)
    sr = resolve_single(model, lr)
    plot_sample(lr, sr)

In [None]:
resolve_and_plot('demo/0869x4-crop.png')

In [None]:
resolve_and_plot('demo/0829x4-crop.png')

In [None]:
resolve_and_plot('demo/0851x4-crop.png')

In [None]:
weights_dir_a = f'weights/wdsr-a-{32}-x{scale}'
weights_file_a = os.path.join(weights_dir_a, 'weights.h5')

model_a = wdsr_a(scale=scale, num_res_blocks=32)
model_a.load_weights(weights_file_a)



weights_dir_b = f'weights/wdsr-b-{32}-x{scale}'
weights_file_b = os.path.join(weights_dir_b, 'weights.h5')

model_b = wdsr_b(scale=scale, num_res_blocks=32)
model_b.load_weights(weights_file_b)



weights_dir_b_16 = f'weights/wdsr-b-{16}-x{scale}'
weights_file_b_16 = os.path.join(weights_dir_b_16, 'weights.h5')

model_b_16 = wdsr_b(scale=scale, num_res_blocks=16)
model_b_16.load_weights(weights_file_b_16)


def resolve_and_plot(model, lr_image_path):
    lr = load_image(lr_image_path)
    sr = resolve_single(model, lr)
    plot_sample(lr, sr)

In [None]:
resolve_and_plot(model_a, 'demo/0851x4-crop.png')
resolve_and_plot(model_b, 'demo/0851x4-crop.png')
resolve_and_plot(model_b_16, 'demo/0851x4-crop.png')