https://www.kaggle.com/code/suraj520/srgan-psnr-19db-10-epoch-know-train-infer

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import transforms
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from torchvision.models.vgg import vgg16
from math import exp
import torch
import torch.nn.functional as F
import torchvision.utils as utils
from torch.autograd import Variable
from tqdm import tqdm
import math
import pandas as pd
import os
from os import listdir
import numpy as np
from PIL import Image
from os.path import join

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

upscale_factor = 8
crop_size = 88
num_epochs = 10

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])


def is_image(filename):
    return any(
        filename.endswith(extension)
        for extension in [".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG"]
    )


def calc_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)


def train_high_res_transform(crop_size):
    return transforms.Compose([transforms.RandomCrop(crop_size), transforms.ToTensor()])


def train_low_res_transform(crop_size, upscale_factor):
    return transforms.Compose(
        [
            transforms.ToPILImage(),
            transforms.Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC),
            transforms.ToTensor(),
        ]
    )

In [3]:
class TrainDataFromFolder(Dataset):
    def __init__(self, data_dir, crop_size, upscale_factor):
        super().__init__()
        self.image_file_names = [
            join(data_dir, x) for x in listdir(data_dir) if is_image(x)
        ]
        crop_size = calc_valid_crop_size(crop_size, upscale_factor)
        self.high_res_transform = train_high_res_transform(crop_size)
        self.low_res_transform = train_low_res_transform(crop_size, upscale_factor)

    def __getitem__(self, index):
        hr_image = self.high_res_transform(Image.open(self.image_file_names[index]))
        lr_image = self.low_res_transform(hr_image)
        return lr_image, hr_image

    def __len__(self):
        return len(self.image_file_names)


class ValDataFromFolder(Dataset):
    def __init__(self, data_dir, upscale_factor):
        super().__init__()
        self.upscale_factor = upscale_factor
        self.image_file_names = [
            join(data_dir, x) for x in listdir(data_dir) if is_image(x)
        ]

    def __getitem__(self, index):
        hr_image = Image.open(self.image_file_names[index])
        w, h = hr_image.size
        crop_size = calc_valid_crop_size(min(w, h), self.upscale_factor)
        lr_scale = transforms.Resize(
            crop_size // self.upscale_factor, interpolation=Image.BICUBIC
        )
        hr_scale = transforms.Resize(crop_size, interpolation=Image.BICUBIC)
        hr_image = transforms.CenterCrop(crop_size)(hr_image)
        lr_image = lr_scale(hr_image)
        hr_restored_image = hr_scale(lr_image)
        return (
            transforms.ToTensor()(lr_image),
            transforms.ToTensor()(hr_restored_image),
            transforms.ToTensor()(hr_image),
        )

    def __len__(self):
        return len(self.image_file_names)


train_set = TrainDataFromFolder(
    "ebany_research/weight_superresolution/kaggle_esrgan/dataset/train/",
    crop_size=crop_size,
    upscale_factor=upscale_factor,
)
val_set = ValDataFromFolder(
    "ebany_research/weight_superresolution/kaggle_esrgan/dataset/valid", upscale_factor=upscale_factor
)
train_loader = DataLoader(dataset=train_set, num_workers=0, batch_size=1, shuffle=True)
val_loader = DataLoader(dataset=val_set, num_workers=0, batch_size=1, shuffle=False)

In [4]:
next(iter(train_loader))[0].shape

torch.Size([1, 3, 11, 11])

In [18]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = self.conv1(x)
        residual = self.bn1(residual)
        residual = self.prelu(residual)
        residual = self.conv2(residual)
        residual = self.bn2(residual)
        return x + residual


class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super(UpsampleBlock, self).__init__()
        self.conv = nn.Conv2d(
            in_channels, in_channels * up_scale**2, kernel_size=3, padding=1
        )
        self.pixel_shuffle = nn.PixelShuffle(up_scale)
        self.prelu = nn.PReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x


