# Super Resolution GAN (SRGAN) training

### Mount drive if in colab

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Navigate to repository
%cd /content/drive/MyDrive/Github/SRGAN

!pip install albumentations==0.4.6

### Import needed modules

In [None]:
import config
import torch
from torch import nn
# Optimization algorithms
import torch.optim as optim
# Dataset manager
from torch.utils.data import DataLoader

from torchvision.models import vgg19

## 0. Define and prepare data

In [None]:
from dataset import MyImageFolder

dataset = MyImageFolder(root_dir="new_data")
print(f"{len(dataset)} samples in dir {dataset.root_dir}/{dataset.class_names[0]}")

## 1. Create model

### Initialize Generator and Discriminator

In [None]:
from model import Generator, Discriminator

gen = Generator(in_channels=3).to(config.DEVICE)
disc = Discriminator(in_channels=3).to(config.DEVICE)

## 2. Loss and optimizer

### Loss

In [None]:
class VGGLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = vgg19(pretrained=True).features[:36].eval().to(config.DEVICE)
        self.loss = nn.MSELoss()

        for param in self.vgg.parameters():
            param.requires_grad = False

    def forward(self, input, target):
        vgg_input_features = self.vgg(input)
        vgg_target_features = self.vgg(input)
        return self.loss(vgg_input_features, vgg_target_features)

bce = nn.BCEWithLogitsLoss()
vgg_loss_fun = VGGLoss()
mse = nn.MSELoss()

### Optimizers

In [None]:
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.9, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.9, 0.999))

## 3. Training

### Train discriminator and generator functions

In [None]:
### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
def train_discriminator(D, opt, fake, high_res, bce):

    # Reset gradients to zero
    opt.zero_grad()

    # Train on real data
    pred_real = D(high_res)
    loss_real = bce(pred_real, torch.ones_like(pred_real) - 0.1 * torch.rand_like(pred_real))

    # Train on fake data
    pred_fake = D(fake.detach())
    loss_fake = bce(pred_fake, torch.zeros_like(pred_fake))

    loss = loss_real + loss_fake

    # Backward pass
    loss.backward()
    # Update weights
    opt.step()

    return loss

### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
def train_generator(D, opt, fake, high_res, vgg_loss, mse):

    # Reset gradients to zero
    opt.zero_grad()

    pred_fake = D(fake)

    adv_loss = 1e-3 * bce(pred_fake, torch.ones_like(pred_fake))
    vgg_loss = 0.006 * vgg_loss(fake, high_res)
    mse_loss = mse(fake, high_res)

    loss = adv_loss + vgg_loss + mse_loss

    # Backward pass
    loss.backward()
    # Update weights
    opt.step()

    return loss

### Load data

In [None]:
loader = DataLoader(dataset, batch_size=config.BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=config.NUM_WORKERS)

### Load last checkpoint

In [None]:
from utils import load_checkpoint

if config.LOAD_MODEL:
    load_checkpoint(
        config.CHECKPOINT_GEN,
        gen,
        opt_gen,
        config.LEARNING_RATE
    )
    load_checkpoint(
        config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE
    )

### Training loop

In [None]:
from utils import plot_examples, save_checkpoint, plot_loss
from time import process_time
import time
import matplotlib.pyplot as plt
import numpy as np
from IPython import display
from matplotlib.ticker import MaxNLocator

print(f"SRGAN training: \n")
print(f" Total training samples: {len(dataset)}\n Number of epochs: {config.NUM_EPOCHS}\n Mini batch size: {config.BATCH_SIZE}\n Number of batches: {len(loader)}\n Learning rate: {config.LEARNING_RATE}\n")


loss_disc = []
loss_gen = []

# Start the stopwatch
t0 = process_time()

fig, ax = plt.subplots(figsize=(10,6), dpi= 80)

