[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/khetansarvesh/CV/blob/main/low_res2high_res/gans.ipynb)

In [None]:
import os
import numpy as np
from tqdm import tqdm
from PIL import Image
import kagglehub

import albumentations as A
from albumentations.pytorch import ToTensorV2

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
import torch.nn as nn
from torch import optim
from torchvision.utils import save_image
from torchvision.models import vgg19

  check_for_updates()


In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cudnn.benchmark = True

# **Dataset**

In [None]:
path = kagglehub.dataset_download("adityachandrasekhar/image-super-resolution")
print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/adityachandrasekhar/image-super-resolution?dataset_version_number=2...


100%|██████████| 301M/301M [00:16<00:00, 19.1MB/s]


Extracting files...
Path to dataset files: /root/.cache/kagglehub/datasets/adityachandrasekhar/image-super-resolution/versions/2


In [None]:
class MyImageFolder(Dataset):
    def __init__(self):
        super(MyImageFolder, self).__init__()
        self.base = "/root/.cache/kagglehub/datasets/adityachandrasekhar/image-super-resolution/versions/2/dataset/train"
        self.high_images = os.listdir(self.base + '/high_res')
        self.low_images = os.listdir(self.base + '/low_res')

    def __len__(self):
        return len(self.high_images)

    def __getitem__(self, index):

        # finding image index
        high_img = self.high_images[index % len(self.high_images)]
        low_img = self.low_images[index % len(self.low_images)]

        # finding image path
        high_path = os.path.join(self.base + "/high_res", high_img)
        low_path = os.path.join(self.base + "/low_res", low_img)

        # opening image and storing in array
        high_img = np.array(Image.open(high_path).convert("RGB"))
        low_img = np.array(Image.open(low_path).convert("RGB"))

        # performing transformations on the images zebra and horses
        transforms = A.Compose([A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ToTensorV2()])
        high_res = transforms(image=high_img)["image"]
        low_res = transforms(image=low_img)["image"]

        return low_res, high_res

In [None]:
dataset = MyImageFolder()
loader = DataLoader( dataset, batch_size=16, shuffle=True, pin_memory=True, num_workers=4)



# **Modelling**

In [None]:
class Generator(nn.Module):
    def __init__(self, img_channels=3, num_features=64, num_residuals=9):
        super().__init__()
        self.model = nn.Sequential(
                                        nn.Conv2d(3, 64, 7, 1, 3, padding_mode="reflect"), nn.InstanceNorm2d(64), nn.ReLU(inplace=True),
                                        nn.Conv2d(64, 128, 3, 2, 1, padding_mode="reflect"), nn.InstanceNorm2d(128), nn.ReLU(inplace=True),
                                        nn.Conv2d(128, 256, 3, 2, 1, padding_mode="reflect"), nn.InstanceNorm2d(256), nn.ReLU(inplace=True),
                                        nn.ConvTranspose2d(256, 128, 3, 2, 1, 1), nn.InstanceNorm2d(128),nn.ReLU(inplace=True),
                                        nn.ConvTranspose2d(128, 64, 3, 2, 1, 1), nn.InstanceNorm2d(64),nn.ReLU(inplace=True),
                                        nn.Conv2d(64, 3, 7, 1, 3, padding_mode="reflect")
                        )

    def forward(self, x):
        return torch.tanh(self.model(x))

In [None]:
# class ConvBlock(nn.Module):
#     def __init__(
#         self,
#         in_channels,
#         out_channels,
#         discriminator=False,
#         use_act=True,
#         use_bn=True,
#         **kwargs,
#     ):
#         super().__init__()
#         self.use_act = use_act
#         self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias=not use_bn)
#         self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
#         self.act = (
#             nn.LeakyReLU(0.2, inplace=True)
#             if discriminator
#             else nn.PReLU(num_parameters=out_channels)
#         )

#     def forward(self, x):
#         return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x))


