### Setup

In [1]:
from google.colab import drive
drive.mount('/content/drive')
%cd '/content/drive/My Drive/CSC420/CSC420_project-main/src'
!ls

Mounted at /content/drive
/content/drive/My Drive/CSC420/CSC420_project-main/src
 arch			   infer.py	 results_1	 train_out
 arch_st		   __init__.py	 results_1e-06	 train.py
'Copy of finetune.ipynb'   log.log	 results_1e-07	 util
 data			   main.py	 results_1e-09	 util.ipynb
 finetune.ipynb		   pretrained	 test_out	 val_out
 generate_lr.py		   results_0.1	 test.py


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

import argparse
import logging
logging.basicConfig(filename='./log.log')

import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torch.nn.functional as F
from arch.srgan_model import Generator, Discriminator
from arch.vgg19 import vgg19
from arch.losses import TVLoss, perceptual_loss
from util import arg_util
import pathlib
from PIL import Image
import random
import multiprocessing

### Data 

In [4]:
aug = transforms.Compose([
    transforms.RandomAffine(
        degrees=180, 
        translate=(0.2, 0.2), 
        scale=(0.7, 1.3),
        shear=40,
        resample=Image.BICUBIC, 
        fillcolor=255
    ),
    transforms.ToTensor(),
    transforms.RandomGrayscale(p=0.1),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    # torchvision.transforms.ColorJitter(
    #     brightness=0.01, 
    #     contrast=0.2, 
    #     saturation=0.1, 
    #     hue=0.01
    # )
])

In [None]:
IMG_EXTENSIONS = set(['.jpg', '.jpeg', '.png', '.ppm', '.bmp', 'tiff'])
def is_image(path):
    return path.suffix.lower() in IMG_EXTENSIONS

class LowResGroundTruthDataset(Dataset):
    """Training Dataset for use when training an SR model."""
    def __init__(self, lr_dir, gt_dir, memcache=False, transform=None,
                 strict_filename_intersection=True):
        super().__init__()
        self._DataLoader__initialized = False
        self.lr_dir = pathlib.Path(lr_dir)
        self.gt_dir = pathlib.Path(gt_dir)
        self.memcache = memcache
        self.transform = transform

        # Attempt filename matching.
        self.lr_image_filepaths = [f for f in self.lr_dir.glob('*') if is_image(f)]
        self.gt_image_filepaths = [f for f in self.gt_dir.glob('*') if is_image(f)]

        lr_image_filenames = set(map(os.path.basename, self.lr_image_filepaths))
        gt_image_filenames = set(map(os.path.basename, self.gt_image_filepaths))
        intersect_filenames = lr_image_filenames.intersection(gt_image_filenames)
        if strict_filename_intersection:
            mismatched_filenames = (lr_image_filenames.union(gt_image_filenames)).difference(intersect_filenames)
            if len(mismatched_filenames) > 0:
                raise ValueError(f"Mismatched filenames in lr_dir and gt_dir: {str(mismatched_filenames)}")

        self.image_filenames = list(sorted(intersect_filenames))
        self.image_lr_gt_pairs = []

        # Load the images if we want to cache them in memory.
        if self.memcache:
            for i, img_filename in enumerate(self.image_filenames):
                # Images with Shape: (C, H, W)
                img_lr = Image.open(os.path.join(self.lr_dir, img_filename)).convert("RGB")
                img_gt = Image.open(os.path.join(self.gt_dir, img_filename)).convert("RGB")
                self.image_lr_gt_pairs.append((img_lr, img_gt))

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, i):
        img_filename = self.image_filenames[i]
        if self.memcache:
            img_lr, img_gt = self.image_lr_gt_pairs[i]
        else:
            # Images with Shape: (C, H, W)
            img_lr = Image.open(os.path.join(self.lr_dir, img_filename)).convert("RGB")
            img_gt = Image.open(os.path.join(self.gt_dir, img_filename)).convert("RGB")
            
        # Apply Data Augmentation
        if self.transform is not None:
            # Use set seed to make sure both LR and GT images get same transforms
            seed = random.randint(0, 1e7)
            torch.manual_seed(seed)
            img_lr = self.transform(img_lr)
            torch.manual_seed(seed)
            img_gt = self.transform(img_gt)
        else:
            img_lr = torchvision.transforms.ToTensor()(img_lr)
            img_gt = torchvision.transforms.ToTensor()(img_gt)
            
        # Apply Normalization from [0, 1] -> [-1, 1]
        img_lr = (img_lr * 2) - 1.0
        img_gt = (img_gt * 2) - 1.0

        return {
            'img_filename': img_filename,
            'img_lr': img_lr,
            'img_gt': img_gt
        }


