In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
import sys
sys.path.append('/content/drive/MyDrive/Colab Notebooks/DL Lab/Lab7/code')

Mounted at /content/drive


In [None]:
import time
import torch.backends.cudnn as cudnn
from torch import nn
from easydict import EasyDict as edict
from models import Generator, Discriminator, TruncatedVGG19
from datasets import SRDataset
from utils import *
from solver import train

### Alternating training with Generator Initialisation

In [None]:
from solver import train_init

In [None]:
# config
config = edict()
config.csv_folder = '/content/drive/MyDrive/Colab Notebooks/DL Lab/Lab7/data'
config.HR_data_folder = '/content/drive/MyDrive/Colab Notebooks/DL Lab/Lab7/data/DIV2K_train_HR'
config.LR_data_folder = '/content/drive/MyDrive/Colab Notebooks/DL Lab/Lab7/data/DIV2K_train_LR_bicubic/X4'
config.crop_size = 96
config.scaling_factor = 4

# Generator parameters
config.G = edict()
config.G.large_kernel_size = 9
config.G.small_kernel_size = 3
config.G.n_channels = 64
config.G.n_blocks = 16

# Discriminator parameters
config.D = edict()
config.D.kernel_size = 3
config.D.n_channels = 64
config.D.n_blocks = 8
config.D.fc_size = 1024

# Learning parameters
config.checkpoint_init = '/content/drive/MyDrive/Colab Notebooks/DL Lab/Lab7/code/saved_checkpoints/checkpoint_generator.pth.tar'
config.checkpoint = '/content/drive/MyDrive/Colab Notebooks/DL Lab/Lab7/code/saved_checkpoints/checkpoint_srgan2.pth.tar' # path to model (SRGAN) checkpoint, None if none
config.batch_size = 64
config.start_epoch = 0
config.epochs = 150
config.workers = 4
config.vgg19_i = 5  # the index i in the definition for VGG loss; see paper
config.vgg19_j = 4  # the index j in the definition for VGG loss; see paper
config.beta = 1e-3  # the coefficient to weight the adversarial loss in the perceptual loss
config.print_freq = 100
config.lr = 1e-3
config.SGD = edict()
config.SGD.lr= 5e-2
config.SGD.momentum=0.9

# Default device
config.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cudnn.benchmark = True

In [None]:
if config.checkpoint_init is None and config.checkpoint is None:
    # Generator
    generator = Generator(config)
    # generator's optimizer in initialization phase
    optimizer_g_init = torch.optim.SGD(params=filter(lambda p: p.requires_grad, generator.parameters()), 
                                       lr=config.SGD.lr, 
                                       momentum=config.SGD.momentum)
    # generator's optimizer
    optimizer_g = torch.optim.Adam(params=filter(lambda p: p.requires_grad, generator.parameters()),
                                   lr=config.lr)

    # Discriminator
    discriminator = Discriminator(config)
    optimizer_d = torch.optim.Adam(params=filter(lambda p: p.requires_grad, discriminator.parameters()),
                                   lr=config.lr)

elif config.checkpoint_init is not None and config.checkpoint is None:
    checkpoint = torch.load(config.checkpoint_init)
    config.start_epoch_init = checkpoint['epoch'] + 1
    generator = checkpoint['model']
    optimizer_g_init = checkpoint['optimizer']
    print("\nLoaded checkpoint from epoch %d.\n" % (checkpoint['epoch'] + 1))

    # generator's optimizer
    optimizer_g = torch.optim.Adam(params=filter(lambda p: p.requires_grad, generator.parameters()),
                                   lr=config.lr)

    # Discriminator
    discriminator = Discriminator(config)
    optimizer_d = torch.optim.Adam(params=filter(lambda p: p.requires_grad, discriminator.parameters()),
                                   lr=config.lr)


elif config.checkpoint is not None:
    checkpoint = torch.load(config.checkpoint)
    config.start_epoch = checkpoint['epoch'] + 1
    generator = checkpoint['generator']
    discriminator = checkpoint['discriminator']
    optimizer_g = checkpoint['optimizer_g']
    optimizer_d = checkpoint['optimizer_d']
    print("\nLoaded checkpoint from epoch %d.\n" % (checkpoint['epoch'] + 1))


