**If training on colab, be sure to use a GPU (runtime > Change runtime type > GPU)**

In [1]:
# uncomment and run the lines below if running in google colab
!pip install tensorflow==2.4.3
!git clone https://github.com/jlaihong/image-super-resolution.git
!mv image-super-resolution/* ./

Collecting tensorflow==2.4.3
  Downloading tensorflow-2.4.3-cp38-cp38-manylinux2010_x86_64.whl (394.6 MB)
[K     |████████████████████████████████| 394.6 MB 87 kB/s  eta 0:00:01
Collecting grpcio~=1.32.0
  Downloading grpcio-1.32.0-cp38-cp38-manylinux2014_x86_64.whl (3.8 MB)
[K     |████████████████████████████████| 3.8 MB 1.4 MB/s eta 0:00:01
[?25hCollecting h5py~=2.10.0
  Downloading h5py-2.10.0-cp38-cp38-manylinux1_x86_64.whl (2.9 MB)
[K     |████████████████████████████████| 2.9 MB 1.6 MB/s eta 0:00:01
[?25hCollecting tensorflow-estimator<2.5.0,>=2.4.0
  Downloading tensorflow_estimator-2.4.0-py2.py3-none-any.whl (462 kB)
[K     |████████████████████████████████| 462 kB 1.4 MB/s eta 0:00:01
Collecting gast==0.3.3
  Downloading gast-0.3.3-py2.py3-none-any.whl (9.7 kB)
Collecting oauthlib>=3.0.0
  Using cached oauthlib-3.1.1-py2.py3-none-any.whl (146 kB)
Installing collected packages: oauthlib, grpcio, tensorflow-estimator, h5py, gast, tensorflow
  Attempting uninstall: grpcio


# SRResNet and SRGAN Training for Image Super Resolution

An Implementation of SRGAN: https://arxiv.org/pdf/1609.04802.pdf

In [2]:
import os
import time
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import PiecewiseConstantDecay
from tensorflow.keras.losses import MeanSquaredError, BinaryCrossentropy, MeanAbsoluteError
from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input
from tensorflow.keras.models import Model
from tensorflow.keras.metrics import Mean
from PIL import Image

from datasets.div2k.parameters import Div2kParameters 
from datasets.div2k.loader import create_training_and_validation_datasets
from utils.dataset_mappings import random_crop, random_flip, random_rotate, random_lr_jpeg_noise
from utils.metrics import psnr_metric
from utils.config import config
from utils.callbacks import SaveCustomCheckpoint
from models.srresnet import build_srresnet
from models.srgan import build_discriminator


2021-12-03 02:49:57.180785: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2021-12-03 02:49:57.180805: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


## Prepare the dataset

In [3]:
dataset_key = "bicubic_x4"

data_path = config.get("data_path", "") 

div2k_folder = os.path.abspath(os.path.join(data_path, "div2k"))

dataset_parameters = Div2kParameters(dataset_key, save_data_directory=div2k_folder)

In [4]:
hr_crop_size = 96

In [5]:
train_mappings = [
    lambda lr, hr: random_crop(lr, hr, hr_crop_size=hr_crop_size, scale=dataset_parameters.scale), 
    random_flip, 
    random_rotate, 
    random_lr_jpeg_noise]

In [9]:
train_dataset, valid_dataset = create_training_and_validation_datasets(dataset_parameters, train_mappings)

valid_dataset_subset = valid_dataset.take(10) # only taking 10 examples here to speed up evaluations during training

Couldn't find directory:  /home/ubuntu/div2k/DIV2K_train_LR_bicubic/X4
/home/ubuntu/div2k
Downloading data from http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X4.zip
Begin caching in /home/ubuntu/div2k/cache/DIV2K_train_LR_bicubic/X4/cache.
Completed caching in /home/ubuntu/div2k/cache/DIV2K_train_LR_bicubic/X4/cache.
Couldn't find directory:  /home/ubuntu/div2k/DIV2K_train_HR
/home/ubuntu/div2k
Downloading data from http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
Begin caching in /home/ubuntu/div2k/cache/DIV2K_train_HR/cache.
Completed caching in /home/ubuntu/div2k/cache/DIV2K_train_HR/cache.
Couldn't find directory:  /home/ubuntu/div2k/DIV2K_valid_LR_bicubic/X4
/home/ubuntu/div2k
Downloading data from http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X4.zip
Begin caching in /home/ubuntu/div2k/cache/DIV2K_valid_LR_bicubic/X4/cache.
Completed caching in /home/ubuntu/div2k/cache/DIV2K_valid_LR_bicubic/X4/cache.
Couldn't find directory:  /home/ubun

## Train the SRResNet generator model

In [10]:
generator = build_srresnet(scale=dataset_parameters.scale)

In [11]:
checkpoint_dir=f'./ckpt/sr_resnet_{dataset_key}'

learning_rate=1e-4

checkpoint = tf.train.Checkpoint(step=tf.Variable(0),
                                 epoch=tf.Variable(0),
                                 psnr=tf.Variable(0.0),
                                 optimizer=Adam(learning_rate),
                                 model=generator)

checkpoint_manager = tf.train.CheckpointManager(checkpoint=checkpoint,
                                                directory=checkpoint_dir,
                                                max_to_keep=3)

if checkpoint_manager.latest_checkpoint:
    checkpoint.restore(checkpoint_manager.latest_checkpoint)
    print(f'Model restored from checkpoint at step {checkpoint.step.numpy()} with validation PSNR {checkpoint.psnr.numpy()}.')

In [12]:
training_steps = 1000

steps_per_epoch = 500

training_epochs = training_steps / steps_per_epoch

if checkpoint.epoch.numpy() < training_epochs:
    remaining_epochs = int(training_epochs - checkpoint.epoch.numpy())
    print(f"Continuing Training from epoch {checkpoint.epoch.numpy()}. Remaining epochs: {remaining_epochs}.")
    save_checkpoint_callback = SaveCustomCheckpoint(checkpoint_manager, steps_per_epoch)
    checkpoint.model.compile(optimizer=checkpoint.optimizer, loss=MeanSquaredError(), metrics=[psnr_metric])
    checkpoint.model.fit(train_dataset,validation_data=valid_dataset_subset, steps_per_epoch=steps_per_epoch, epochs=3, callbacks=[save_checkpoint_callback])
else:
    print("Training already completed. To continue training, increase the number of training steps")

Continuing Training from epoch 0. Remaining epochs: 2.
Epoch 1/3


2021-12-03 04:57:55.705927: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 191102976 exceeds 10% of free system memory.
2021-12-03 04:57:55.705927: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 191102976 exceeds 10% of free system memory.


  1/500 [..............................] - ETA: 52:07 - loss: 5307.0220 - psnr_metric: 11.8027

2021-12-03 04:57:58.680547: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 191102976 exceeds 10% of free system memory.
2021-12-03 04:57:58.680547: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 191102976 exceeds 10% of free system memory.


  2/500 [..............................] - ETA: 23:16 - loss: 5555.9426 - psnr_metric: 11.5170

2021-12-03 04:58:01.488736: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 191102976 exceeds 10% of free system memory.


Epoch 2/3
Epoch 3/3


In [13]:
weights_directory = f"weights/srresnet_{dataset_key}"
os.makedirs(weights_directory, exist_ok=True)
weights_file = f'{weights_directory}/generator.h5'
checkpoint.model.save_weights(weights_file)

## Train SRGAN using SRResNet as the generator

In [14]:
generator = build_srresnet(scale=dataset_parameters.scale)
generator.load_weights(weights_file)

In [15]:
discriminator = build_discriminator(hr_crop_size=hr_crop_size)

In [16]:
layer_5_4 = 20
vgg = VGG19(input_shape=(None, None, 3), include_top=False)
perceptual_model = Model(vgg.input, vgg.layers[layer_5_4].output)

In [17]:
binary_cross_entropy = BinaryCrossentropy()
mean_squared_error = MeanSquaredError()

In [18]:
learning_rate=PiecewiseConstantDecay(boundaries=[100000], values=[1e-4, 1e-5])

In [19]:
generator_optimizer = Adam(learning_rate=learning_rate)
discriminator_optimizer = Adam(learning_rate=learning_rate)

In [20]:
srgan_checkpoint_dir=f'./ckpt/srgan_{dataset_key}'

srgan_checkpoint = tf.train.Checkpoint(step=tf.Variable(0),
                                       psnr=tf.Variable(0.0),
                                       generator_optimizer=Adam(learning_rate),
                                       discriminator_optimizer=Adam(learning_rate),
                                       generator=generator,
                                       discriminator=discriminator)

srgan_checkpoint_manager = tf.train.CheckpointManager(checkpoint=srgan_checkpoint,
                                                directory=srgan_checkpoint_dir,
                                                max_to_keep=3)

In [21]:
if srgan_checkpoint_manager.latest_checkpoint:
    srgan_checkpoint.restore(srgan_checkpoint_manager.latest_checkpoint)
    print(f'Model restored from checkpoint at step {srgan_checkpoint.step.numpy()} with validation PSNR {srgan_checkpoint.psnr.numpy()}.')

In [22]:
@tf.function
def train_step(lr, hr):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        lr = tf.cast(lr, tf.float32)
        hr = tf.cast(hr, tf.float32)

        sr = srgan_checkpoint.generator(lr, training=True)

        hr_output = srgan_checkpoint.discriminator(hr, training=True)
        sr_output = srgan_checkpoint.discriminator(sr, training=True)

        con_loss = calculate_content_loss(hr, sr)
        gen_loss = calculate_generator_loss(sr_output)
        perc_loss = con_loss + 0.001 * gen_loss
        disc_loss = calculate_discriminator_loss(hr_output, sr_output)

    gradients_of_generator = gen_tape.gradient(perc_loss, srgan_checkpoint.generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, srgan_checkpoint.discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, srgan_checkpoint.generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, srgan_checkpoint.discriminator.trainable_variables))

    return perc_loss, disc_loss

@tf.function
def calculate_content_loss(hr, sr):
    sr = preprocess_input(sr)
    hr = preprocess_input(hr)
    sr_features = perceptual_model(sr) / 12.75
    hr_features = perceptual_model(hr) / 12.75
    return mean_squared_error(hr_features, sr_features)

def calculate_generator_loss(sr_out):
    return binary_cross_entropy(tf.ones_like(sr_out), sr_out)

def calculate_discriminator_loss(hr_out, sr_out):
    hr_loss = binary_cross_entropy(tf.ones_like(hr_out), hr_out)
    sr_loss = binary_cross_entropy(tf.zeros_like(sr_out), sr_out)
    return hr_loss + sr_loss


In [23]:
perceptual_loss_metric = Mean()
discriminator_loss_metric = Mean()

step = srgan_checkpoint.step.numpy()
steps = 5000

monitor_folder = f"monitor_training/srgan_{dataset_key}"
os.makedirs(monitor_folder, exist_ok=True)

now = time.perf_counter()

for lr, hr in train_dataset.take(steps - step):
    srgan_checkpoint.step.assign_add(1)
    step = srgan_checkpoint.step.numpy()

    perceptual_loss, discriminator_loss = train_step(lr, hr)
    perceptual_loss_metric(perceptual_loss)
    discriminator_loss_metric(discriminator_loss)

    if step % 1000 == 0:
        psnr_values = []
        
        for lr, hr in valid_dataset_subset:
            sr = srgan_checkpoint.generator.predict(lr)[0]
            sr = tf.clip_by_value(sr, 0, 255)
            sr = tf.round(sr)
            sr = tf.cast(sr, tf.uint8)
            
            psnr_value = psnr_metric(hr, sr)[0]
            psnr_values.append(psnr_value)
            psnr = tf.reduce_mean(psnr_values)
            
        image = Image.fromarray(sr.numpy())
        image.save(f"{monitor_folder}/{step}.png" )
        
        duration = time.perf_counter() - now
        
        now = time.perf_counter()
        
        print(f'{step}/{steps}, psnr = {psnr}, perceptual loss = {perceptual_loss_metric.result():.4f}, discriminator loss = {discriminator_loss_metric.result():.4f} ({duration:.2f}s)')
        
        perceptual_loss_metric.reset_states()
        discriminator_loss_metric.reset_states()
        
        srgan_checkpoint.psnr.assign(psnr)
        srgan_checkpoint_manager.save()

1000/5000, psnr = 17.35757064819336, perceptual loss = 0.1616, discriminator loss = 0.3667 (9433.51s)
2000/5000, psnr = 20.70804214477539, perceptual loss = 0.1568, discriminator loss = 0.3132 (9633.08s)
3000/5000, psnr = 20.38214111328125, perceptual loss = 0.1540, discriminator loss = 0.2769 (7655.62s)
4000/5000, psnr = 21.83553695678711, perceptual loss = 0.1511, discriminator loss = 0.3200 (7364.31s)
5000/5000, psnr = 19.311738967895508, perceptual loss = 0.1521, discriminator loss = 0.2438 (7286.52s)


In [24]:
weights_directory = f"weights/srgan_{dataset_key}"
os.makedirs(weights_directory, exist_ok=True)
weights_file = f'{weights_directory}/generator.h5'
srgan_checkpoint.generator.save_weights(weights_file)