###Import

In [26]:
import os
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import to_tensor
from torch.utils.data import Dataset
import torch
from torch import nn
from torchvision.models import mobilenet_v2
import math
import torch.nn.functional as F
import pandas as pd
import torchvision
from torch.utils.data import DataLoader
from tqdm import tqdm
import time

###Ult

In [27]:
def gaussian(window_size, sigma):
    gauss = torch.Tensor([math.exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
    return gauss / gauss.sum()


def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
    return window


def _ssim(img1, img2, window, window_size, channel, size_average=True):
    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2

    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)


class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)

            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)

            self.window = window
            self.channel = channel

        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)


def ssim(img1, img2, window_size=11, size_average=True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)

    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)

    return _ssim(img1, img2, window, window_size, channel, size_average)
    

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])

def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)

class ValDataset(Dataset):
    def __init__(self, dataset_dir, upscale_factor):
        super(ValDataset, self).__init__()
        # self.crop_size = calculate_valid_crop_size(crop_size, upscale_factor)
        self.upscale_factor = upscale_factor
        self.image_filenames = [os.path.join(dataset_dir, x) for x in os.listdir(dataset_dir) if is_image_file(x)]
        self.single_filenames = [x for x in os.listdir(dataset_dir) if is_image_file(x)]

    def __getitem__(self, index):
        hr_image = Image.open(self.image_filenames[index]).convert('RGB')
        image_width = (hr_image.width // self.upscale_factor) * self.upscale_factor
        image_height = (hr_image.height // self.upscale_factor) * self.upscale_factor

        hr_scale = transforms.Resize((image_height, image_width), interpolation=Image.BICUBIC)
        lr_scale = transforms.Resize((image_height // self.upscale_factor, image_width // self.upscale_factor), interpolation=Image.BICUBIC)
        # hr_image = transforms.CenterCrop(self.crop_size)(hr_image)
        lr_image = lr_scale(hr_image)
        hr_image = hr_scale(hr_image)
        return self.single_filenames[index], to_tensor(lr_image), to_tensor(hr_image)

    def __len__(self):
        return len(self.image_filenames)

###Define model

In [40]:
class SeperableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1, bias=True):
        super(SeperableConv2d, self).__init__()
        self.depthwise = nn.Conv2d(
            in_channels,
            in_channels,
            kernel_size=kernel_size,
            stride = stride,
            groups=in_channels,
            bias=bias,
            padding=padding
        )
        self.pointwise = nn.Conv2d(
            in_channels,
            out_channels, 
            kernel_size=1,
            bias=bias
        )
    def forward(self, x):
        return self.pointwise(self.depthwise(x))
    

    
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_act=True, use_bn=True, discriminator=False, **kwargs):
        super(ConvBlock, self).__init__()
        
        self.use_act = use_act
        self.cnn = SeperableConv2d(in_channels, out_channels, **kwargs, bias=not use_bn)
        self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
        self.act = nn.LeakyReLU(0.2, inplace=True) if discriminator else nn.PReLU(num_parameters=out_channels)
        
    def forward(self, x):
        return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x))


class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, scale_factor):
        super(UpsampleBlock, self).__init__()
        
        self.conv = SeperableConv2d(in_channels, in_channels * scale_factor**2, kernel_size=3, stride=1, padding=1)
        self.ps = nn.PixelShuffle(scale_factor) # (in_channels * 4, H, W) -> (in_channels, H*2, W*2)
        self.act = nn.PReLU(num_parameters=in_channels)
    
    def forward(self, x):
        return self.act(self.ps(self.conv(x)))
        

class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        
        self.block1 = ConvBlock(
            in_channels,
            in_channels,
            kernel_size=3,
            stride=1,
            padding=1
        )
        self.block2 = ConvBlock(
            in_channels,
            in_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            use_act=False
        )
        
    def forward(self, x):
        out = self.block1(x)
        out = self.block2(out)
        return out + x
    
    
