### SRGAN

This notebook implements SRGAN model along with training and test data creation.

In [None]:
"""
Import Library
"""
from torch import nn
import h5py
import numpy as np
import glob
import os
from PIL import Image
from torch.utils.data import Dataset
import torch.optim as optim
from torch.utils.data.dataloader import DataLoader
import torch
from tqdm import tqdm
from collections import namedtuple
import copy
import math
from torch.autograd import Variable
import pandas as pd

In [None]:
"""
SRGAN model
"""
class Generator(nn.Module):
    def __init__(self, scale_factor):
        upsample_block_num = int(math.log(scale_factor, 2))

        super(Generator, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=9, padding=4),
            nn.PReLU()
        )
        self.block2 = ResidualBlock(32)
        self.block3 = ResidualBlock(32)
        self.block4 = ResidualBlock(32)
        self.block5 = ResidualBlock(32)
        self.block6 = ResidualBlock(32)
        self.block7 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32)
        )
        block8 = [UpsampleBLock(32, 2) for _ in range(upsample_block_num)]
        block8.append(nn.Conv2d(32, 3, kernel_size=9, padding=4))
        self.block8 = nn.Sequential(*block8)

    def forward(self, x):
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)
        block6 = self.block6(block5)
        block7 = self.block7(block6)
        block8 = self.block8(block1 + block7)

        return (torch.tanh(block8) + 1) / 2


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 512, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1, kernel_size=1)
        )

    def forward(self, x):
        batch_size = x.size(0)
        return torch.sigmoid(self.net(x).view(batch_size))


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = self.conv1(x)
        residual = self.bn1(residual)
        residual = self.prelu(residual)
        residual = self.conv2(residual)
        residual = self.bn2(residual)

        return x + residual


class UpsampleBLock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super(UpsampleBLock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(up_scale)
        self.prelu = nn.PReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x

In [None]:
"""
Dataset feeding
"""
class CustomDataset(Dataset):
    def __init__(self, h5_file):
        super(CustomDataset, self).__init__()
        self.h5_file = h5_file

    def __getitem__(self, idx):
        with h5py.File(self.h5_file, 'r') as f:
            return f['lr'][idx], f['hr'][idx]

    def __len__(self):
        with h5py.File(self.h5_file, 'r') as f:
            return len(f['lr'])


In [None]:
"""
Loss Functions
"""
from torchvision.models.vgg import vgg16

# TV loss is optional but implemented in paper
class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super(TVLoss, self).__init__()
        self.tv_loss_weight = tv_loss_weight

    def forward(self, x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self.tensor_size(x[:, :, 1:, :])
        count_w = self.tensor_size(x[:, :, :, 1:])
        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
        return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

    @staticmethod
    def tensor_size(t):
        return t.size()[1] * t.size()[2] * t.size()[3]

class GeneratorLoss(nn.Module):
    def __init__(self):
        super(GeneratorLoss, self).__init__()
        # use VGG16 for loss calculation
        vgg = vgg16(pretrained=True, progress=False)
        loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
        for param in loss_network.parameters():
            param.requires_grad = False
        self.loss_network = loss_network
        self.mse_loss = nn.MSELoss()
        self.tv_loss = TVLoss()

    def forward(self, out_labels, out_images, target_images):
        # Adversarial Loss
        adversarial_loss = torch.mean(1 - out_labels)
        # Perception Loss
        perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
        # Image Loss
        image_loss = self.mse_loss(out_images, target_images)
        # TV Loss
        tv_loss = self.tv_loss(out_images)
        return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss



In [None]:
"""
Setup data loader
"""
batch_size = 1

train_dataset = CustomDataset('dataset/train/train_full_s.h5')
train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=0,
                                  pin_memory=True,
                                  drop_last=True)
eval_dataset = CustomDataset('dataset/val/val_full_s.h5')
eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=batch_size)


In [None]:
"""
Setup network parameter
"""
upscale_factor = 4
num_epoch = 20

torch.manual_seed(123)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
"""
Setup network
"""
netG = Generator(upscale_factor)
netD = Discriminator()
generator_criterion = GeneratorLoss()

if torch.cuda.is_available():
    netG.to(device)
    netD.to(device)
    generator_criterion.to(device)

optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())

In [None]:
"""
Util function to measure error
"""
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

"""
Calculate PSNR
"""
def calc_psnr(img1, img2):
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))

In [None]:
results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': []}
best_weights = copy.deepcopy(netG.state_dict())
best_epoch = 0
best_psnr = 0.0