# TODO: Change to match above
class LowResDataSet(Dataset):
    def __init__(self, lr_dir, memcache=False):
        super().__init__()
        self.lr_dir = lr_dir
        self.memcache = memcache

        # Attempt filename matching.
        self.lr_image_filepaths = [f for f in lr_dir.glob('*') if is_image(f)]
        self.image_filenames = sorted(list(map(os.path.basename, self.lr_image_filepaths)))
        self.image_lr = []

        # Load the images if we want to cache them in memory.
        if self.memcache:
            for i, img_filename in enumerate(self.image_filenames):
                img_lr = np.array(Image.open(os.path.join(self.lr_dir, img_filename)).convert("RGB")).astype(np.uint8)
                img_lr = (img_lr / 127.5) - 1.0
                img_lr = img_lr.transpose(2, 0, 1).astype(np.float32)
                self.image_lr[i] = img_lr

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, i):
        img_filename = self.image_filenames[i]
        if self.memcache:
            img_lr = self.image_lr[i]
        else:
            img_lr = np.array(Image.open(os.path.join(self.lr_dir, img_filename)).convert("RGB")).astype(np.uint8)
            img_lr = (img_lr / 127.5) - 1.0
            img_lr = img_lr.transpose(2, 0, 1).astype(np.float32)
        return {
            'img_filename': img_filename,
            'img_lr': img_lr
        }

### Metrics and Testing

In [None]:
memcache=True
batch_size=8
num_workers=multiprocessing.cpu_count()

scale=4
patch_size=24
model_res_count=16

transfer_generator_path=arg_util.path_abs("pretrained/SRResNet.pt")

# feat_layer='relu2_2'
feat_layer='relu5_4'
vgg_rescale_coeff=0.006
adv_coeff=1e-3
tv_loss_coeff=0.0

t_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
t_device

device(type='cuda')

In [None]:
import csv
from skimage.color import rgb2ycbcr
from skimage.metrics import peak_signal_noise_ratio
from torch.utils.data import DataLoader

