<a href="https://colab.research.google.com/github/chang-heekim/Implementation_Deep_Learning_Paper/blob/main/SRGAN/SRGAN_Pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Load Necessary Library

In [2]:
import glob
import numpy as np
from PIL import Image
import time

import torch
from torch import nn
from torchvision.models import vgg19
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from torch.autograd import Variable

# Define Models
- FeatureExtractor
- ResidualBlock
- Generator
- Discriminator


In [4]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = nn.Sequential(*list(vgg19(pretrained=True).features.children())[:18])

    def forward(self, input):
        return self.feature_extractor(input)

class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_features, 0.8),
            nn.PReLU(),
            nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_features, 0.8),
        )

    def forward(self, input):
        return input + self.conv_block(input)

In [41]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, n_residual_blocks=16):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=9, stride=1, padding=4),
            nn.PReLU()
        )

        residual_blocks = []
        for _ in range(n_residual_blocks):
            residual_blocks.append(ResidualBlock(64))
        self.residual_blocks = nn.Sequential(*residual_blocks)

        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8)
        )

        upsampling = []
        for _ in range(2):
            upsampling += [
                nn.Conv2d(64, 256, 3, 1, 1),
                nn.BatchNorm2d(256),
                nn.PixelShuffle(upscale_factor=2),
                nn.PReLU(),
            ]
        self.upsampling = nn.Sequential(*upsampling)   

        self.conv3 = nn.Sequential(nn.Conv2d(64, out_channels, kernel_size=9, stride=1, padding=4), nn.Tanh())
    
    def forward(self, input):
        out1 = self.conv1(input)
        out = self.residual_blocks(out1)
        out2 = self.conv2(out)
        out = torch.add(out1, out2) 
        out = self.upsampling(out)
        out = self.conv3(out)
        return out

In [10]:
class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super().__init__()

        self.input_shape = input_shape
        c, h, w = self.input_shape
        patch_h, patch_w = int(h / (2 ** 4)), int(w / (2 ** 4))
        self.output_shape = (1, patch_h, patch_w)

        layers = []
        in_filters = c
        for i, out_filters in enumerate([64, 128, 256, 512]):
            layers.extend(self._discriminator_block(in_filters, out_filters, first_block=(i==0)))
            in_filters = out_filters
        layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))
        self.model = nn.Sequential(*layers)

    def _discriminator_block(self, in_filters, out_filters, first_block=False):
        layers = []
        layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
        if not first_block:
            layers.append(nn.BatchNorm2d(out_filters))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
        layers.append(nn.BatchNorm2d(out_filters))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers

    def forward(self, input):
        return self.model(input)

In [7]:
def to_rgb(image):
    """
    Function that changes mode from non-rgb to rgb
    """
    rgb_image = Image.new("RGB", image.size)
    rgb_image.paste(image)
    return rgb_image

# Define Custom Dataset Class

In [25]:
class CustomDataset(Dataset):
    def __init__(self, root, hr_shape):
        hr_h, hr_w = hr_shape
        self.lr_transform = transforms.Compose([
            transforms.Resize((hr_h // 4, hr_h // 4), Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])

        self.hr_transform = transforms.Compose([
            transforms.Resize((hr_h, hr_h), Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])

        self.files = sorted(glob.glob(root + '/*.*'))

    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])

        if img.mode != 'RGB':
            img = to_rgb(img)

        img_lr = self.lr_transform(img) 
        img_hr = self.hr_transform(img)

        return {"lr": img_lr, "hr": img_hr}
    
    def __len__(self):
        return len(self.files)

# Set up Hyper Parameters

In [38]:
import os 

os.makedirs("images", exist_ok=True)
os.makedirs("saved_models", exist_ok=True)

epochs = 200
root = '/content/Scene-Classification-Dataset/train-scene classification/train'
batch_size =4
lr = 0.0002
b1 = 0.5
b2 = 0.999
num_workers = 4
hr_height = 128
hr_width = 128
channels = 3
sample_interval = 100

# Set up Models & Criterion & Optimizer

In [51]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
hr_shape = (hr_height, hr_width)

generator = Generator().to(device)
discriminator = Discriminator(input_shape=(channels, *hr_shape)).to(device)

feature_extractor = FeatureExtractor().to(device)
feature_extractor.eval()

criterion_GAN = torch.nn.MSELoss()
criterion_content = torch.nn.L1Loss()

optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

In [50]:
data_loader = DataLoader(
    CustomDataset(root, hr_shape=hr_shape),
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
)

  "Argument interpolation should be of type InterpolationMode instead of int. "


# Training

In [None]:
start_time = time.time()

for epoch in range(1, epochs + 1):
    for idx, imgs in enumerate(data_loader):
        imgs_lr = imgs['lr'].to(device)
        imgs_hr = imgs['hr'].to(device)

        valid = torch.cuda.FloatTensor(np.ones((imgs_lr.size(0), *discriminator.output_shape)))
        fake = torch.cuda.FloatTensor(np.zeros((imgs_lr.size(0), *discriminator.output_shape)))
        
        # Training Generator
        optimizer_G.zero_grad()

        gen_hr = generator(imgs_lr)
        loss_GAN = criterion_GAN(discriminator(gen_hr), valid)

        gen_features = feature_extractor(gen_hr)
        real_features = feature_extractor(imgs_hr)
        loss_content = criterion_content(gen_features, real_features.detach())

        loss_G = loss_content + 1e-3 * loss_GAN

        loss_G.backward()
        optimizer_G.step()

        # Training Discriminator
        optimizer_D.zero_grad()

        loss_real = criterion_GAN(discriminator(imgs_hr), valid)
        loss_fake = criterion_GAN(discriminator(gen_hr.detach()), fake)

        loss_D = (loss_real + loss_fake) / 2

        loss_D.backward()
        optimizer_D.step()

        if (idx + 1) % sample_interval == 0:
            # Save image grid with upsampled inputs and SRGAN outputs
            imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
            gen_hr = make_grid(gen_hr, nrow=1, normalize=True)
            imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
            img_grid = torch.cat((imgs_lr, gen_hr), -1)
            save_image(img_grid, "images/%d.png" % (idx + 1), normalize=False)
            print('[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [Time elapsed (seconds): %f]' %
                  (epoch, epochs, idx + 1, len(data_loader), loss_D.item(), loss_G.item(), time.time() - start_time))
            # Save model checkpoints
            torch.save(generator.state_dict(), "saved_models/generator_%d_latest.pth" % epoch)
            torch.save(discriminator.state_dict(), "saved_models/discriminator_%d_latest.pth" % epoch)

[Epoch 1/200] [Batch 100/6084] [D loss: 0.014545] [G loss: 1.308950] [Time elapsed (seconds): 11.557683]
[Epoch 1/200] [Batch 200/6084] [D loss: 0.005165] [G loss: 1.157285] [Time elapsed (seconds): 22.945740]
[Epoch 1/200] [Batch 300/6084] [D loss: 0.004603] [G loss: 1.045570] [Time elapsed (seconds): 34.344923]
[Epoch 1/200] [Batch 400/6084] [D loss: 0.003354] [G loss: 1.181427] [Time elapsed (seconds): 45.836504]
[Epoch 1/200] [Batch 500/6084] [D loss: 0.002076] [G loss: 1.191220] [Time elapsed (seconds): 57.327128]
[Epoch 1/200] [Batch 600/6084] [D loss: 0.001783] [G loss: 1.116439] [Time elapsed (seconds): 68.888235]
[Epoch 1/200] [Batch 700/6084] [D loss: 0.001699] [G loss: 1.118929] [Time elapsed (seconds): 80.464166]
[Epoch 1/200] [Batch 800/6084] [D loss: 0.003504] [G loss: 1.579327] [Time elapsed (seconds): 92.066290]
[Epoch 1/200] [Batch 900/6084] [D loss: 0.002869] [G loss: 1.026241] [Time elapsed (seconds): 103.692144]
[Epoch 1/200] [Batch 1000/6084] [D loss: 0.001100] [G 