In [None]:
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch.nn as nn
from torchvision.models import vgg19, VGG19_Weights
import torch
from torch import optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

import tqdm
import random

In [None]:
device = 'cuda' if torch.cuda.is_available else 'cpu'
device

'cuda'

In [None]:
def augment_images(start, end, orig_size, transform_type):
    for i in range(start, end):
        for res_quality in ['low_res', 'high_res']:
            im = Image.open(f'./SRGANData/train/{res_quality}/{i}.png')
            im_t = im.transpose(transform_type)
            im_t.save(f"./SRGANData/train/{res_quality}/{orig_size+i}.png")

augment_images(0, 171, 685, Image.Transpose.ROTATE_180)
augment_images(171, 342, 685, Image.Transpose.FLIP_TOP_BOTTOM)
augment_images(342, 514, 685, Image.Transpose.FLIP_LEFT_RIGHT)
augment_images(514, 685, 685, Image.Transpose.TRANSVERSE)

In [None]:
low_res_size = 128
high_res_size = 256

transform_low = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((low_res_size, low_res_size)),
    transforms.ToTensor(),
])

transform_high = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((high_res_size, high_res_size)),
    transforms.ToTensor(),
])

In [None]:
class ImageDataset(Dataset):
    def __init__(self, root_dir):
        super(ImageDataset, self).__init__()
        self.root_dir = root_dir
        files_low = os.listdir(os.path.join(root_dir, "low_res"))
        files_high = os.listdir(os.path.join(root_dir, "high_res"))
        self.data = list(zip(files_low, files_high))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        #Stores file name of the png
        img_low_file, img_high_file = self.data[index]

        #Gets their paths w.r.t current working directory
        low_res_pth = os.path.join(self.root_dir, "low_res", img_low_file)
        high_res_pth = os.path.join(self.root_dir, "high_res", img_high_file)

        #Numpy arrays of size (H,W,C=4)
        low_res = np.array(Image.open(low_res_pth))
        high_res = np.array(Image.open(high_res_pth))

        #Color channels (C) = 3
        low_res = low_res[:, :, :3]
        high_res = high_res[:, :, :3]

        #Tensors of shape (C=3, H, W)
        low_res = transform_low(low_res)
        high_res = transform_high(high_res)

        return low_res, high_res

In [None]:
#Hyperparameters
lr = 3e-4
epochs = 7
batch_size = 5

In [None]:
#This is used as nothing more than a function. It is not trained.
#Note that the original paper mentions only using the first 19 layers of vgg19

class vggL(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = vgg19(weights=VGG19_Weights.DEFAULT).features[:25].eval().to(device)
        self.MSE = nn.MSELoss()

    def forward(self, first, second):
        vgg_first = self.vgg(first)
        vgg_second = self.vgg(second)
        perceptual_loss = self.MSE(vgg_first, vgg_second)
        return perceptual_loss

In [None]:
class ConvBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            use_activation=True,
            use_BatchNorm=True,
            **kwargs
    ):

        super(ConvBlock, self).__init__()
        self.use_activation = use_activation
        self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels) if use_BatchNorm else nn.Identity()
        self.ac = nn.LeakyReLU(0.2) 
                                    
    def forward(self, x):
        x1 = self.cnn(x)
        x2 = self.bn(x1)
        x3 = self.ac(x2)
        return x3 if self.use_activation else x2


class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, scale_factor):
        super(UpsampleBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * scale_factor ** 2, kernel_size=2, stride=1, padding=1) 
        self.ps = nn.PixelShuffle(scale_factor)
        self.ac = nn.PReLU(num_parameters=in_channels)

    def forward(self, x):
        return self.ac(self.ps(self.conv(x)))


class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.b1 = ConvBlock(
            in_channels,
            in_channels,
            kernel_size=3,
            stride=1,
            padding=1
        )

        self.b2 = ConvBlock(
            in_channels,
            in_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            use_activation=False
        )

    def forward(self, x):
        out = self.b1(x)
        out = self.b2(out)
        return out + x 

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, num_channels=64, num_blocks=8):
        super(Generator, self).__init__()
        self.initial = ConvBlock(in_channels, num_channels, kernel_size=7, stride=1, padding=4, use_BatchNorm=False) #K=9 in the paper
        self.res = nn.Sequential(*[ResidualBlock(num_channels) for i in range(num_blocks)])
        self.conv = ConvBlock(num_channels, num_channels, kernel_size=3, stride=1, padding=1, use_activation=False)
        self.up = UpsampleBlock(num_channels, scale_factor=2) 
        self.final = nn.Conv2d(num_channels, in_channels, kernel_size=9, stride=1, padding=1)

    def forward(self, x):
        x = self.initial(x)
        c = self.res(x)
        c = self.conv(c) + x 
        c = self.up(c)
        return torch.sigmoid(self.final(c)) #Sigmoid activation not mentioned in the paper


