### Setup

In [1]:
from google.colab import drive
drive.mount('/content/drive')

# Change this to match your src folder location
%cd '/content/drive/My Drive/CSC420/CSC420_project-main/src'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/My Drive/CSC420/CSC420_project-main/src


In [2]:
import numpy as np
import matplotlib.pyplot as plt
import time
import os
import copy
import pathlib
from PIL import Image
import random
import multiprocessing
import argparse
import logging

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import torchvision
from torchvision import datasets, models, transforms

from arch.dataset import *
from arch.metrics import *
from arch.srgan_model import Generator, Discriminator
from arch.vgg19 import vgg19
from arch.losses import TVLoss, perceptual_loss
from util import arg_util

### Data and Metrics


In [3]:
# Setup Parameters
memcache=True
batch_size=24
num_workers=multiprocessing.cpu_count()

scale=4
patch_size=24
model_res_count=16

# feat_layer='relu2_2'
feat_layer='relu5_4'
vgg_rescale_coeff=0.006
adv_coeff=1e-3
tv_loss_coeff=0.0

t_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define Image Augmentations
aug = transforms.Compose([
    transforms.RandomAffine(
        degrees=180, 
        translate=(0.2, 0.2), 
        scale=(0.7, 1.3),
        shear=40,
        resample=Image.BICUBIC, 
        fillcolor=255
    ),
    transforms.RandomPerspective(
        distortion_scale=0.5, 
        p=0.5, 
        interpolation=Image.BICUBIC, 
        fill=255
    ),
    transforms.ToTensor(),
    transforms.RandomGrayscale(p=0.1),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5)
])

In [4]:
# Load Training Data
gt_path = arg_util.path_abs("data/pokemon/hr/train/")
lr_path = arg_util.path_abs("data/pokemon/lr/train/")

lr_gt_dataset = LowResGroundTruthDataset(
    lr_dir=lr_path, gt_dir=gt_path, memcache=memcache,
    transform=aug
)

# Setup data loader and generator
checkpoint_dir = arg_util.path_abs("train_out/")
checkpoint_dir.mkdir(parents=True, exist_ok=True)

loader = DataLoader(lr_gt_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)

transfer_generator_path=arg_util.path_abs("pretrained/SRGAN.pt")
generator = Generator(img_feat=3, n_feats=64, kernel_size=3, num_block=model_res_count, scale=scale)
if transfer_generator_path:
    generator.load_state_dict(torch.load(transfer_generator_path, map_location=t_device))
    logging.info(f"Loaded pre-trained model: {transfer_generator_path}")
    print(f"Loaded pre-trained model: {transfer_generator_path}")
generator = generator.to(t_device)
generator = generator.train()

Loaded pre-trained model: /content/drive/My Drive/CSC420/CSC420_project-main/src/pretrained/SRGAN.pt


In [5]:
# Setup Metrics Class and test initial performance
metrics = MetricEval(lr_gt_dataset)
metrics.load_generator(generator=generator)

metrics.get_metric(mode="val", metric="MSE")
metrics.get_metric(mode="val", metric="PSNR")
metrics.get_metric(mode="val", metric="VGG22")
metrics.get_metric(mode="val", metric="VGG54")

# SRGAN
# Average MSE Score: 0.0025611999444663525
# Average PSNR Score: 27.24865229483844
# Average VGG22 Score: 0.015313171781599522.
# Average VGG54 Score: 0.004093066323548555

Average MSE Score: 0.0025611999444663525
Average PSNR Score: 27.24865229483844
Average VGG22 Score: 0.015313171781599522
Average VGG54 Score: 0.004093066323548555


tensor(0.0041, device='cuda:0')

### Training

Only run this part if you'd like to train the model from the pretrained weights

In [6]:
# Freeze all layer weights except the last few
for param in generator.parameters():
    param.requires_grad = False
    
for param in generator.last_conv.body.parameters():
    param.requires_grad = True

for param in generator.tail.parameters():
    param.requires_grad = True

# for param in generator.conv02.parameters():
#     param.requires_grad = True

# for param in generator.body[15].parameters():
#     param.requires_grad = True

In [7]:
discriminator = Discriminator(patch_size = 256)
discriminator = discriminator.to(t_device)
discriminator = discriminator.train()