# class UpsampleBlock(nn.Module):
#     def __init__(self, in_c, scale_factor):
#         super().__init__()
#         self.conv = nn.Conv2d(in_c, in_c * scale_factor ** 2, 3, 1, 1)
#         self.ps = nn.PixelShuffle(scale_factor)  # in_c * 4, H, W --> in_c, H*2, W*2
#         self.act = nn.PReLU(num_parameters=in_c)

#     def forward(self, x):
#         return self.act(self.ps(self.conv(x)))


# class ResidualBlock(nn.Module):
#     def __init__(self, in_channels):
#         super().__init__()
#         self.block1 = ConvBlock(
#             in_channels,
#             in_channels,
#             kernel_size=3,
#             stride=1,
#             padding=1
#         )
#         self.block2 = ConvBlock(
#             in_channels,
#             in_channels,
#             kernel_size=3,
#             stride=1,
#             padding=1,
#             use_act=False,
#         )

#     def forward(self, x):
#         out = self.block1(x)
#         out = self.block2(out)
#         return out + x


# class Generator(nn.Module):
#     def __init__(self, in_channels=3, num_channels=64, num_blocks=16):
#         super().__init__()
#         self.initial = ConvBlock(in_channels, num_channels, kernel_size=9, stride=1, padding=4, use_bn=False)
#         self.residuals = nn.Sequential(*[ResidualBlock(num_channels) for _ in range(num_blocks)])
#         self.convblock = ConvBlock(num_channels, num_channels, kernel_size=3, stride=1, padding=1, use_act=False)
#         self.upsamples = nn.Sequential(UpsampleBlock(num_channels, 2), UpsampleBlock(num_channels, 2))
#         self.final = nn.Conv2d(num_channels, in_channels, kernel_size=9, stride=1, padding=4)

#     def forward(self, x):
#         initial = self.initial(x)
#         x = self.residuals(initial)
#         x = self.convblock(x) + initial
#         x = self.upsamples(x)
#         return torch.tanh(self.final(x))


# class Discriminator(nn.Module):
#     def __init__(self, in_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]):
#         super().__init__()
#         blocks = []
#         for idx, feature in enumerate(features):
#             blocks.append(
#                 ConvBlock(
#                     in_channels,
#                     feature,
#                     kernel_size=3,
#                     stride=1 + idx % 2,
#                     padding=1,
#                     discriminator=True,
#                     use_act=True,
#                     use_bn=False if idx == 0 else True,
#                 )
#             )
#             in_channels = feature

#         self.blocks = nn.Sequential(*blocks)
#         self.classifier = nn.Sequential(
#             nn.AdaptiveAvgPool2d((6, 6)),
#             nn.Flatten(),
#             nn.Linear(512*6*6, 1024),
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.Linear(1024, 1),
#         )

#     def forward(self, x):
#         x = self.blocks(x)
#         return self.classifier(x)

In [None]:
class Discriminator(nn.Module):

    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
                                    nn.Conv2d(3, 64, 4, 2, 1, padding_mode="reflect"), nn.LeakyReLU(0.2, inplace=True),
                                    nn.Conv2d(64, 128, 4, 2, 1, bias = True, padding_mode="reflect"), nn.InstanceNorm2d(128), nn.LeakyReLU(0.2, inplace=True),
                                    nn.Conv2d(128, 256, 4, 2, 1, bias = True, padding_mode="reflect"), nn.InstanceNorm2d(256), nn.LeakyReLU(0.2, inplace=True),
                                    nn.Conv2d(256, 512, 4, 1, 1, bias = True, padding_mode="reflect"), nn.InstanceNorm2d(512), nn.LeakyReLU(0.2, inplace=True),
                                    nn.Conv2d(512, 1, 4, 1, 1, padding_mode="reflect")
                                  )

    def forward(self, x):
        return torch.sigmoid(self.model(x))