class MetricEval(object):
    """Object for evaluating metrics on each """
    def __init__(self, train_dataset, memcache=True, transform=None):
        # Train, Validation, Test is the order of each list
        self.modes = ["train", "val", "test"]
        self.output_paths = [arg_util.path_abs(f"{mode}_out/") for mode in self.modes]
        self.lr_paths = [arg_util.path_abs(f"data/pokemon/lr/{mode}") for mode in self.modes]
        self.gt_paths = [arg_util.path_abs(f"data/pokemon/hr/{mode}") for mode in self.modes]

        dataset_params = {"memcache": memcache, "transform": transform}
        self.datasets = [train_dataset] + [
            LowResGroundTruthDataset(lr_dir=lr_path, gt_dir=gt_path, **dataset_params)
            for lr_path, gt_path in zip(self.lr_paths[1:], self.gt_paths[1:])
        ]

        loader_params = {"batch_size": batch_size, "shuffle": False, "drop_last": False, "num_workers": num_workers}
        self.loaders = [DataLoader(dataset, **loader_params) for dataset in self.datasets]
        
        # self.L2_MSE_loss = nn.MSELoss()
        # self.cross_ent = nn.BCELoss()
        self.tv_loss = TVLoss()

        self.vgg_net = None
        self.vgg_loss = None
        
    def load_generator(self, generator_path=None, generator=None):
        if generator_path is None and generator is None:
            raise ValueError(f"One of generator_path or generator must not be None.")
        
        if generator_path:
            self.generator = Generator(img_feat=3, n_feats=64, kernel_size=3, num_block=model_res_count, scale=scale)
            self.generator.load_state_dict(torch.load(generator_path, map_location=t_device))
            self.generator = self.generator.to(t_device)
        else:
            self.generator = generator
        self.generator.eval()
    
    def get_metric(self, mode="val", metric="MSE", write_img=False):
        # Valid Metrics: MSE, PSNR, VGG22, VGG54 (TODO: SSIM)
        idx = self.modes.index(mode)
        with torch.no_grad():
            results = []
            for lr_gt_datum in self.loaders[idx]:
                img_filenames = lr_gt_datum['img_filename']
                img_lrs = lr_gt_datum['img_lr'].to(t_device)
                img_gts = lr_gt_datum['img_gt'].to(t_device)
                
                img_preds, _ = generator(img_lrs)

                img_lrs.cpu()
                img_gts.cpu()
                img_preds.cpu()

                # Revert from [-1, 1] -> [0, 1]
                img_gts = ((img_gts + 1.) / 2.)
                img_preds = ((torch.clip(img_preds, -1., 1.) + 1.) / 2.)
                
                # Resize GT to ensure its the same size as HR.
                img_gts = img_gts[:, :, :img_preds.shape[2], :img_preds.shape[3]]
                if metric == "MSE":
                    loss = F.mse_loss(img_preds, img_gts)
                    results.append(loss)
                elif metric == "PSNR":
                    # Calculate psnr from ycbcr comparison. (N, H, W, C)
                    y_preds = img_preds.cpu().numpy().transpose(0, 2, 3, 1)
                    y_gt = img_gts.cpu().numpy().transpose(0, 2, 3, 1)

                    y_preds = rgb2ycbcr(y_preds)[:, scale:-scale, scale:-scale, 0]
                    y_gt = rgb2ycbcr(y_gt)[:, scale:-scale, scale:-scale, 0]

                    psnr = peak_signal_noise_ratio(y_gt / 255., y_preds / 255., data_range=1.)                        
                    results.append(psnr)

                elif metric == "VGG22" or metric == "VGG54":
                    if self.vgg_net is None:
                        self.vgg_net = vgg19().to(t_device)
                        self.vgg_net = self.vgg_net.eval()
                        self.vgg_loss = perceptual_loss(self.vgg_net)
                    
                    img_gts = img_gts.to(t_device)
                    img_preds = img_preds.to(t_device)
                    
                    feat_layer = "relu2_2" if metric == "VGG22" else "relu5_4"
                    _percep_loss, hr_feat, sr_feat = self.vgg_loss(img_gts, img_preds, layer=feat_layer)
        
                    L2_loss = F.mse_loss(img_preds, img_gts)
                    percep_loss = vgg_rescale_coeff * _percep_loss
                    total_variance_loss = tv_loss_coeff * self.tv_loss(vgg_rescale_coeff * (hr_feat - sr_feat)**2)

                    g_loss = percep_loss + total_variance_loss + L2_loss
                    results.append(g_loss)
                    
                    img_gts = img_gts.cpu()
                    img_preds = img_preds.cpu()


                if write_img:
                    for i in range(len(img_filenames)):
                        result = Image.fromarray((img_preds[i] * 255.).permute((1, 2, 0)).to(torch.uint8).cpu().numpy())
                        result.save(self.output_paths[idx] / f"pred_{img_filenames[i]}")
                        logging.info(f"Inference Output: {self.output_paths[idx] / f'pred_{img_filenames[i]}'}")
                    
            print(f"Average {metric} Score: {sum(results)/len(results)}")
        return sum(results)/len(results)

In [None]:
# Load Training Data
generator_path_out = arg_util.path_abs("train_out/SRGAN_g.pt")
generator_path_out.parent.mkdir(parents=True, exist_ok=True)

discriminator_path_out = arg_util.path_abs("train_out/SRGAN_d.pt")
discriminator_path_out.parent.mkdir(parents=True, exist_ok=True)

checkpoint_dir = arg_util.path_abs("train_out/")
checkpoint_dir.mkdir(parents=True, exist_ok=True)

gt_path = arg_util.path_abs("data/pokemon/hr/train/")
lr_path = arg_util.path_abs("data/pokemon/lr/train/")

lr_gt_dataset = LowResGroundTruthDataset(
    lr_dir=lr_path, gt_dir=gt_path, memcache=memcache,
    transform=aug
)