class Generator(nn.Module):
    def __init__(self, scale_factor):
        super(Generator, self).__init__()
        upsample_block_num = int(math.log(scale_factor, 2))

        self.block1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=9, padding=4), nn.PReLU()
        )

        self.block2 = ResidualBlock(64)
        self.block3 = ResidualBlock(64)
        self.block4 = ResidualBlock(64)
        self.block5 = ResidualBlock(64)
        self.block6 = ResidualBlock(64)
        self.block7 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64)
        )
        block8 = [UpsampleBlock(64, 2) for _ in range(upsample_block_num)]
        block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
        self.block8 = nn.Sequential(*block8)

    def forward(self, x):
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)
        block6 = self.block6(block5)
        block7 = self.block7(block6)
        block8 = self.block8(block1 + block7)
        return (torch.tanh(block8) + 1) / 2


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, kernel_size=1),
        )

    def forward(self, x):
        batch_size = x.size()[0]
        return torch.sigmoid(self.net(x).view(batch_size))


class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super(TVLoss, self).__init__()
        self.tv_loss_weight = tv_loss_weight

    def forward(self, x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]

        count_h = self.tensor_size(x[:, :, 1:, :])
        count_w = self.tensor_size(x[:, :, :, 1:])

        h_tv = torch.pow(x[:, :, 1:, :] - x[:, :, : h_x - 1, :], 2).sum()
        w_tv = torch.pow(x[:, :, :, 1:] - x[:, :, :, : w_x - 1], 2).sum()
        return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

    @staticmethod
    def tensor_size(t):
        return t.size()[1] * t.size()[2] * t.size()[3]


class GeneratorLoss(nn.Module):
    def __init__(self):
        super(GeneratorLoss, self).__init__()
        # vgg = vgg16(pretrained=False)
        # loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
        # for param in loss_network.parameters():
        #     param.requires_grad = False
        # self.loss_network = loss_network
        self.mse_loss = nn.MSELoss()
        # self.tv_loss = TVLoss()

    def forward(self, out_labels, out_images, target_images):
        adversial_loss = torch.mean(1 - out_labels)
        perception_loss = self.mse_loss(out_images, target_images)
        image_loss = self.mse_loss(out_images, target_images)
        # tv_loss = self.tv_loss(out_images)
        return (
            image_loss
            + 0.001 * adversial_loss
            + 0.006 * perception_loss
            # + 2e-8 * tv_loss
        )

In [20]:
netG = Generator(upscale_factor)
netD = Discriminator()

generator_criterion = GeneratorLoss()

In [19]:
import gc
gc.collect()
del netG
del netD

In [21]:
generator_criterion = generator_criterion.to(device)
netG = netG.to(device)
netD = netD.to(device)

optimizerG = optim.Adam(netG.parameters(), lr=0.0002)
optimizerD = optim.Adam(netD.parameters(), lr=0.0002)

In [10]:
device

device(type='cuda')

In [22]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(count_parameters(netG))
print(count_parameters(netD))
print(count_parameters(generator_criterion))

881932
5067585
0


In [12]:
def gaussian(window_size, sigma):
    gauss = torch.Tensor(
        [
            exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2))
            for x in range(window_size)
        ]
    )
    return gauss / gauss.sum()


def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(
        _2D_window.expand(channel, 1, window_size, window_size).contiguous()
    )
    return window


def _ssim(img1, img2, window, window_size, channel, size_average=True):
    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = (
        F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
    )
    sigma2_sq = (
        F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
    )
    sigma12 = (
        F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
        - mu1_mu2
    )

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
        (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
    )

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)


def ssim(img1, img2, window_size=11, size_average=True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)

    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)

    return _ssim(img1, img2, window, window_size, channel, size_average)

In [23]:
results = {
    "d_loss": [],
    "g_loss": [],
    "d_score": [],
    "g_score": [],
    "psnr": [],
    "ssim": [],
}


<div style="background-color:#F0E3D2; color:#19180F; font-size:15px; font-family:Verdana; padding:10px; border: 2px solid #19180F; border-radius:10px"> 
📌
    <b>Remarks</b>    <br>

