# Super Resolution GAN (SRGAN) implementation

### Mount drive if in colab

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Navigate to repository
%cd /content/drive/MyDrive/Github/SRGAN

!pip install albumentations==0.4.6

### Import needed modules

In [None]:
import config
import torch
from torch import nn
# Optimization algorithms
import torch.optim as optim
# Dataset manager
from torch.utils.data import DataLoader

from torchvision.models import vgg19

## 0. Define and prepare data

In [None]:
from dataset import MyImageFolder

dataset = MyImageFolder(root_dir="new_data")
print(f"{len(dataset)} samples in dir {dataset.root_dir}/{dataset.class_names[0]}")

## 1. Create model

### Building blocks (Convolutional, Upsample, Residual)

In [None]:
class ConvBlock(nn.Module):
    # Conv -> BN (Batch Norm) -> Leaky/PReLu
    def __init__(
        self,
        in_channels,
        out_channels,
        discriminator=False,
        use_act=True,
        use_bn=True,
        **kwargs,
    ):
        super().__init__()
        self.use_act = use_act
        self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias=not use_bn)
        self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
        self.act = (
            nn.LeakyReLU(0.2, inplace=True) if discriminator else nn.PReLU(num_parameters=out_channels)
        )
    
    def forward(self, x):
        return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x))

class UpsampleBlock(nn.Module):
    def __init__(self, in_c, scale_factor):
        super().__init__()
        self.conv = nn.Conv2d(in_c, in_c * scale_factor ** 2, 3, 1, 1)
        self.ps = nn.PixelShuffle(scale_factor) # in_c *4, H, W, --> in_c, H*2, W*2
        self.act = nn.PReLU(num_parameters=in_c)

    def forward(self, x):
        return self.act(self.ps(self.conv(x)))

class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.block1 = ConvBlock(
            in_channels,
            in_channels,
            kernel_size = 3,
            stride = 1,
            padding = 1
        )
        self.block2 = ConvBlock(
            in_channels,
            in_channels,
            kernel_size = 3,
            stride = 1,
            padding = 1,
            use_act = False
        )

    def forward(self, x):
        out = self.block1(x)
        out = self.block2(out)
        return out + x

### Generator

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, num_channels=64, num_blocks=16):
        super().__init__()
        self.initial = ConvBlock(in_channels, num_channels, kernel_size=9, stride=1, padding=4, use_bn=False)
        self.residuals = nn.Sequential(*[ResidualBlock(num_channels) for _ in range(num_blocks)])
        self.convblock = ConvBlock(num_channels, num_channels, kernel_size=3, stride=1, padding=1, use_act=False)
        self.upsamples = nn.Sequential(UpsampleBlock(num_channels, 2), UpsampleBlock(num_channels, 2))
        self.final = nn.Conv2d(num_channels, in_channels, kernel_size=9, stride=1, padding=4)

    def forward(self, x):
        initial = self.initial(x)
        x = self.residuals(initial)
        x = self.convblock(x) + initial
        x = self.upsamples(x)
        return torch.tanh(self.final(x))

### Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]):
        super().__init__()
        blocks = []
        for idx, feature in enumerate(features):
            blocks.append(
                ConvBlock(
                    in_channels,
                    feature,
                    kernel_size = 3,
                    stride = 1 + idx % 2,
                    padding = 1,
                    discriminator = True,
                    use_act = True,
                    use_bn = False if idx == 0 else True
                )
            )
            in_channels = feature

        self.blocks = nn.Sequential(*blocks)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((6,6)),
            nn.Flatten(),
            nn.Linear(512*6*6, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        x = self.blocks(x)
        return self.classifier(x)

In [None]:
def test():
    low_resolution = 24
    with torch.cuda.amp.autocast():
        x = torch.randn((5, 3, low_resolution, low_resolution))
        gen = Generator()
        gen_out = gen(x)
        disc = Discriminator()
        disc_out = disc(gen_out)

        print(gen_out.shape)
        print(disc_out.shape)

test()

### Initialize Generator and Discriminator

In [None]:
gen = Generator(in_channels=3).to(config.DEVICE)
disc = Discriminator(in_channels=3).to(config.DEVICE)

## 2. Loss and optimizer

### Loss

In [None]:
class VGGLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = vgg19(pretrained=True).features[:36].eval().to(config.DEVICE)
        self.loss = nn.MSELoss()

        for param in self.vgg.parameters():
            param.requires_grad = False

    def forward(self, input, target):
        vgg_input_features = self.vgg(input)
        vgg_target_features = self.vgg(input)
        return self.loss(vgg_input_features, vgg_target_features)

bce = nn.BCEWithLogitsLoss()
vgg_loss_fun = VGGLoss()

### Optimizers

In [None]:
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.9, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.9, 0.999))

## 3. Training

### Train discriminator and generator functions

In [None]:
### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
def train_discriminator(D, opt, fake, high_res, bce):

    # Reset gradients to zero
    opt.zero_grad()

    # Train on real data
    pred_real = D(high_res)
    loss_real = bce(pred_real, torch.ones_like(pred_real) - 0.1 * torch.rand_like(pred_real))

    # Train on fake data
    pred_fake = D(fake.detach())
    loss_fake = bce(pred_fake, torch.zeros_like(pred_fake))

    loss = loss_real + loss_fake

    # Backward pass
    loss.backward()
    # Update weights
    opt.step()

    return loss

### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
def train_generator(D, opt, fake, high_res, vgg_loss):

    # Reset gradients to zero
    opt.zero_grad()

    pred_fake = D(fake)

    adv_loss = 1e-3 * bce(pred_fake, torch.ones_like(pred_fake))
    vgg_loss = 0.006 * vgg_loss(fake, high_res)

    loss = adv_loss + vgg_loss

    # Backward pass
    loss.backward()
    # Update weights
    opt.step()

    return loss

### Load data

In [None]:
loader = DataLoader(dataset, batch_size=config.BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=config.NUM_WORKERS)

### Load last checkpoint

In [None]:
from utils import load_checkpoint

if config.LOAD_MODEL:
    load_checkpoint(
        config.CHECKPOINT_GEN,
        gen,
        opt_gen,
        config.LEARNING_RATE
    )
    load_checkpoint(
        config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE
    )

### Training loop

In [None]:
from utils import plot_examples, save_checkpoint

print(f"Starting SRGAN training: \n")
print(f" Total training samples: {len(dataset)}\n Mini batch size: {config.BATCH_SIZE}\n Number of batches: {len(loader)}\n Learning rate: {config.LEARNING_RATE}\n")

for epoch in range(config.NUM_EPOCHS):
    for idx, (low_res, high_res) in enumerate(loader):

        # Send images to device
        high_res = high_res.to(config.DEVICE)
        low_res = low_res.to(config.DEVICE)

        # Generate fake (high_res) image from low_res
        fake = gen(low_res)

        loss_disc = train_discriminator(disc, opt_disc, fake, high_res, bce)
        loss_gen = train_generator(disc, opt_gen, fake, high_res, vgg_loss_fun)

        # Print progress every epoch and test the generator on an test image
        if idx == 0:
            print( 
                f"Epoch [{epoch}/{config.NUM_EPOCHS} - "
                f"Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}]"
                )
            plot_examples("low_res/", gen, epoch)

    if config.SAVE_MODEL:
        save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
        save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC)