In [12]:
import torch
from torch import nn
from tqdm import tqdm
import time
import os
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import cv2

In [9]:
class convolutionBlock(nn.Module):

    def __init__(self, inputChannels, outputChannels, useActivation, **kwargs):
        super().__init__()
        self.cnn = nn.Conv2d(
            inputChannels,
            outputChannels,
            **kwargs,
            bias= True
        )
        self.activation = nn.LeakyReLU(0.2, inplace=True) if useActivation else nn.Identity()

    def forward(self, x):
        return self.activation(self.cnn(x))


# Upsampling block
class UpsampleBlock(nn.Module):
    def __init__(self, inputChannels, scale_factor=2): # scale factor = 2 -> double the height and width
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=scale_factor, mode="nearest") # nearest neighbour upsampling used in paper
        self.conv= nn.Conv2d(inputChannels, inputChannels, 3, 1, 1, bias=True) # 3x3 kernel 
        self.activation = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        return self.activation(self.conv(self.upsample(x)))


# Residual block
class DenseResidualBlock(nn.Module):
    def __init__(self, inputChannels, channels=32, residual_beta=0.2): # 0.2 is constant specified in the paper
        super().__init__()
        self.residual_beta = residual_beta
        self.blocks = nn.ModuleList() # ModuleList maintains the parameters in loops

        # There are 5 residual blocks
        for i in range(5):
            self.blocks.append(
                convolutionBlock(
                    inputChannels + channels * i,
                    channels if i <= 3 else inputChannels, # reduce number of channels except for last one
                    kernel_size=3,
                    stride=1,
                    padding=1,
                    useActivation=True if i <= 3 else False,
                )
            )

    # Concatination of the residual blocks
    def forward(self, x):
        newInputs = x
        for block in self.blocks:
            out = block(newInputs)
            newInputs = torch.cat([newInputs, out], dim=1)
        return self.residual_beta * out + x # concat the original one and the new output


# Residual Dense Blocks
class RRDB(nn.Module):
    def __init__(self, inputChannels, residual_beta=0.2):
        super().__init__()
        self.residual_beta = residual_beta
        self.rrdb = nn.Sequential(*[DenseResidualBlock(inputChannels) for _ in range(3)]) # 3 blocks are required

    def forward(self, x):
        return self.rrdb(x) * self.residual_beta + x


# Generator block
class Generator(nn.Module):
    def __init__(self, inputChannels=3, numChannels=64, numBlocks=23):
        super().__init__()
        self.initial = nn.Conv2d(
            inputChannels,
            numChannels,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=True,
        )
        self.residuals = nn.Sequential(*[RRDB(numChannels) for _ in range(numBlocks)])
        self.conv = nn.Conv2d(numChannels, numChannels, kernel_size=3, stride=1, padding=1)
        self.upsamples = nn.Sequential(
            UpsampleBlock(numChannels), UpsampleBlock(numChannels),
        )
        self.final = nn.Sequential(
            nn.Conv2d(numChannels, numChannels, 3, 1, 1, bias=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(numChannels, inputChannels, 3, 1, 1, bias=True),
        )

    def forward(self, x):
        initial = self.initial(x)
        x = self.conv(self.residuals(initial)) + initial
        x = self.upsamples(x)
        return self.final(x)


# Discriminator block
class Discriminator(nn.Module):
    def __init__(self, inputChannels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]):
        super().__init__()
        blocks = []
        for idx, feature in enumerate(features):
            blocks.append(
                convolutionBlock(
                    inputChannels,
                    feature,
                    kernel_size=3,
                    stride=1 + idx % 2,
                    padding=1,
                    useActivation=True,
                ),
            )
            inputChannels = feature

        self.blocks = nn.Sequential(*blocks)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((6, 6)),
            nn.Flatten(),
            nn.Linear(512 * 6 * 6, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1),
        )

    def forward(self, x):
        x = self.blocks(x)
        return self.classifier(x)


# Initialize weights
def initialize_weights(model, scale=0.1): # multiplied by a scale to make images in the same scale
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight.data) # Kaiming is used in source code
            m.weight.data *= scale

        elif isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight.data) # Kaiming is used in source code
            m.weight.data *= scale


In [10]:
# Testing the model
def test():
    gen = Generator()
    disc = Discriminator()
    low_res = 24
    x = torch.randn((5, 3, low_res, low_res))
    gen_out = gen(x)
    disc_out = disc(gen_out)

    print(gen_out.shape)
    print(disc_out.shape)


test()

torch.Size([5, 3, 96, 96])
torch.Size([5, 1])