1. Mean square loss is often not great since it compares pixel values, SSIM(Structural similarity) tries to capture the structure of image including noise via statistic. SSIM looks at groups of pixels to decipher whether two images are same or not.<br>
2. psnr is peak signal to noise ratio used as yet another metric.<br>
3. TV loss obtains better edges by doing total variation(TV) on the reconstructed image and the residual between the reconstructed image and the original image.<br>
4. Training model simultaneously.</div>

In [24]:
for epoch in range(1, num_epochs + 100):
    train_bar = tqdm(train_loader)
    running_results = {
        "batch_sizes": 0,
        "d_loss": 0,
        "g_loss": 0,
        "d_score": 0,
        "g_score": 0,
    }

    netG.train()
    netD.train()
    for data, target in train_bar:
        g_update_first = True
        batch_size = data.size(0)
        running_results["batch_sizes"] += batch_size

        ############################
        # (1) Update D network: maximize D(x)-1-D(G(z))
        ###########################
        real_img = Variable(target)
        if torch.cuda.is_available():
            real_img = real_img.cuda()
        z = Variable(data)
        if torch.cuda.is_available():
            z = z.cuda()
        fake_img = netG(z)

        netD.zero_grad()
        real_out = netD(real_img).mean()
        fake_out = netD(fake_img).mean()
        d_loss = 1 - real_out + fake_out
        d_loss.backward(retain_graph=True)
        optimizerD.step()

        ############################
        # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
        ###########################
        netG.zero_grad()
        ## The two lines below are added to prevent runetime error in Google Colab ##
        fake_img = netG(z)
        fake_out = netD(fake_img).mean()
        ##
        g_loss = generator_criterion(fake_out, fake_img, real_img)
        g_loss.backward()

        fake_img = netG(z)
        fake_out = netD(fake_img).mean()

        optimizerG.step()

        # loss for current batch before optimization
        running_results["g_loss"] += g_loss.item() * batch_size
        running_results["d_loss"] += d_loss.item() * batch_size
        running_results["d_score"] += real_out.item() * batch_size
        running_results["g_score"] += fake_out.item() * batch_size

        train_bar.set_description(
            desc="[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f"
            % (
                epoch,
                num_epochs,
                running_results["d_loss"] / running_results["batch_sizes"],
                running_results["g_loss"] / running_results["batch_sizes"],
                running_results["d_score"] / running_results["batch_sizes"],
                running_results["g_score"] / running_results["batch_sizes"],
            )
        )

    netG.eval()

    with torch.no_grad():
        val_bar = tqdm(val_loader)
        valid_results = {"mse": 0, "ssims": 0, "psnr": 0, "ssim": 0, "batch_sizes": 0}
        val_images = []
        for val_lr, val_hr_restore, val_hr in val_bar:
            batch_size = val_lr.size(0)
            valid_results["batch_sizes"] += batch_size
            lr = val_lr
            hr = val_hr
            if torch.cuda.is_available():
                lr = lr.cuda()
                hr = hr.cuda()
            sr = netG(lr)

            batch_mse = ((sr - hr) ** 2).data.mean()
            valid_results["mse"] += batch_mse * batch_size
            batch_ssim = ssim(sr, hr).item()
            valid_results["ssims"] += batch_ssim * batch_size
            valid_results["psnr"] = 10 * math.log10(
                (hr.max() ** 2) / (valid_results["mse"] / valid_results["batch_sizes"])
            )
            valid_results["ssim"] = (
                valid_results["ssims"] / valid_results["batch_sizes"]
            )
            val_bar.set_description(
                desc="[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f"
                % (valid_results["psnr"], valid_results["ssim"])
            )

    # if not os.path.exists("epochs/"):
    #     os.makedirs("epochs/")
    # save model parameters
    # if epoch % 10 == 0:
    #     torch.save(
    #         netG.state_dict(),
    #         "ebany_research/weight_superresolution/kaggle_esrgan/epochs/netG_epoch_%d_%d.pth"
    #         % (upscale_factor, epoch),
    #     )
    #     torch.save(
    #         netD.state_dict(),
    #         "ebany_research/weight_superresolution/kaggle_esrgan/epochs/netD_epoch_%d_%d.pth"
    #         % (upscale_factor, epoch),
    #     )
    # save loss\scores\psnr\ssim
    results["d_loss"].append(running_results["d_loss"] / running_results["batch_sizes"])
    results["g_loss"].append(running_results["g_loss"] / running_results["batch_sizes"])
    results["d_score"].append(
        running_results["d_score"] / running_results["batch_sizes"]
    )
    results["g_score"].append(
        running_results["g_score"] / running_results["batch_sizes"]
    )
    results["psnr"].append(valid_results["psnr"])
    results["ssim"].append(valid_results["ssim"])

    # if epoch % 10 == 0 and epoch != 0:
    #     out_path = 'statistics/'
    #     if not os.path.exists(out_path):
    #       os.makedirs(out_path)

    #     data_frame = pd.DataFrame(
    #         data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
    #               'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
    #         index=range(1, epoch + 1))
    #     data_frame.to_csv(out_path + 'srf_' + str(upscale_factor) + '_train_results.csv', index_label='Epoch')