Loaded checkpoint from epoch 33.



In [None]:
# Truncated VGG19 network to be used in the loss calculation
truncated_vgg19 = TruncatedVGG19(i=config.vgg19_i, j=config.vgg19_j)
truncated_vgg19.eval()

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


  0%|          | 0.00/548M [00:00<?, ?B/s]

TruncatedVGG19(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 

In [None]:
# Loss functions
init_loss_criterion = nn.MSELoss()
content_loss_criterion = nn.MSELoss()
adversarial_loss_criterion = nn.BCEWithLogitsLoss()

In [None]:
# Move to default device
generator = generator.to(config.device)
discriminator = discriminator.to(config.device)
truncated_vgg19 = truncated_vgg19.to(config.device)
content_loss_criterion = content_loss_criterion.to(config.device)
adversarial_loss_criterion = adversarial_loss_criterion.to(config.device)

In [None]:
# Custom dataloaders
train_dataset = SRDataset(split='train', config=config)
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=config.batch_size,
                                           shuffle=True, 
                                           num_workers=config.workers,
                                           pin_memory=True)

In [None]:
# initialize learning (G)
config.n_epoch_init = 50

for epoch in range(config.start_epoch_init, config.n_epoch_init):
  train_init(train_loader=train_loader,
         model=generator,
         loss_criterion=init_loss_criterion,
         optimizer=optimizer_g_init,
         epoch=epoch,
         device=config.device,
         print_freq=config.print_freq
  )
  torch.save({'epoch': epoch,
              'model': generator,
              'optimizer': optimizer_g_init},
              'checkpoint_generator.pth.tar')

In [None]:
# Epochs
for epoch in range(config.start_epoch, config.epochs):
    # At the halfway point, reduce learning rate to a tenth
    if epoch == int(config.epochs / 2 + 1):
        adjust_learning_rate(optimizer_g, 0.1)
        adjust_learning_rate(optimizer_d, 0.1)
    # One epoch's training
    train(train_loader=train_loader,
          generator=generator,
          discriminator=discriminator,
          truncated_vgg19=truncated_vgg19,
          content_loss_criterion=content_loss_criterion,
          adversarial_loss_criterion=adversarial_loss_criterion,
          optimizer_g=optimizer_g,
          optimizer_d=optimizer_d,
          epoch=epoch,
          device=config.device,
          beta=config.beta,
          print_freq=config.print_freq)
    # Save checkpoint
    torch.save({'epoch': epoch,
                'generator': generator,
                'discriminator': discriminator,
                'optimizer_g': optimizer_g,
                'optimizer_d': optimizer_d},
                'checkpoint_srgan2.pth.tar')

Epoch: [33][0/1250]----Batch Time 86.700 (86.700)----Data Time 83.904 (83.904)----Cont. Loss 0.4164 (0.4164)----Adv. Loss 0.3133 (0.3133)----Disc. Loss 1.6265 (1.6265)
Epoch: [33][100/1250]----Batch Time 6.306 (4.341)----Data Time 5.958 (3.966)----Cont. Loss 0.4268 (0.4380)----Adv. Loss 0.3133 (0.3133)----Disc. Loss 1.6265 (1.6265)
Epoch: [33][200/1250]----Batch Time 6.674 (3.129)----Data Time 6.325 (2.766)----Cont. Loss 0.4424 (0.4367)----Adv. Loss 0.3133 (0.3133)----Disc. Loss 1.6265 (1.6265)
Epoch: [33][300/1250]----Batch Time 6.267 (2.718)----Data Time 5.920 (2.358)----Cont. Loss 0.5032 (0.4378)----Adv. Loss 0.3133 (0.3133)----Disc. Loss 1.6265 (1.6265)
Epoch: [33][400/1250]----Batch Time 6.504 (2.514)----Data Time 6.157 (2.157)----Cont. Loss 0.4327 (0.4357)----Adv. Loss 0.3133 (0.3133)----Disc. Loss 1.6265 (1.6265)
Epoch: [33][500/1250]----Batch Time 6.795 (2.390)----Data Time 6.443 (2.034)----Cont. Loss 0.4779 (0.4348)----Adv. Loss 0.3133 (0.3133)----Disc. Loss 1.6265 (1.6265)
Ep