# 0. Imports & setting up

In [None]:
import torch
from torch import nn
from torch import optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from src.srgan import Discriminator, Generator
from src.loss import GeneratorLoss
from src.data import SuperResolutionImageDataset
from torchsummary import summary
import torchvision.utils as torchvision_utils
import os
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
from IPython.display import HTML

In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

# 1. Weights

In [None]:
# CHANGE
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# 2. Training

In [None]:
checkpoint_path = "./checkpoints"
checkpoint_period = 500
restore_network = False

In [None]:
#Input Parameters
data_path = "./data/Renders/"
r = 4
n_channels = 3
B = 1
batch_size_train = 128
batch_size_validation = 128
workers = 1
seed = 1317
train_test_val_split = [.7, .15, .15]

In [None]:
#Training parameters
lr = 1e-4
beta1 = 0.9
num_epochs = 1000
update_optimizer = 1000
update_discriminator = 2

In [None]:
hr_size = (128,128)
lr_size = (hr_size[0]//r, hr_size[1]//r)
hr_dimension = (*hr_size,n_channels)
lr_dimension = (*lr_size,n_channels)

In [None]:
dataset = SuperResolutionImageDataset(
    root = data_path,
    transform = transforms.Compose([
        transforms.RandomCrop(hr_size),
    ]),
    target_transform = transforms.Compose([
        # transforms.GaussianBlur(3,1),
        transforms.Resize(lr_size),
    ])
)

random_generator = torch.Generator().manual_seed(seed)
train_dataset, test_dataset, validation_dataset = torch.utils.data.random_split(dataset,train_test_val_split,random_generator)

train_dataloader = DataLoader(
    train_dataset,
    batch_size = batch_size_train,
    shuffle = True,
    num_workers = workers
)
validation_dataloader = DataLoader(
    validation_dataset,
    batch_size = batch_size_validation,
    shuffle = True,
    num_workers = workers
)

In [None]:
netD = Discriminator(hr_dimension)
netG = Generator(lr_dimension,B)
netD.to(device)
netG.to(device)
pass

In [None]:
if restore_network:
    gen_load = torch.load(f'{checkpoint_path}/generator')
    dis_load = torch.load(f'{checkpoint_path}/discriminator')
    initial_epoch = gen_load['epoch']
    netG.load_state_dict(gen_load['model_state_dict'])
    optimizerG = optim.Adam(netG.parameters())
    optimizerG.load_state_dict(gen_load['optimizer_state_dict'])
    netD.load_state_dict(dis_load['model_state_dict'])
    optimizerD = optim.Adam(netD.parameters())
    optimizerD.load_state_dict(dis_load['optimizer_state_dict'])
else: 
    initial_epoch = 1
    netD.apply(weights_init)
    netG.apply(weights_init)
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
pass

In [None]:
# Discriminator Loss (Adversarial)
criterionD = nn.BCELoss()

# Generator Loss
criterionG = GeneratorLoss(device)

#Label Definition
real_label = 0.90
fake_label = 0.0

In [None]:
lr_schedulerD = optim.lr_scheduler.ExponentialLR(optimizerD, 0.1, verbose = True)
lr_schedulerG = optim.lr_scheduler.ExponentialLR(optimizerG, 0.1, verbose = True)

In [None]:
# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
errD = criterionD(torch.tensor([1.]),torch.tensor([0.]))
D_x = 0.5
D_G_z1 = 0.5

print("Starting Training Loop...")
# For each epoch
for epoch in range(initial_epoch,num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(train_dataloader,0):

        updateD = (i%update_discriminator == 0)

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        hr_cpu = data["hr_sample"].to(device)
        lr_cpu = data["lr_sample"].to(device)
        b_size = hr_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)

        # Generate fake image batch with G
        fake = netG(lr_cpu)

        if updateD: #Update D only if it is not significantly better than G
            # Forward pass real batch through D
            output = netD(hr_cpu).view(-1)
            # Calculate loss on all-real batch
            errD_real = criterionD(output, label)
            # Calculate gradients for D in backward pass
            errD_real.backward()
            D_x = output.mean().item()

            ## Train with all-fake batch

            label.fill_(fake_label)
            # Classify all fake batch with D
            output = netD(fake.detach()).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = criterionD(output, label)
            # Calculate the gradients for this batch, accumulated (summed) with previous gradients
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            # Compute error of D as sum over the fake and the real batches
            errD = errD_real + errD_fake
            # Update D
            optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG_all = criterionG(output, label, hr_cpu, fake)
        adv_lossG = errG_all["adversarial_loss"]
        pixel_lossG = errG_all["pixel_loss"]
        perceptual_lossG = errG_all["perceptual_loss"]
        errG = 1e-3*adv_lossG + 0.5*(perceptual_lossG + pixel_lossG) # No 6e-3 since we are normalizing VGG output
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f(total)/%.4f(Adv)/%.4f(Pixel)/%.4f(Perc)\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(train_dataloader),
                     errD.item(), errG.item(), adv_lossG.item(), pixel_lossG.item(), perceptual_lossG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (epoch % 1 == 0) or ((epoch == num_epochs-1) and (i == len(train_dataloader)-1)):
            with torch.no_grad():
                fake = netG(lr_cpu).detach().cpu()
            img_list.append(torchvision_utils.make_grid(fake, padding=2, normalize=True))

        iters += 1
    
    # Update optimizers at end of epoch
    if epoch !=0 and epoch%update_optimizer == 0:
        lr_schedulerD.step()
        lr_schedulerG.step()

    if epoch !=0 and epoch%checkpoint_period == 0:
        torch.save({
                'epoch': epoch,
                'model_state_dict': netG.state_dict(),
                'optimizer_state_dict': optimizerG.state_dict(),
                'loss': errG.item(),
                }, 
                checkpoint_path+"/generator"
        )
        
        torch.save({
                'epoch': epoch,
                'model_state_dict': netD.state_dict(),
                'optimizer_state_dict': optimizerD.state_dict(),
                'loss': errG.item(),
                }, 
                checkpoint_path+"/discriminator"
        )