for epoch in range(1, num_epoch + 1):

    epoch_losses = AverageMeter()
    netG.train()
    netD.train()

    with tqdm(total=(len(train_dataset) - len(train_dataset) % 1)) as t:
        t.set_description('epoch: {}/{}'.format(epoch, 30 - 1))

        running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}

        # training
        netG.train()
        netD.train()

        for data in train_dataloader:
            inputs, labels = data

            g_update_first = True
            batch_size = inputs.size(0)
            running_results['batch_sizes'] += batch_size

            # Update D network
            real_img = Variable(labels).to(device, dtype=torch.float)
            z = Variable(inputs).to(device, dtype=torch.float)

            fake_img = netG(z)

            netD.zero_grad()
            real_out = netD(real_img).mean()
            fake_out = netD(fake_img).mean()
            d_loss = 1 - real_out + fake_out
            d_loss.backward(retain_graph=True)

            # Update G network
            netG.zero_grad()
            g_loss = generator_criterion(fake_out, fake_img, real_img)
            g_loss.backward()

            epoch_losses.update(g_loss.item(), len(inputs))

            optimizerD.step()
            optimizerG.step()

            # Loss for current batch
            running_results['g_loss'] += g_loss.item() * batch_size
            running_results['d_loss'] += d_loss.item() * batch_size
            running_results['d_score'] += real_out.item() * batch_size
            running_results['g_score'] += fake_out.item() * batch_size

            t.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
                epoch, num_epoch, running_results['d_loss'] / running_results['batch_sizes'],
                running_results['g_loss'] / running_results['batch_sizes'],
                running_results['d_score'] / running_results['batch_sizes'],
                running_results['g_score'] / running_results['batch_sizes']))
            t.update(len(inputs))

        torch.save(netG.state_dict(), 'weight_srgan/netG_epoch_%d.pth' % epoch)
        torch.save(netD.state_dict(), 'weight_srgan/netD_epoch_%d.pth' % epoch)

        # validation
        netG.eval()
        epoch_psnr = AverageMeter()

        with torch.no_grad():
            val_images = []
            for data in eval_dataloader:
                inputs, labels = data
                inputs = inputs.to(device, dtype=torch.float)
                labels = labels.to(device, dtype=torch.float)

                preds = netG(inputs)

                epoch_psnr.update(calc_psnr(preds, labels), len(inputs))

            print('eval psnr: {:.2f}'.format(epoch_psnr.avg))

            if epoch_psnr.avg > best_psnr:
                best_epoch = epoch
                best_psnr = epoch_psnr.avg
                best_weights = copy.deepcopy(netG.state_dict())

In [None]:
"""
Evaluate the model with test set
"""
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = Generator(upscale_factor).to(device)
state_dict = model.state_dict()
for n, p in torch.load('weight_srgan/best.pth', map_location=lambda storage, loc: storage).items():
    if n in state_dict.keys():
        state_dict[n].copy_(p)
    else:
        raise KeyError(n)

model.eval()


In [None]:
from pytorch_ssim import pytorch_ssim

lr_image_path = 'dataset/test/images_stage5/*.png'
lr_image_list = glob.glob(lr_image_path)
hr_image_path = 'dataset/test/images_stage3/*.png'
hr_image_list = glob.glob(hr_image_path)

psnr_total = 0
ssim_total = 0

for i, img in enumerate(lr_image_list):
    image = Image.open(img).convert('RGB')
    image = np.array(image).astype(np.float32)
    image = np.transpose(image, axes=[2, 0, 1])
    image /= 255.0

    image = torch.from_numpy(image).to(device)
    image = image.unsqueeze(0)

    label = Image.open(hr_image_list[i]).convert('RGB')
    label = np.array(label).astype(np.float32)
    label = np.transpose(label, axes=[2, 0, 1])
    label /= 255.0
    label = torch.from_numpy(label).to(device)
    label = label.unsqueeze(0)

    with torch.no_grad():
        preds = model(image).clamp(0.0, 1.0)

    psnr = calc_psnr(label, preds)
    psnr_total += psnr
    ssim = pytorch_ssim.ssim(label, preds)
    ssim_total += ssim
    print('PSNR: {:.2f}'.format(psnr))
    print('SSIM: {:.2f}'.format(ssim))

    output = preds.mul_(255.0).clamp_(0.0, 255.0).squeeze(0).permute(1, 2, 0).byte().cpu().numpy()
    output = Image.fromarray(output)
    output.save(f'result_srgan_new/img_{i}.png')

psnr_total /= len(lr_image_list)
ssim_total /= len(lr_image_list)
print('PSNR_T: {:.4f}'.format(psnr_total))
print('SSIM_T: {:.4f}'.format(ssim_total))