loader = DataLoader(lr_gt_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
generator = Generator(img_feat=3, n_feats=64, kernel_size=3, num_block=model_res_count, scale=scale)
if transfer_generator_path:
    generator.load_state_dict(torch.load(transfer_generator_path, map_location=t_device))
    logging.info(f"Loaded pre-trained model: {transfer_generator_path}")
    print(f"Loaded pre-trained model: {transfer_generator_path}")
generator = generator.to(t_device)
_ = generator.train()

Loaded pre-trained model: /content/drive/My Drive/CSC420/CSC420_project-main/src/pretrained/SRResNet.pt


In [None]:
generator = Generator(img_feat=3, n_feats=64, kernel_size=3, num_block=model_res_count, scale=scale)
generator.load_state_dict(torch.load(arg_util.path_abs("pretrained/SRGAN.pt"), map_location=t_device))
generator = generator.to(t_device)
_ = generator.train()

In [None]:
metrics = MetricEval(lr_gt_dataset)
metrics.load_generator(generator=generator)

In [None]:
# metrics.get_metric(mode="val", metric="MSE")
# metrics.get_metric(mode="val", metric="PSNR")
# metrics.get_metric(mode="val", metric="VGG22")
# metrics.get_metric(mode="val", metric="VGG54")
# Average MSE Score: 0.013258594088256359
# Average PSNR Score: 20.03266583430994
# Average VGG22 Score: 0.03304675221443176
# Average VGG54 Score: 0.01411474496126175

## BELOW IS WORK IN PROGRESS


In [None]:
# metrics.get_metric(mode="val", metric="VGG54")

In [None]:
# torch.cuda.empty_cache()
# print(torch.cuda.memory_summary())

In [None]:
# Freeze all layer weights except the last few, needs testing
for param in generator.parameters():
    param.requires_grad = False
    
for param in generator.last_conv.body.parameters():
    param.requires_grad = True

# for param in generator.tail.parameters():
#     param.requires_grad = True

# for param in generator.conv02.parameters():
#     param.requires_grad = True

# for param in generator.body[15].parameters():
#     param.requires_grad = True

In [None]:
from pynvml import *
nvmlInit()
h = nvmlDeviceGetHandleByIndex(0)
info = nvmlDeviceGetMemoryInfo(h)
print(f'{info.total/(1024**2)} {info.free/(1024**2)} {info.used/(1024**2)}')

15079.75 411.875 14667.875


In [None]:
from arch.blocks import *

class Discriminator(nn.Module):
    
    def __init__(self, img_feat = 3, n_feats = 64, kernel_size = 3, act = nn.LeakyReLU(inplace = True), num_of_block = 3, patch_size = 96):
        super(Discriminator, self).__init__()
        self.act = act
        
        self.conv01 = conv(in_channel = img_feat, out_channel = n_feats, kernel_size = 3, BN = False, act = self.act)
        self.conv02 = conv(in_channel = n_feats, out_channel = n_feats, kernel_size = 3, BN = False, act = self.act, stride = 2)
        
        body = [discrim_block(in_feats = n_feats * (2 ** i), out_feats = n_feats * (2 ** (i + 1)), kernel_size = 3, act = self.act) for i in range(num_of_block)]    
        self.body = nn.Sequential(*body)
        
        self.linear_size = 460800 # ((patch_size // (2 ** (num_of_block + 1))) ** 2) * (n_feats * (2 ** num_of_block))
        
        tail = []
        
        tail.append(nn.Linear(self.linear_size, 1024))
        tail.append(self.act)
        tail.append(nn.Linear(1024, 1))
        tail.append(nn.Sigmoid())
        
        self.tail = nn.Sequential(*tail)
        
        
    def forward(self, x):
        x = self.conv01(x)
        x = self.conv02(x)
        x = self.body(x)
        x = x.view(-1, self.linear_size)
        x = self.tail(x)
        return x


In [None]:
def train(init_lr=1e-4, pre_train_epoch=100, feat_layer="relu5_4"):
    # Initialize Losses
    vgg_net = vgg19().to(t_device)
    vgg_net = vgg_net.eval()
    vgg_loss = perceptual_loss(vgg_net)
    L2_MSE_loss = nn.MSELoss()
    cross_ent = nn.BCELoss()
    tv_loss = TVLoss()

    real_label = torch.ones((batch_size, 1)).to(t_device)
    fake_label = torch.zeros((batch_size, 1)).to(t_device)

    g_optim = optim.Adam(generator.parameters(), lr=init_lr)
    g_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(g_optim, mode="min", factor=0.5, patience=30, cooldown=10, verbose=True)
    global metrics

    discriminator = Discriminator(patch_size = patch_size * scale)
    discriminator = discriminator.to(t_device)
    discriminator.train()

    d_optim = optim.Adam(discriminator.parameters(), lr = 1e-4)
    d_scheduler = optim.lr_scheduler.StepLR(g_optim, step_size = 200, gamma = 0.1)

    checkpoint_modulo = (pre_train_epoch // 3) or pre_train_epoch
    for pre_epoch in range(1, pre_train_epoch + 1):
        logging.info(f"Pre-train Epoch [{pre_epoch}]: running.")

        for _ in range(5):
            for batch_i, lr_gt_datum in enumerate(loader):
                ## Training Discriminator
                img_lr, img_gt = lr_gt_datum['img_lr'].to(t_device), lr_gt_datum['img_gt'].to(t_device)
                img_pred, _ = generator(img_lr)

                img_gt = img_gt[:, :, :img_pred.shape[2], :img_pred.shape[3]]

                fake_prob = discriminator(img_pred)
                real_prob = discriminator(img_gt)
                print(img_gt.size(), img_pred.size(), fake_prob.size(), real_prob.size())
                
                d_loss_real = cross_ent(real_prob, real_label)
                d_loss_fake = cross_ent(fake_prob, fake_label)
                
                d_loss = d_loss_real + d_loss_fake

                g_optim.zero_grad()
                d_optim.zero_grad()
                d_loss.backward()
                d_optim.step()

                print("Generator Loss:", d_loss.item())
                        
            d_scheduler.step()

        results = []
        for batch_i, lr_gt_datum in enumerate(loader):
            img_lr, img_gt = lr_gt_datum['img_lr'].to(t_device), lr_gt_datum['img_gt'].to(t_device)
            img_pred, _ = generator(img_lr)

            img_gt = ((img_gt + 1.) / 2.)
            img_pred = ((torch.clip(img_pred, -1., 1.) + 1.) / 2.)
            
            # Resize GT to ensure its the same size as HR.
            img_gt = img_gt[:, :, :img_pred.shape[2], :img_pred.shape[3]]

            ## Training Generator
            fake_prob = discriminator(img_pred)
            _percep_loss, hr_feat, sr_feat = vgg_loss(img_gt, img_pred, layer=feat_layer)

            g_loss = L2_MSE_loss(img_pred, img_gt) + \
                vgg_rescale_coeff * _percep_loss + \
                adv_coeff * cross_ent(fake_prob, real_label) + \
                tv_loss_coeff * tv_loss(vgg_rescale_coeff * (hr_feat - sr_feat)**2)
            print(g_loss.item(), end=" ")

            d_optim.zero_grad()
            g_optim.zero_grad()
            g_loss.backward()
            g_optim.step()

            results.append(g_loss.item())

        # Log epoch statistics.
        logging.info(f"Pre-train Epoch [{pre_epoch}]: Average Train loss={sum(results)/len(results)}")
        print(f"Pre-train Epoch [{pre_epoch}]: Average Train loss={sum(results)/len(results)}")
        with open(f"results_{init_lr}", "a") as fp:
            fp.write(f"Pre-train Epoch [{pre_epoch}]: Average Train loss={sum(results)/len(results)}")

        metrics.load_generator(generator=generator)
        psnr = metrics.get_metric(mode="val", metric="PSNR")
        vgg22 = metrics.get_metric(mode="val", metric="VGG22")
        vgg54 = metrics.get_metric(mode="val", metric="VGG54")
        
        generator.train()
        
        g_scheduler.step(vgg54)

        if pre_epoch % checkpoint_modulo == 0:
            checkpoint_filepath = (checkpoint_dir / f'pre_trained_model_{pre_epoch}.pt').absolute()
            torch.save(generator.state_dict(),  checkpoint_filepath)
            logging.info(f"Pre-train Epoch [{pre_epoch}]: saved model checkpoint: {checkpoint_filepath}")

In [None]:
# Average MSE Score: 0.013258594088256359
# Average PSNR Score: 20.03266583430994
# Average VGG22 Score: 0.03304675221443176
# Average VGG54 Score: 0.01411474496126175
# generator = Generator(img_feat=3, n_feats=64, kernel_size=3, num_block=model_res_count, scale=scale)
# if transfer_generator_path:
#     generator.load_state_dict(torch.load(transfer_generator_path, map_location=t_device))
#     print(f"Loaded pre-trained model: {transfer_generator_path}")
# generator = generator.to(t_device)
generator.train()

train(init_lr=1e-9, pre_train_epoch=500)

RuntimeError: ignored

In [None]:
for lr in [1e-1, 1e-2, 1e-3, 1e-4, 1e-5]:
    generator = Generator(img_feat=3, n_feats=64, kernel_size=3, num_block=model_res_count, scale=scale)
    if transfer_generator_path:
        generator.load_state_dict(torch.load(transfer_generator_path, map_location=t_device))
        print(f"Loaded pre-trained model: {transfer_generator_path}")
    generator = generator.to(t_device)
    _ = generator.train()

    train(init_lr=lr, pre_train_epoch=1000)
    generator_path_out = arg_util.path_abs(f"train_out/vgg54_{lr}.pt")
    generator_path_out.parent.mkdir(parents=True, exist_ok=True)
    torch.save(generator.state_dict(), generator_path_out)

In [None]:
torch.cuda.empty_cache()

In [None]:
generator_path_out = arg_util.path_abs("train_out/9_SRResnet_pre.pt")
generator_path_out.parent.mkdir(parents=True, exist_ok=True)
torch.save(generator.state_dict(), generator_path_out)

In [None]:
metrics.get_metric(mode="test", metric="VGG54", write_img=True)

In [None]:
import matplotlib.pyplot as plt
import cv2
        
# Show Results
pred_fp = "test_out/"
real_fp = "data/pokemon/hr/test/"
bad_fp = "data/pokemon/lr/test/"

# generator.eval()
with torch.no_grad():
    for filename in os.listdir(real_fp):
        if np.random.random() < 0.7:
            continue

        print(filename)
        pred = cv2.imread(f'{pred_fp}pred_{filename}')
        real = cv2.imread(f'{real_fp}{filename}')
        lr = cv2.imread(f'{bad_fp}{filename}')

        f, ax = plt.subplots(1, 3, figsize=(8., 8.))
        ax[0].imshow(real)
        ax[0].axis('off')
        ax[1].imshow(pred)
        ax[1].axis('off')
        ax[2].imshow(lr)
        ax[2].axis('off')

        plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
        plt.show()
        # break

# _ = generator.train()

In [None]:
vgg_net = vgg19().to(device)
vgg_net = vgg_net.eval()

discriminator = Discriminator(patch_size = args.patch_size * args.scale)
discriminator = discriminator.to(device)
discriminator.train()

d_optim = optim.Adam(discriminator.parameters(), lr = 1e-4)
scheduler = optim.lr_scheduler.StepLR(g_optim, step_size = 2000, gamma = 0.1)

VGG_loss = perceptual_loss(vgg_net)
cross_ent = nn.BCELoss()
tv_loss = TVLoss()
real_label = torch.ones((args.batch_size, 1)).to(device)
fake_label = torch.zeros((args.batch_size, 1)).to(device)

while fine_epoch < args.fine_train_epoch:
    
    scheduler.step()
    
    for i, tr_data in enumerate(loader):
        gt = tr_data['GT'].to(device)
        lr = tr_data['LR'].to(device)
                    
        ## Training Discriminator
        output, _ = generator(lr)
        fake_prob = discriminator(output)
        real_prob = discriminator(gt)
        
        d_loss_real = cross_ent(real_prob, real_label)
        d_loss_fake = cross_ent(fake_prob, fake_label)
        
        d_loss = d_loss_real + d_loss_fake

        g_optim.zero_grad()
        d_optim.zero_grad()
        d_loss.backward()
        d_optim.step()
        
        ## Training Generator
        output, _ = generator(lr)
        fake_prob = discriminator(output)
        
        _percep_loss, hr_feat, sr_feat = VGG_loss((gt + 1.0) / 2.0, (output + 1.0) / 2.0, layer = args.feat_layer)
        
        L2_loss = l2_loss(output, gt)
        percep_loss = args.vgg_rescale_coeff * _percep_loss
        adversarial_loss = args.adv_coeff * cross_ent(fake_prob, real_label)
        total_variance_loss = args.tv_loss_coeff * tv_loss(args.vgg_rescale_coeff * (hr_feat - sr_feat)**2)
        
        g_loss = percep_loss + adversarial_loss + total_variance_loss + L2_loss
        
        g_optim.zero_grad()
        d_optim.zero_grad()
        g_loss.backward()
        g_optim.step()

        
    fine_epoch += 1

    if fine_epoch % 2 == 0:
        print(fine_epoch)
        print(g_loss.item())
        print(d_loss.item())
        print('=========')

    if fine_epoch % 500 ==0:
        torch.save(generator.state_dict(), './model/SRGAN_gene_%03d.pt'%fine_epoch)
        torch.save(discriminator.state_dict(), './model/SRGAN_discrim_%03d.pt'%fine_epoch)

In [None]:
# Below is Work In Progress

In [None]:
# Train using perceptual & adversarial loss
if adversarial_train_epoch > 0:
    logging.info(f"Training using Adversarial loss for {adversarial_train_epoch} epochs.")

    # Set-up adversarial loss VGG network.
    vgg_net = vgg19().to(t_device)
    vgg_net = vgg_net.eval()

    discriminator = Discriminator(patch_size=patch_size * scale)
    discriminator = discriminator.to(t_device)
    discriminator.train()

    d_optim = optim.Adam(discriminator.parameters(), lr=1e-4)
    scheduler = optim.lr_scheduler.StepLR(g_optim, step_size=2000, gamma=0.1)

    VGG_loss = perceptual_loss(vgg_net)
    cross_ent = nn.BCELoss()
    tv_loss = TVLoss()
    base_real_label = torch.ones((batch_size, 1)).to(t_device)
    base_fake_label = torch.zeros((batch_size, 1)).to(t_device)

    torch.autograd.set_detect_anomaly(True)
    checkpoint_modulo = (adversarial_train_epoch // 10) or adversarial_train_epoch
    for epoch in range(1, adversarial_train_epoch + 1):
        logging.info(f"Epoch [{epoch}]: running.")

        d_optim.step()
        g_optim.step()
        scheduler.step()
        for batch_i, lr_gt_datum in enumerate(loader):
            img_lr, img_gt = lr_gt_datum['img_lr'].to(t_device), lr_gt_datum['img_gt'].to(t_device)
            img_hr_prediction, _ = generator(img_lr)

            # Train Discriminator
            fake_prob = discriminator(img_hr_prediction)
            real_prob = discriminator(img_gt)

            # Avoid mismatched label and probability length in case where batch is remainder of data, but not
            # a perfect fit.
            real_label = base_real_label
            fake_label = base_fake_label
            if len(base_real_label) != len(real_prob):
                real_label = torch.ones((len(real_prob), 1)).to(t_device)
                fake_label = torch.zeros((len(real_prob), 1)).to(t_device)

            d_loss_real = cross_ent(real_prob, real_label)
            d_loss_fake = cross_ent(fake_prob, fake_label)

            d_loss = d_loss_real + d_loss_fake

            # Back-propagate Discriminator
            g_optim.zero_grad()
            d_optim.zero_grad()
            d_loss.backward()
            d_optim.step()

            # Train Generator
            img_hr_prediction, _ = generator(img_lr)
            fake_prob = discriminator(img_hr_prediction)

            l2_loss = L2_MSE_loss(img_hr_prediction, img_gt)
            percep_loss, hr_feat, sr_feat = VGG_loss((img_gt + 1.0) / 2.0, (img_hr_prediction + 1.0) / 2.0, layer=feat_layer)
            percep_loss = vgg_rescale_coeff * percep_loss
            adversarial_loss = adv_coeff * cross_ent(fake_prob, real_label)
            total_variance_loss = tv_loss_coeff * tv_loss(vgg_rescale_coeff * (hr_feat - sr_feat) ** 2)
            g_loss = percep_loss + adversarial_loss + total_variance_loss + l2_loss

            # Back-propagate Generator
            g_optim.zero_grad()
            d_optim.zero_grad()
            g_loss.backward()
            g_optim.step()

        # Log epoch statistics.
        logging.info(f"Epoch [{epoch}]: g_loss={g_loss.item()} d_loss={d_loss.item()}")
        if epoch % checkpoint_modulo == 0:
            g_checkpoint_filepath = (checkpoint_dir / f'SRGAN_g_{epoch}.pt').absolute()
            d_checkpoint_filepath = (checkpoint_dir / f'SRGAN_d_{epoch}.pt').absolute()
            torch.save(generator.state_dict(),  g_checkpoint_filepath)
            torch.save(discriminator.state_dict(), d_checkpoint_filepath)
            logging.info(f"Pre-train Epoch [{epoch}]: saved model checkpoints: {g_checkpoint_filepath}, {d_checkpoint_filepath}")
    if discriminator_path_out:
        torch.save(discriminator.state_dict(), discriminator_path_out)
torch.save(generator.state_dict(), generator_path_out)