class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]):
        super(Discriminator, self).__init__()
        blocks = []
        for idx, feature in enumerate(features):
            blocks.append(
                ConvBlock(
                    in_channels,
                    feature, #Numbers as given in the paper. See the diagram labels
                    kernel_size=3,
                    stride = idx % 2 + 1, #Even idx gives s=1, Odd idx gives 2. This is as per the architecture proposed in the paper
                    padding=1,
                    use_activation=True,
                    use_BatchNorm=idx != 0 #Only the first block in this has no Batch Normalization
                )
            )
            in_channels = feature  

        self.blocks = nn.Sequential(*blocks)

        self.mlp = nn.Sequential(
            nn.AdaptiveAvgPool2d((8, 8)), #Not there in the paper. Converts input (N,512,H,W) tensor into a (512,C,8,8) tensor
            nn.Flatten(),
            nn.Linear(512*8*8, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        x = self.blocks(x)
        x = self.mlp(x)
        return x

In [None]:
gen = Generator(in_channels=3).to(device)
disc = Discriminator(in_channels=3).to(device)
opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.9, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.9, 0.999))

bce = nn.BCEWithLogitsLoss()
vgg_loss = vggL()

train = ImageDataset(root_dir="./SRGANData/train/")
train_loader = DataLoader(train, batch_size=batch_size, num_workers = 2, shuffle=True)

val = ImageDataset(root_dir="./SRGANData/val/")
val_loader = DataLoader(val, batch_size=batch_size, num_workers = 2, shuffle=True)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:04<00:00, 133MB/s]


In [None]:
disc_loss_history = []
gen_loss_history = []
def train_fn(loader, disc, gen, opt_gen, opt_disc, bce, vgg_loss):
    #loop = tqdm(loader)
    disc_loss = 0
    gen_loss = 0

    for idx, (low_res, high_res) in enumerate(loader):
        high_res = high_res.to(device)
        low_res = low_res.to(device)

        fake = gen(low_res)

        disc_real = disc(high_res)
        disc_fake = disc(fake.detach())

        disc_loss_real = bce(disc_real, torch.ones_like(disc_real)) 
        disc_loss_fake = bce(disc_fake, torch.zeros_like(disc_fake)) 

        disc_loss = disc_loss_fake + disc_loss_real

        opt_disc.zero_grad()
        disc_loss.backward()
        opt_disc.step()

        disc_fake = disc(fake)
        adversarial_loss = 1e-3 * bce(disc_fake, torch.ones_like(disc_fake)) 
        loss_for_vgg = 0.006*vgg_loss(fake, high_res) #0.006 factor not mentioned in paper
        gen_loss = loss_for_vgg + adversarial_loss

        opt_gen.zero_grad()
        gen_loss.backward()
        opt_gen.step()

        if (idx+1)%20 == 0:
          with torch.no_grad():
            disc_loss_history.append(disc_loss)
            gen_loss_history.append(gen_loss)

    return gen_loss.detach().cpu(), disc_loss.detach().cpu()

In [None]:
def plot_examples(gen):
    dataset_test = ImageDataset(root_dir="./SRGANData/val/")
    loader = DataLoader(dataset_test, batch_size=16, num_workers=2)
    gen.eval()

    # Create a figure with two subplots
    fig, axs = plt.subplots(1, 3, figsize=(8, 4))
    chosen_batch = random.randint(0, len(loader)-1)
    for idx, (low_res, high_res) in enumerate(loader):
        if(chosen_batch == idx):
            chosen = random.randint(0, len(low_res)-1)

            axs[0].set_axis_off()
            axs[0].imshow(low_res[chosen].permute(1, 2, 0))
            axs[0].set_title("low res")

            with torch.no_grad():
                upscaled_img = gen(low_res[chosen].to(device).unsqueeze(0))

            axs[1].set_axis_off()
            axs[1].imshow(upscaled_img.permute(0, 2, 3, 1)[0].cpu())
            axs[1].set_title("predicted")

            axs[2].set_axis_off()
            axs[2].imshow(high_res[chosen].permute(1, 2, 0))
            axs[2].set_title("high res")

            if(idx == 1):
                break

    plt.show()

    gen.train()

In [None]:
for epoch in range(epochs):
     plot_examples(gen)
     gen_loss, disc_loss = train_fn(train_loader, disc, gen, opt_gen, opt_disc, bce, vgg_loss)
     print("epoch ", epoch+1, "/", epochs)
     print(f"gen_loss: {gen_loss:.4f}, disc_loss:{disc_loss:.4f}")

In [None]:
## Saving the model

# srgan_G_param_file='SRGAN_G.pth'
# srgan_D_param_file='SRGAN_D.pth'
# torch.save(gen.state_dict(), srgan_G_param_file)
# torch.save(disc.state_dict(), srgan_D_param_file)

In [None]:
def upscale(model, image_path, return_image=False):
    transform_low2 = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
    ])
    
    img  = transform_low2(Image.open(image_path))
    img = img[:3,:,:]
    img = torch.reshape(img, (1,3,128,128))
    upscaled_img = model(img).detach()

    if not return_image:
        _, axs = plt.subplots(1, 2, figsize=(8,4))
        img = img.permute(0,2,3,1)[0]

        axs[0].set_axis_off()
        axs[0].imshow(img)
        axs[0].set_title("Original")

        axs[1].set_axis_off()
        axs[1].imshow(upscaled_img.permute(0, 2, 3, 1)[0])
        axs[1].set_title("Upscaled")

        plt.show()
    else:
        return upscaled_img.permute(0, 2, 3, 1)