for epoch in range(config.NUM_EPOCHS):
    for idx, (low_res, high_res) in enumerate(loader):

        # Send images to device
        high_res = high_res.to(config.DEVICE)
        low_res = low_res.to(config.DEVICE)

        # Generate fake (high_res) image from low_res
        fake = gen(low_res)

        loss_disc_e = train_discriminator(disc, opt_disc, fake, high_res, bce)
        loss_gen_e = train_generator(disc, opt_gen, fake, high_res, vgg_loss_fun, mse)

        # At the end of every epoch
        if idx == config.BATCH_SIZE-1:


            # Plot gen and disc loss
            # Append current epoch loss to list of losses
            loss_disc.append(float(loss_disc_e.detach().cpu()))
            loss_gen.append(float(loss_gen_e.detach().cpu()))

            x = np.arange(0, epoch+1)
            print(x, loss_disc, loss_gen)
            ax.plot(x, loss_disc, label='Discriminator loss', marker='o', color='b')
            ax.plot(x, loss_gen, label='Generator loss', marker='o', color='r')
            ax.set_title('Evolution of losses through epochs')
            ax.set(xlabel='epochs')
            ax.set(ylabel='loss')
            # ax.set_xlim(left=0, right=config.NUM_EPOCHS-1)
            ax.grid()
            if epoch == 0:
              ax.legend(loc='upper right')
            ax.xaxis.set_major_locator(MaxNLocator(integer=True))

            display.clear_output(wait=True)
            print(f"SRGAN training: \n")
            print(f" Total training samples: {len(dataset)}\n Number of epochs: {config.NUM_EPOCHS}\n Mini batch size: {config.BATCH_SIZE}\n Number of batches: {len(loader)}\n Learning rate: {config.LEARNING_RATE}\n")
            # Display current figure
            display.display(fig)
            # Pause execution 0.1s
            time.sleep(0.1)
            plt.close()
            ax.grid()
            
            # Print progress every epoch
            print( 
                f"Epoch [{epoch}/{config.NUM_EPOCHS} - "
                f"Loss D: {loss_disc_e:.4f}, Loss G: {loss_gen_e:.4f}]"
                )

    if config.SAVE_MODEL:
        save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
        save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC)

# Stop the stopwatch
t1 = process_time()
print(f"Elapsed time: {t1-t0}")

plt.savefig("loss_evol.png")

In [None]:
import os

test_images = next(os.walk("datasets/testing/"))[2]

## Test generator with pretrained net (checkpoint)

In [None]:
import os

from utils import load_checkpoint, plot_examples

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

testing = True # If not testing: validation
r = 3 # Zoom factor

# Initialize SRGAN Generator
gen = Generator(in_channels=3).to(config.DEVICE)
# Define optimizer for Generator
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.9, 0.999))

# Load checkpoint (w&b of specified training)
if config.LOAD_MODEL:
    load_checkpoint(
        config.CHECKPOINT_GEN,
        gen,
        opt_gen,
        config.LEARNING_RATE
    )
    # load_checkpoint(
    #     config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE
    # )

# Get all images in testing folder

# Test generator in all images in testing folder
plot_examples("datasets/testing/", gen, 0)

# Get list of generated super resolution images
sr_images = next(os.walk("datasets/testing/sr/"))[2]

# Get list of testing images
test_images = next(os.walk("datasets/testing/"))[2]

# If validating, get hr images
if not testing:
  hr_images = next(os.walk("new_data/hr/"))[2:][0]

# Loop through all test_images
for idx, im in enumerate(test_images):
    # Read lr, sr and hr images
    lr_im = mpimg.imread(f"datasets/testing/{im}")
    sr_im = mpimg.imread(f"datasets/testing/sr/{sr_images[idx]}")
    if not testing:
      hr_im = mpimg.imread(f"new_data/hr/{hr_images[idx]}")

    # Get new widths and heights given zoom in factor r
    w0 = sr_im.shape[0]//2-sr_im.shape[0]//r
    w1 = sr_im.shape[0]//2+sr_im.shape[0]//r
    h1 = sr_im.shape[1]//2-sr_im.shape[1]//r
    h2 = sr_im.shape[1]//2+sr_im.shape[1]//r
    w0_l = lr_im.shape[0]//2-lr_im.shape[0]//r
    w1_l = lr_im.shape[0]//2+lr_im.shape[0]//r
    h1_l = lr_im.shape[1]//2-lr_im.shape[1]//r
    h2_l = lr_im.shape[1]//2+lr_im.shape[1]//r

    # Crop images and define titles for comparison figure
    ims = [lr_im[w0_l:w1_l, h1_l:h2_l, :], sr_im[w0:w1, h1:h2, :]]
    titles = ["Low Resolution", "Super Resolution"]
    if not testing:
      ims.append(hr_im[w0:w1, h1:h2, :])
      titles.append('High Resolution')

    # Initialize figure and axes
    fig, axs = plt.subplots(1,2) if testing else plt.subplots(1,3)
    
    # show image in each subplot, set title deactivate axis
    for idx, ax in enumerate(ims):
      axs[idx].imshow(ax)
      axs[idx].set_title(titles[idx], fontsize=36)
      axs[idx].axis('off')

    # Set size and save figure
    fig.set_size_inches((40, 40), forward=False)
    fig.savefig(f"figures/test_{im}.png", bbox_inches='tight')