In [8]:
def train(init_lr=1e-4, pre_train_epoch=100, feat_layer="relu5_4"):
    # Initialize Losses
    vgg_net = vgg19().to(t_device)
    vgg_net = vgg_net.eval()
    vgg_loss = perceptual_loss(vgg_net)
    L2_MSE_loss = nn.MSELoss()
    cross_ent = nn.BCELoss()
    logits_ce = nn.BCEWithLogitsLoss()
    tv_loss = TVLoss()

    real_label = torch.ones((batch_size, 1)).to(t_device)
    fake_label = torch.zeros((batch_size, 1)).to(t_device)

    global metrics, generator, discriminator
    g_optim = optim.Adam(generator.parameters(), lr=init_lr)
    g_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(g_optim, mode="min", factor=0.5, patience=10, cooldown=0, verbose=True)
    
    discriminator = discriminator.train()
    generator = generator.train()

    d_optim = optim.Adam(discriminator.parameters(), lr = 5e-5)
    d_scheduler = optim.lr_scheduler.StepLR(d_optim, step_size = 200, gamma = 0.1)

    checkpoint_modulo = (pre_train_epoch // 3) or pre_train_epoch
    for pre_epoch in range(1, pre_train_epoch + 1):
        logging.info(f"Pre-train Epoch [{pre_epoch}]: running.")

        # Train the Discrimator more than Generator at the start to let it catch up
        for _ in range(max(5 - (pre_epoch//5), 1)):
            results = []
            for batch_i, lr_gt_datum in enumerate(loader):
                ## Training Discriminator
                g_optim.zero_grad()
                d_optim.zero_grad()

                img_lr, img_gt = lr_gt_datum['img_lr'].to(t_device), lr_gt_datum['img_gt'].to(t_device)
                img_pred, _ = generator(img_lr)
                # Resize GT to ensure its the same size as HR.
                img_gt = img_gt[:, :, :img_pred.shape[2], :img_pred.shape[3]]
                
                fake_prob = discriminator(img_pred)
                real_prob = discriminator(img_gt)
                
                d_loss_real = logits_ce(real_prob, real_label)
                d_loss_fake = logits_ce(fake_prob, fake_label)
                
                d_loss = d_loss_real + d_loss_fake

                d_loss.backward()
                d_optim.step()
                results.append(d_loss.item())
            print("Discriminator Loss:", sum(results)/len(results))
            d_scheduler.step()

        results = []
        for batch_i, lr_gt_datum in enumerate(loader):
            ## Training Generator
            d_optim.zero_grad()
            g_optim.zero_grad()

            img_lr, img_gt = lr_gt_datum['img_lr'].to(t_device), lr_gt_datum['img_gt'].to(t_device)
            img_pred, _ = generator(img_lr)

            img_gt = ((img_gt + 1.) / 2.)
            img_pred = ((torch.clip(img_pred, -1., 1.) + 1.) / 2.)
            
            # Resize GT to ensure its the same size as HR.
            img_gt = img_gt[:, :, :img_pred.shape[2], :img_pred.shape[3]]

            fake_prob = discriminator(img_pred)
            _percep_loss, hr_feat, sr_feat = vgg_loss(img_gt, img_pred, layer=feat_layer)

            g_loss = L2_MSE_loss(img_pred, img_gt) + \
                vgg_rescale_coeff * _percep_loss + \
                adv_coeff * logits_ce(fake_prob, real_label) + \
                tv_loss_coeff * tv_loss(vgg_rescale_coeff * (hr_feat - sr_feat)**2)

            g_loss.backward()
            g_optim.step()

            results.append(g_loss.item())

        # Log epoch statistics.
        logging.info(f"Pre-train Epoch [{pre_epoch}]: Average Train loss={sum(results)/len(results)}")
        print(f"Pre-train Epoch [{pre_epoch}]: Average Train loss={sum(results)/len(results)}")

        # Evaluate Metrics on Validation Set
        metrics.load_generator(generator=generator)
        psnr = metrics.get_metric(mode="val", metric="PSNR")
        # vgg22 = metrics.get_metric(mode="val", metric="VGG22")
        vgg54 = metrics.get_metric(mode="val", metric="VGG54")
        
        generator = generator.train()
        g_scheduler.step(vgg54)

        if pre_epoch % checkpoint_modulo == 0:
            checkpoint_filepath = (checkpoint_dir / f'pre_trained_model_{pre_epoch}.pt').absolute()
            torch.save(generator.state_dict(),  checkpoint_filepath)
            logging.info(f"Pre-train Epoch [{pre_epoch}]: saved model checkpoint: {checkpoint_filepath}")

In [10]:
load = True
if load:
    generator = Generator(img_feat=3, n_feats=64, kernel_size=3, num_block=model_res_count, scale=scale)
    generator.load_state_dict(torch.load("train_out/SRGAN_pre_adv_gen.pt", map_location=t_device))
    generator = generator.to(t_device)
    generator = generator.train()

    discriminator.load_state_dict(torch.load("train_out/SRGAN_pre_adv_dis.pt", map_location=t_device))
    discriminator = discriminator.to(t_device)
    discriminator = discriminator.train()

train(init_lr=1e-9, pre_train_epoch=50)

In [11]:
# Reevaluate Metrics
metrics.load_generator(generator=generator)

metrics.get_metric(mode="val", metric="MSE")
metrics.get_metric(mode="val", metric="PSNR")
metrics.get_metric(mode="val", metric="VGG22")
metrics.get_metric(mode="val", metric="VGG54")

Average MSE Score: 0.0023944128770381212
Average PSNR Score: 27.449777365281214
Average VGG22 Score: 0.012860788963735104
Average VGG54 Score: 0.0035042911767959595


tensor(0.0035, device='cuda:0')

In [12]:
# Uncomment to save current models

# generator_path_out = arg_util.path_abs("train_out/SRGAN_pre_adv_gen.pt")
# discriminator_path_out = arg_util.path_abs("train_out/SRGAN_pre_adv_dis.pt")
# generator_path_out.parent.mkdir(parents=True, exist_ok=True)
# torch.save(generator.state_dict(), generator_path_out)
# torch.save(discriminator.state_dict(), discriminator_path_out)

### Testing and Visualization

In [13]:
# Load finetuned weights
generator = Generator(img_feat=3, n_feats=64, kernel_size=3, num_block=model_res_count, scale=scale)
generator.load_state_dict(torch.load("train_out/SRGAN_pre_adv_gen.pt", map_location=t_device))
generator = generator.to(t_device)
generator = generator.train()

discriminator.load_state_dict(torch.load("train_out/SRGAN_pre_adv_dis.pt", map_location=t_device))
discriminator = discriminator.to(t_device)
discriminator = discriminator.train()

# Load Unfinetuned Weights to compare
generator2 = Generator(img_feat=3, n_feats=64, kernel_size=3, num_block=model_res_count, scale=scale)
generator2.load_state_dict(torch.load(arg_util.path_abs("pretrained/SRGAN.pt"), map_location=t_device))
generator2 = generator2.to(t_device)
generator2 = generator2.train()

In [None]:
# Generate Test Image Predictions (saved to "results/" folder)
metrics.save_test_metrics(generator, generator2)