# **Training**

In [None]:
# instantiate the model
gen = Generator().to(DEVICE)
disc = Discriminator().to(DEVICE)

In [None]:
# Optimizers
opt_gen = optim.Adam(gen.parameters(), lr=1e-4, betas=(0.9, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=1e-4, betas=(0.9, 0.999))

In [None]:
# losses
class VGGLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = vgg19(pretrained=True).features[:36].eval().to(DEVICE)
        self.loss = nn.MSELoss()

        for param in self.vgg.parameters():
            param.requires_grad = False

    def forward(self, input, target):
        vgg_input_features = self.vgg(input)
        vgg_target_features = self.vgg(target)
        return self.loss(vgg_input_features, vgg_target_features)


mse = nn.MSELoss()
bce = nn.BCEWithLogitsLoss()
vgg_loss = VGGLoss()

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:03<00:00, 154MB/s]


In [None]:
for epoch in range(100):
    loop = tqdm(loader, leave=True)

    for idx, (low_res, high_res) in enumerate(loop):

        high_res = high_res.to(DEVICE)
        low_res = low_res.to(DEVICE)



        ''' Training Discriminator first keeping Generator Constant '''
        fake = gen(low_res)
        disc_real = disc(high_res)
        disc_fake = disc(fake.detach())
        disc_loss_real = bce(disc_real, torch.ones_like(disc_real) - 0.1 * torch.rand_like(disc_real))
        disc_loss_fake = bce(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = disc_loss_fake + disc_loss_real

        opt_disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        '''Training Generator Next Keeping Discriminator Constant'''
        disc_fake = disc(fake)
        adversarial_loss = 1e-3 * bce(disc_fake, torch.ones_like(disc_fake))
        loss_for_vgg = 0.006 * vgg_loss(fake, high_res)
        gen_loss = loss_for_vgg + adversarial_loss

        opt_gen.zero_grad()
        gen_loss.backward()
        opt_gen.step()


100%|██████████| 43/43 [01:46<00:00,  2.47s/it]
100%|██████████| 43/43 [00:45<00:00,  1.05s/it]
100%|██████████| 43/43 [00:45<00:00,  1.05s/it]
100%|██████████| 43/43 [00:45<00:00,  1.05s/it]
100%|██████████| 43/43 [00:45<00:00,  1.07s/it]
100%|██████████| 43/43 [00:45<00:00,  1.05s/it]
100%|██████████| 43/43 [00:45<00:00,  1.05s/it]
100%|██████████| 43/43 [00:45<00:00,  1.05s/it]
100%|██████████| 43/43 [00:45<00:00,  1.07s/it]
100%|██████████| 43/43 [00:45<00:00,  1.07s/it]
100%|██████████| 43/43 [00:45<00:00,  1.05s/it]
100%|██████████| 43/43 [00:45<00:00,  1.05s/it]
100%|██████████| 43/43 [00:45<00:00,  1.05s/it]
100%|██████████| 43/43 [00:45<00:00,  1.05s/it]
100%|██████████| 43/43 [00:45<00:00,  1.05s/it]
100%|██████████| 43/43 [00:45<00:00,  1.06s/it]
100%|██████████| 43/43 [00:45<00:00,  1.06s/it]
100%|██████████| 43/43 [00:45<00:00,  1.05s/it]
100%|██████████| 43/43 [00:45<00:00,  1.05s/it]
100%|██████████| 43/43 [00:45<00:00,  1.06s/it]
100%|██████████| 43/43 [00:45<00:00,  1.

# ****Inference

In [None]:
gen.eval()
image = Image.open("test_images/" + file)
with torch.no_grad():
    test_transform = A.Compose([A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),ToTensorV2()])
    transformed_img = test_transform(image=np.asarray(image))["image"]
    upscaled_img = gen(transformed_img.unsqueeze(0).to(DEVICE))

In [None]:
upscaled_img