class Generator(nn.Module):
    def __init__(self, in_channels: int = 3, num_channels: int = 64, num_blocks: int = 16, upscale_factor: int = 4):
        super(Generator, self).__init__()
        
        self.initial = ConvBlock(in_channels, num_channels, kernel_size=9, stride=1, padding=4, use_bn=False)
        self.residual = nn.Sequential(
            *[ResidualBlock(num_channels) for _ in range(num_blocks)]
        )
        self.convblock = ConvBlock(num_channels, num_channels, kernel_size=3, stride=1, padding=1, use_act=False)
        self.upsampler = nn.Sequential(
            *[UpsampleBlock(num_channels, scale_factor=2) for _ in range(upscale_factor//2)]
        )
        self.final_conv = SeperableConv2d(num_channels, in_channels, kernel_size=9, stride=1, padding=4)
        
    def forward(self, x):
        initial = self.initial(x)
        x = self.residual(initial)
        x = self.convblock(x) + initial
        x = self.upsampler(x)
        return (torch.tanh(self.final_conv(x)) + 1) / 2


class Discriminator(nn.Module):
    def __init__(
        self,
        in_channels: int = 3,
        features: tuple = (64, 64, 128, 128, 256, 256, 512, 512),
    ) -> None:
        super(Discriminator, self).__init__()

        blocks = []
        for idx, feature in enumerate(features):
            blocks.append(
                ConvBlock(
                    in_channels,
                    feature,
                    kernel_size=3,
                    stride=1 + idx % 2,
                    padding=1,
                    discriminator=True,
                    use_act=True,
                    use_bn=False if idx == 0 else True,
                )
            )
            in_channels = feature

        self.blocks = nn.Sequential(*blocks)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((6, 6)),
            nn.Flatten(),
            nn.Linear(512 * 6 * 6, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.blocks(x)
        return torch.sigmoid(self.classifier(x))

###Testing

In [29]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
UPSCALE_FACTOR = 4

In [33]:
def swiftsrgan_psnr_testing(_path, model_path, ds):
    netG = Generator().to(DEVICE)
    netG.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['model'])
    netG.eval()
    val_set = ValDataset(
        _path, upscale_factor=UPSCALE_FACTOR
    )

    val_loader = DataLoader(dataset=val_set, num_workers=0, batch_size=1, shuffle=False)

    with torch.no_grad():
        val_bar = tqdm(val_loader, total=len(val_loader))

        valing_results = {
            "mse": 0,
            "psnr": 0,
            "batch_sizes": 0,
            "ssim": 0,
            "ssims": 0,
        }
        nums = 0
        for name, val_lr, val_hr in val_bar:
            batch_size = val_lr.size(0)
            valing_results["batch_sizes"] += batch_size
            lr = val_lr
            hr = val_hr
            if torch.cuda.is_available():
                lr = lr.cuda()
                hr = hr.cuda()
                
            sr = netG(lr)

            batch_mse = ((sr - hr) ** 2).data.mean()
            valing_results["mse"] += batch_mse * batch_size
            batch_ssim = ssim(sr, hr).item()

            valing_results["psnr"] = 10 * math.log10(
                (hr.max() ** 2)
                / (valing_results["mse"] / valing_results["batch_sizes"])
            )
            valing_results["ssims"] += batch_ssim * batch_size
            valing_results["ssim"] = (
                valing_results["ssims"] / valing_results["batch_sizes"]
            )
            
            torchvision.utils.save_image(
                sr.data.cpu().squeeze(0),
                "./model_evaluation/data/SR/" + ds + "/" + name[0]
            )
            nums += 1
            # break
    
    return valing_results["psnr"], valing_results["ssim"]

In [31]:
def swiftsrgan_time_testing(_path, model_path, ds):
    netG = Generator().to(DEVICE)
    netG.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['model'])
    netG.eval()
    val_set = ValDataset(
        _path, upscale_factor=UPSCALE_FACTOR
    )
    val_loader = DataLoader(dataset=val_set, num_workers=0, batch_size=1, shuffle=False)
    val_bar = tqdm(val_loader, total=len(val_loader))

    exec_time = 0.0
    nums = 0
    with torch.no_grad():
        for name, val_lr, val_hr in val_bar:
            lr = val_lr
            hr = val_hr
            if torch.cuda.is_available():
                lr = lr.cuda()
                hr = hr.cuda()

            # torch.cuda.synchronize()
            t0 = time.time()
            _ = netG(lr)
            # torch.cuda.synchronize()
            exec_time += time.time() - t0
            nums += 1
    
    return exec_time / nums
    

In [41]:
dataset = ["Set5", "Set14"]
model_path = "./models/baseline.pth.tar"
for ds in dataset:
    psnr, ssim_val = swiftsrgan_psnr_testing("./dataset/SR_testing_datasets/" + ds, model_path, ds)
    exec_time = swiftsrgan_time_testing("./dataset/SR_testing_datasets/" + ds, model_path, ds)
    print(f'Execution time in {ds} = {exec_time}')
    print(f'PSNR in {ds} = {psnr}')
    print(f'SSIM in {ds} = {ssim_val}')

100%|██████████| 5/5 [00:03<00:00,  1.37it/s]
100%|██████████| 5/5 [00:01<00:00,  2.97it/s]


Execution time in Set5 = 0.32821102142333985
PSNR in Set5 = 27.922218817626515
SSIM in Set5 = 0.8342373847961426


100%|██████████| 14/14 [00:24<00:00,  1.73s/it]
100%|██████████| 14/14 [00:10<00:00,  1.36it/s]

Execution time in Set14 = 0.7187467302594867
PSNR in Set14 = 24.904779334216457
SSIM in Set14 = 0.7247450585876193