[1/10] Loss_D: 0.9935 Loss_G: 0.0777 D(x): 0.5184 D(G(z)): 0.4872: 100%|██████████| 2/2 [00:00<00:00,  7.81it/s]
[converting LR images to SR images] PSNR: 11.9653 dB SSIM: 0.3308: 100%|██████████| 1/1 [00:00<00:00,  6.37it/s]
[2/10] Loss_D: 0.9865 Loss_G: 0.0594 D(x): 0.5002 D(G(z)): 0.4802: 100%|██████████| 2/2 [00:00<00:00,  7.89it/s]
[converting LR images to SR images] PSNR: 11.7401 dB SSIM: 0.3250: 100%|██████████| 1/1 [00:00<00:00,  6.40it/s]
[3/10] Loss_D: 0.9555 Loss_G: 0.0923 D(x): 0.5261 D(G(z)): 0.4710: 100%|██████████| 2/2 [00:00<00:00,  7.67it/s]
[converting LR images to SR images] PSNR: 11.4853 dB SSIM: 0.3194: 100%|██████████| 1/1 [00:00<00:00,  6.32it/s]
[4/10] Loss_D: 0.9170 Loss_G: 0.0190 D(x): 0.5484 D(G(z)): 0.4550: 100%|██████████| 2/2 [00:00<00:00,  8.21it/s]
[converting LR images to SR images] PSNR: 11.1219 dB SSIM: 0.3105: 100%|██████████| 1/1 [00:00<00:00,  6.50it/s]
[5/10] Loss_D: 0.8863 Loss_G: 0.0619 D(x): 0.5734 D(G(z)): 0.4403: 100%|██████████| 2/2 [00:00<0


<div style="background-color:#F0E3D2; color:#19180F; font-size:15px; font-family:Verdana; padding:10px; border: 2px solid #19180F; border-radius:10px"> 
📌
Evaluating trained model on test image    </div>


In [14]:
upscale_factor = 8
model_name = "netG_epoch_8_20.pth"
model = Generator(upscale_factor).eval()
device = torch.device("cuda")
model = model.to(device)
model.load_state_dict(torch.load("ebany_research/weight_superresolution/epochs/netG_epoch_8_10.pth"))

<All keys matched successfully>


<div style="background-color:#F0E3D2; color:#19180F; font-size:15px; font-family:Verdana; padding:10px; border: 2px solid #19180F; border-radius:10px"> 
📌
saving the output and displaying it    </div>


In [None]:
# pass any other image, if needed
image_name = "ebany_research/weight_superresolution/dataset/valid/0003.png"
image = Image.open(image_name)
image = Variable(transforms.ToTensor()(image)).unsqueeze(0).to(device)
out = model(image)
out_img = transforms.ToPILImage()(out[0].data.cpu())
out_img.save("output.jpeg")

In [16]:
image.shape

torch.Size([1, 3, 1356, 2040])

In [None]:
print(image.shape)
print(out_img.shape)