# Colab
If you are planning any changes or modifications please use the github link: https://github.com/monishramadoss/SRGAN

In [None]:
!nvidia-smi
!pip install --force https://github.com/chengs/tqdm/archive/colab.zip

Wed Apr  8 02:08:22 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.64.00    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   38C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

# HyperParameters

In [None]:
UPSCALE_FACTOR = 2
BATCHSIZE = 32
EPOCHS = 1000
LOWRES = 56

# Datasets

In [None]:
import os
from tqdm.auto import tqdm
import urllib.request
import zipfile

DEVSET_URL = "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip"
TRAINSET_URL = "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip"
DEVSET = "./data/DIV2K_valid_HR.zip"
TRAINSET = "./data/DIV2K_train_HR.zip"
DEVDATA_FOLDER = "./data/DIV2K_valid_HR"
TRAINDATA_FOLDER = "./data/DIV2K_train_HR"

class TqdmUpTo(tqdm):
    def update_to(self, b=1, bsize=1, tsize=None):
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)


if not os.path.exists('./data'):
    os.makedirs('./data')
if not os.path.exists('./data/train'):
    os.makedirs('./data/train')
if not os.path.exists('./data/dev'):
    os.makedirs('./data/dev')
if not os.path.exists('./checkpoint'):
    os.makedirs('./checkpoint')


if not os.path.exists(TRAINSET):
    with TqdmUpTo(unit='B', unit_scale=True, miniters=1, desc="Div2k Train Set") as t:
        urllib.request.urlretrieve(TRAINSET_URL, TRAINSET, reporthook=t.update_to)
    with zipfile.ZipFile(TRAINSET, 'r') as zip_ref:
        zip_ref.extractall('./data/train')
if not os.path.exists(DEVSET):
    with TqdmUpTo(unit='B', unit_scale=True, miniters=1, desc="Div2k Valid Set") as t:
        urllib.request.urlretrieve(DEVSET_URL, DEVSET, reporthook=t.update_to)
    with zipfile.ZipFile(DEVSET, 'r') as zip_ref:
        zip_ref.extractall('./data/dev')








# Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

def swish(x):
    return x * torch.sigmoid(x)

class FeatureExtractor(nn.Module):
    def __init__(self, feature_layer=11):
        super(FeatureExtractor, self).__init__()
        cnn = torchvision.models.vgg19(pretrained=True)
        self.features = nn.Sequential(*list(cnn.features.children())[:(feature_layer + 1)])

    def forward(self, x):
        return self.features(x)


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 64, 3, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, 3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 128, 3, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.conv5 = nn.Conv2d(128, 256, 3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(256)
        self.conv6 = nn.Conv2d(256, 256, 3, stride=2, padding=1)
        self.bn6 = nn.BatchNorm2d(256)
        self.conv7 = nn.Conv2d(256, 512, 3, stride=1, padding=1)
        self.bn7 = nn.BatchNorm2d(512)
        self.conv8 = nn.Conv2d(512, 512, 3, stride=2, padding=1)
        self.bn8 = nn.BatchNorm2d(512)
        self.conv9 = nn.Conv2d(512, 1, 1, stride=1, padding=1)

    def forward(self, x):
        x = swish(self.conv1(x))
        x = swish(self.bn2(self.conv2(x)))
        x = swish(self.bn3(self.conv3(x)))
        x = swish(self.bn4(self.conv4(x)))
        x = swish(self.bn5(self.conv5(x)))
        x = swish(self.bn6(self.conv6(x)))
        x = swish(self.bn7(self.conv7(x)))
        x = swish(self.bn8(self.conv8(x)))
        x = self.conv9(x)

        return torch.sigmoid(F.avg_pool2d(x, x.size()[2:])).view(x.size()[0], -1)


class Generator(nn.Module):
    def __init__(self, n_residual_blocks, upscale_factor=2, n_filters=64, inplace=False):
        super(Generator, self).__init__()
        self.n_residual_blocks = n_residual_blocks
        self.upsample_factor = upscale_factor
        self.conv1 = nn.Conv2d(3, n_filters, 9, stride=1, padding=4)

        for i in range(self.n_residual_blocks):
            self.add_module('residual_block' + str(i + 1), ResidualBlock(n_filters, 3, n_filters, 1))

        self.conv2 = nn.Conv2d(n_filters, n_filters, 3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(n_filters)
        for i in range(self.upsample_factor // 2):
            self.add_module('upsample' + str(i + 1), UpsampleBlock(n_filters, n_filters))
        self.conv3 = nn.Conv2d(n_filters, 3, 9, stride=1, padding=4)

    def forward(self, x):
        x = swish(self.conv1(x))
        y = x.clone()

        for i in range(self.n_residual_blocks):
            y = self.__getattr__('residual_block' + str(i + 1))(y)

        x = self.bn2(self.conv2(y)) + x

        for i in range(self.upsample_factor // 2):
            x = self.__getattr__('upsample' + str(i + 1))(x)

        return self.conv3(x)


class UpSampleConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=2):
        super(UpSampleConvLayer, self).__init__()
        self.upsample = upsample
        self.upsample_layer = nn.Upsample(scale_factor=upsample)
        reflection_padding = kernel_size // 2
        self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        y = self.upsample_layer(x)
        y = self.reflection_pad(y)
        y = self.conv(y)
        return y


class ResidualBlock(nn.Module):
    def __init__(self, channels=64, k=3, n=64, s=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, n, k, stride=s, padding=1)
        self.bn1 = nn.BatchNorm2d(n)
        self.conv2 = nn.Conv2d(n, n, k, stride=s, padding=1)
        self.bn2 = nn.BatchNorm2d(n)

    def forward(self, x):
        y = self.conv1(x)
        y = self.bn1(y)
        y = swish(y)
        y = self.conv2(y)
        y = self.bn2(y)
        y = y + x
        return y


class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpsampleBlock, self).__init__()
        # self.conv = nn.Conv2d(in_channels, out_channels * 4, 3, 1, padding=1)
        self.convT = nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        # self.up_layer = UpSampleConvLayer(in_channels, out_channels, 3, 1)
        # self.shuffler = nn.PixelShuffle(2)

    def forward(self, x):
        # y = self.up_layer(x)
        # y = self.conv(x)
        y = self.convT(x)
        # y = self.shuffler(y)
        y = self.bn(y)
        y = swish(y)
        return y

# Training Setup

In [None]:

import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim

transform = transforms.Compose([transforms.RandomCrop(LOWRES*UPSCALE_FACTOR),
                                transforms.ToTensor()])
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
unnormalize = transforms.Normalize(mean = [-2.118, -2.036, -1.804], std = [4.367, 4.464, 4.444])
scale = transforms.Compose([transforms.ToPILImage(), transforms.Resize(LOWRES),
                            transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
train_dataset = datasets.ImageFolder(root='./data/train', transform=transform)
dev_dataset = datasets.ImageFolder(root='./data/dev', transform=transform)

train_dataloader = torch.utils.data.DataLoader(train_dataset, BATCHSIZE)
valid_dataloader = torch.utils.data.DataLoader(dev_dataset, 1)


content_criterion = nn.MSELoss()
GeneratorDevice = torch.device("cuda:0")
DiscriminatorDevice = torch.device("cuda:0")
adversarial_criterion = nn.BCELoss()

generator = Generator(16, UPSCALE_FACTOR)
discriminator = Discriminator()
feature_extractor = FeatureExtractor()

generator = generator.to(GeneratorDevice)
discriminator = discriminator.to(DiscriminatorDevice)
feature_extractor = feature_extractor.to(DiscriminatorDevice)
low_res = torch.FloatTensor(BATCHSIZE, 3, LOWRES, LOWRES)
ones_const = torch.ones(BATCHSIZE, 1).to(DiscriminatorDevice)


Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/checkpoints/vgg19-dcbb9e9d.pth





# Pretrain Generator

In [None]:
optim_generator = optim.Adam(generator.parameters(), lr=0.0001)
for epoch in tqdm(range(2), desc ='Pretraining'):
    for i, data in enumerate(train_dataloader):
        high_res_real, _ = data
        for j in range(BATCHSIZE):
            low_res[j] = scale(high_res_real[j])
            high_res_real[j] = normalize(high_res_real[j])
        high_res_real = high_res_real.to(GeneratorDevice)
        high_res_fake = generator(low_res.to(GeneratorDevice))

        generator.zero_grad()
        generator_content_loss = content_criterion(high_res_fake, high_res_real)
        generator_content_loss.backward()
        optim_generator.step()





# Training Loop

In [None]:
generator_optimzer = optim.Adam(generator.parameters(), lr=0.00001)
discriminator_optimzer = optim.Adam(discriminator.parameters(), lr=0.00001)
for epoch in tqdm(range(EPOCHS), desc='Training'):
    for i, data in enumerate(train_dataloader):
        high_res_real, _ = data
        for j in range(BATCHSIZE):
            low_res[j] = scale(high_res_real[j])
            high_res_real[j] = normalize(high_res_real[j])
        high_res_real = high_res_real.to(GeneratorDevice)
        high_res_fake = generator(low_res.to(GeneratorDevice))
        
        target_real = (torch.rand(BATCHSIZE, 1) * 0.5 + 0.7).to(DiscriminatorDevice)
        target_fake = (torch.rand(BATCHSIZE, 1) * 0.3).to(GeneratorDevice)
        high_res_real = high_res_real.to(DiscriminatorDevice)
        high_res_fake = high_res_fake.to(DiscriminatorDevice)

        #Train D
        discriminator.zero_grad()
        discriminator_loss = adversarial_criterion(discriminator(high_res_real), target_real) + adversarial_criterion(discriminator(high_res_fake), target_fake )
        discriminator_loss.backward(retain_graph=True)
        discriminator_optimzer.step()

        #Feature Extractor
        real_features = feature_extractor(high_res_real)
        fake_features = feature_extractor(high_res_fake)

        #Train G
        generator.zero_grad()
        generator_content_loss = content_criterion(high_res_fake, high_res_real) + 0.006*content_criterion(fake_features, real_features)
        generator_adversarial_loss = adversarial_criterion(discriminator(high_res_fake), ones_const)
        generator_total_loss = generator_content_loss + 0.001 * generator_adversarial_loss
        generator_total_loss.backward()
        generator_optimzer.step()


    torch.save(generator.state_dict(), './checkpoint/generator_final.pth')
    torch.save(discriminator.state_dict(), './checkpoint/discriminator_final.pth')

KeyboardInterrupt: ignored

In [None]:
from torchvision.utils import save_image

generator.load_state_dict(torch.load("./checkpoint/generator_final.pth", map_location='cuda:0'))
generator.eval()
low_res = torch.zeros((1, 3, LOWRES, LOWRES), device=torch.device('cuda:0'))
for i, data in enumerate(valid_dataloader):
    high_res_real, _ = data
    low_res[0] = scale(high_res_real[0])
    high_res_fake = generator(low_res)
    for j in range(1):
        output = unnormalize(high_res_fake[j].cpu()).clamp(min=0, max=1)
        save_image(high_res_real, str(i) + '_real.png')
        save_image(output, str(i)+'_fake.png')       