###Import

In [1]:
import os
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import to_tensor
from torch.utils.data import Dataset
import torch
from torch import nn
from torchvision.models import mobilenet_v2
import math
import torch.nn.functional as F
import pandas as pd
import torchvision
from torch.utils.data import DataLoader
from tqdm import tqdm
import time
import numpy as np
import imageio
import shutil

###Ult

In [2]:
def gaussian(window_size, sigma):
    gauss = torch.Tensor([math.exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
    return gauss / gauss.sum()


def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
    return window


def _ssim(img1, img2, window, window_size, channel, size_average=True):
    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2

    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)


class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)

            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)

            self.window = window
            self.channel = channel

        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)

def ssim(img1, img2, window_size=11, size_average=True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)

    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)

    return _ssim(img1, img2, window, window_size, channel, size_average)

def tensors_to_imgs(x):
    for i in range(len(x)):
        x[i] = x[i].squeeze(0).data.cpu().numpy()
        x[i] = x[i].clip(0, 255).round()
        x[i] = x[i].transpose(1, 2, 0).astype(np.uint8)
    return x

def imgs_to_tensors(x):
    for i in range(len(x)):
        x[i] = x[i].transpose(2, 0, 1)
        x[i] = np.expand_dims(x[i], axis=0)
        x[i] = torch.Tensor(x[i].astype(float))
    return x


def rgb2y(rgb):
    return np.dot(rgb[...,:3], [65.738/256, 129.057/256, 25.064/256]) + 16


def compute_PSNR(out, lbl):
    [out, lbl] = tensors_to_imgs([out, lbl])
    out = rgb2y(out)
    lbl = rgb2y(lbl)
    out = out.clip(0, 255).round()
    lbl = lbl.clip(0, 255).round()
    diff = out - lbl
    rmse = np.sqrt(np.mean(diff**2))
    psnr = 20*np.log10(255/rmse)
    return psnr

###Define model

In [3]:
import torch
from torch import nn

class SeperableConv2d(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True, padding_mode='zeros'):
        super(SeperableConv2d, self).__init__(
            nn.Conv2d(
                in_channels=in_channels, out_channels=in_channels, groups=in_channels, 
                kernel_size=kernel_size, padding='same', dilation=dilation,
                bias=bias, padding_mode=padding_mode
            ),
            nn.Conv2d(in_channels, out_channels, kernel_size=1)
        )


class UpsampleBlock(nn.Sequential):
    def __init__(self, in_channels, scale_factor):
        super(UpsampleBlock, self).__init__(
            SeperableConv2d(in_channels, in_channels * scale_factor**2, kernel_size=3),
            nn.PixelShuffle(scale_factor),
            nn.PReLU(num_parameters=in_channels),
            SeperableConv2d(in_channels, in_channels * scale_factor**2, kernel_size=3),
            nn.PixelShuffle(scale_factor),
            nn.PReLU(num_parameters=in_channels)
        )


class ConvBlock(nn.Sequential):
    def __init__(self, in_channels, out_channels, use_bn=False, use_ffc=False, use_act=True, discriminator=False, **kwargs):
        if use_ffc: conv = FFC(in_channels, out_channels, kernel_size=3, 
                ratio_gin=0.5, ratio_gout=0.5, inline = True
            )
        else: conv = SeperableConv2d(in_channels, out_channels, **kwargs)
        m = [conv]
        
        if use_bn: m.append(nn.BatchNorm2d(out_channels))
        if use_act: m.append(nn.LeakyReLU(0.2, inplace=True) if discriminator else nn.PReLU(num_parameters=out_channels))
        super(ConvBlock, self).__init__(*m)


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, index):
        super(ResidualBlock, self).__init__()
        
        self.block1 = ConvBlock(
            in_channels,
            in_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            use_ffc=True if index % 2 == 0 else False
        )
        self.block2 = ConvBlock(
            in_channels,
            in_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            use_act=False
        )
        
    def forward(self, x):
        out = self.block1(x)
        out = self.block2(out)
        out = out.mul(0.1)
        out += x
        return out


class MeanShift(nn.Conv2d):
    def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        std = torch.Tensor(rgb_std)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
        for p in self.parameters():
            p.requires_grad = False


class Generator(nn.Module):
    def __init__(self, in_channels: int = 3, num_channels: int = 64, num_blocks: int = 16, upscale_factor: int = 4):
        super(Generator, self).__init__()
        
        self.initial = ConvBlock(in_channels, num_channels, kernel_size=3, use_act=False)
        self.residual = nn.Sequential(
            *[ResidualBlock(num_channels, _) for _ in range(num_blocks)]
        )
        self.upsampler = UpsampleBlock(num_channels, scale_factor=2)
        self.final_conv = SeperableConv2d(num_channels, in_channels, kernel_size=3)

        rgb_mean = (0.4488, 0.4371, 0.4040)
        rgb_std = (1.0, 1.0, 1.0)
        rgb_range = 255
        self.sub_mean = MeanShift(rgb_range, rgb_mean, rgb_std)
        self.add_mean = MeanShift(rgb_range, rgb_mean, rgb_std, 1)
        
    def forward(self, x):
        x = self.sub_mean(x)
        initial = self.initial(x)
        x = self.residual(initial) + initial
        x = self.upsampler(x)
        out = self.final_conv(x)
        out = self.add_mean(out)
        return out


class FourierUnit(nn.Module):

    def __init__(self, in_channels, out_channels, ffc3d=False, fft_norm='ortho'):
        super(FourierUnit, self).__init__()
        self.conv_layer = SeperableConv2d(in_channels=in_channels * 2,
                                          out_channels=out_channels * 2,
                                          kernel_size=1, stride=1, padding=0, bias=False)
        self.relu = torch.nn.ReLU(inplace=True)

        self.ffc3d = ffc3d
        self.fft_norm = fft_norm

    def forward(self, x):
        batch = x.shape[0]
        r_size = x.size()

        # (batch, c, h, w/2+1, 2)
        fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
        ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
        ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
        ffted = ffted.permute(0, 1, 4, 2, 3).contiguous()  # (batch, c, 2, h, w/2+1)
        ffted = ffted.view((batch, -1,) + ffted.size()[3:])

        ffted = self.conv_layer(ffted)  # (batch, c*2, h, w/2+1)
        ffted = self.relu(ffted)

        ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
            0, 1, 3, 4, 2).contiguous()  # (batch,c, t, h, w/2+1, 2)
        ffted = torch.complex(ffted[..., 0], ffted[..., 1])

        ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
        output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)

        return output


class SpectralTransform(nn.Module):

    def __init__(self, in_channels, out_channels, stride=1, enable_lfu=True, **fu_kwargs):
        super(SpectralTransform, self).__init__()
        self.enable_lfu = enable_lfu
        if stride == 2:
            self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        else:
            self.downsample = nn.Identity()

        self.stride = stride
        self.conv1 = nn.Sequential(
            SeperableConv2d(in_channels, out_channels // 2, kernel_size=1, bias=False),
            nn.ReLU(inplace=True)
        )
        self.fu = FourierUnit(
            out_channels // 2, out_channels // 2, **fu_kwargs)

        if self.enable_lfu:
            self.lfu = FourierUnit(out_channels // 2, out_channels // 2)
        self.conv2 = SeperableConv2d(out_channels // 2, out_channels, kernel_size=1, bias=False)

    def forward(self, x):

        x = self.downsample(x)
        x = self.conv1(x)
        output = self.fu(x)
        
        if self.enable_lfu:
            n, c, h, w = x.shape
            split_no = 2
            split_h = h // split_no
            split_w = w // split_no
            xs = torch.cat(torch.split(x[:, :c // 4], split_h, dim=-2)[0:2], dim=1).contiguous()
            xs = torch.cat(torch.split(xs, split_w, dim=-1)[0:2], dim=1).contiguous()
            xs = self.lfu(xs)
            xs = xs.repeat(1, 1, split_no, split_no).contiguous()

            if h % 2 == 1:
                h_zeros = torch.zeros(xs.shape[0], xs.shape[1], 1, xs.shape[3]).to(DEVICE)
                xs = torch.cat((xs, h_zeros), dim=2)
            if w % 2 == 1:
                w_zeros = torch.zeros(xs.shape[0], xs.shape[1], xs.shape[2], 1).to(DEVICE)
                xs = torch.cat((xs, w_zeros), dim=3)
        else:
            xs = 0

        output = self.conv2(x + output + xs)

        return output


class FFC(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size,
                 ratio_gin, ratio_gout, inline=True, stride=1, padding=0,
                 dilation=1, enable_lfu=True,
                 padding_type='reflect', gated=False, **spectral_kwargs):
        super(FFC, self).__init__()

        assert stride == 1 or stride == 2, "Stride should be 1 or 2."
        self.stride = stride
        self.inline = inline

        in_cg = int(in_channels * ratio_gin)
        in_cl = in_channels - in_cg
        out_cg = int(out_channels * ratio_gout)
        out_cl = out_channels - out_cg

        self.ratio_gin = ratio_gin
        self.ratio_gout = ratio_gout
        self.global_in_num = in_cg

        module = nn.Identity if in_cl == 0 or out_cl == 0 else SeperableConv2d
        self.convl2l = module(in_cl, out_cl, kernel_size,
                              stride, padding, dilation, padding_mode=padding_type)
        module = nn.Identity if in_cl == 0 or out_cg == 0 else SeperableConv2d
        self.convl2g = module(in_cl, out_cg, kernel_size,
                              stride, padding, dilation, padding_mode=padding_type)
        module = nn.Identity if in_cg == 0 or out_cl == 0 else SeperableConv2d
        self.convg2l = module(in_cg, out_cl, kernel_size,
                              stride, padding, dilation, padding_mode=padding_type)
        module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
        self.convg2g = module(
            in_cg, out_cg, stride, enable_lfu, **spectral_kwargs)

        self.gated = gated
        module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else SeperableConv2d
        self.gate = module(in_channels, 2, 1)

    def forward(self, x):
        if self.inline:
            x_l, x_g = x[:, :-self.global_in_num], x[:, -self.global_in_num:]
        else:
            x_l, x_g = x if type(x) is tuple else (x, 0)
        out_xl, out_xg = 0, 0

        if self.gated:
            total_input_parts = [x_l]
            if torch.is_tensor(x_g):
                total_input_parts.append(x_g)
            total_input = torch.cat(total_input_parts, dim=1)

            gates = torch.sigmoid(self.gate(total_input))
            g2l_gate, l2g_gate = gates.chunk(2, dim=1)
        else:
            g2l_gate, l2g_gate = 1, 1

        if self.ratio_gout != 1:
            out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
        if self.ratio_gout != 0:
            out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)
            
        out = out_xl, out_xg
        if self.inline:
            out = torch.cat(out, dim=1)

        return out

###Testing

In [4]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
UPSCALE_FACTOR = 4

In [5]:
def swiftsrgan_psnr_testing(_path, model_path, ds):
    netG = Generator().to(DEVICE)
    netG.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['model'])
    netG.eval()

    with torch.no_grad():
        valing_results = {
            "mse": 0,
            "psnr": 0,
            "batch_sizes": 0,
            "ssim": 0,
            "ssims": 0,
        }
        nums = 0
        shutil.rmtree('./PIRM2018/your_results')
        shutil.rmtree('./PIRM2018/self_validation_HR')
        os.makedirs('./PIRM2018/your_results')
        os.makedirs('./PIRM2018/self_validation_HR')
        for name in os.listdir(_path):
            full_path = _path + "/" + name
            hr_image = Image.open(full_path).convert('RGB')
            image_width = (hr_image.width // 4) * 4
            image_height = (hr_image.height // 4) * 4
            hr_scale = transforms.Resize((image_height, image_width), interpolation=Image.BICUBIC)
            lr_scale = transforms.Resize((image_height // 4, image_width // 4), interpolation=Image.BICUBIC)
            lr_image = lr_scale(hr_image)
            hr_image = hr_scale(hr_image)
            lr_image = np.asarray(lr_image)
            hr_image = np.asarray(hr_image)
            
            # inp = imageio.imread(full_path)
            # [inp] = imgs_to_tensors([inp])
            [hr_image] = imgs_to_tensors([hr_image])
            [lr_image] = imgs_to_tensors([lr_image])
            out = netG(lr_image)

            # valing_results["psnr"] += compute_PSNR(hr_image, out)
            # mse = ((hr_image - out) ** 2).data.mean()
            # valing_results["psnr"] += 10 * math.log10((1.0 ** 2) / (mse))
            [out] = tensors_to_imgs([out])
            [hr_image] = tensors_to_imgs([hr_image])

            imageio.imwrite('./PIRM2018/your_results' + "/" + name, out)
            imageio.imwrite("./PIRM2018/self_validation_HR" + "/" + name, hr_image)
            nums += 1
    
    return valing_results["psnr"]/nums, valing_results["ssim"]

In [6]:
def swiftsrgan_time_testing(_path, model_path, ds):
    netG = Generator().to(DEVICE)
    netG.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['model'])
    netG.eval()

    exec_time = 0.0
    nums = 0
    with torch.no_grad():
        for name in os.listdir(_path):
            full_path = _path + "/" + name
            hr_image = Image.open(full_path).convert('RGB')
            image_width = (hr_image.width // 4) * 4
            image_height = (hr_image.height // 4) * 4
            hr_scale = transforms.Resize((image_height, image_width), interpolation=Image.BICUBIC)
            lr_scale = transforms.Resize((image_height // 4, image_width // 4), interpolation=Image.BICUBIC)
            
            lr_image = lr_scale(hr_image)
            hr_image = hr_scale(hr_image)
            lr_image = np.asarray(lr_image)
            hr_image = np.asarray(hr_image)
            [inp] = imgs_to_tensors([lr_image])
            # inp = imageio.imread(full_path)
            # [inp] = imgs_to_tensors([inp])

            # torch.cuda.synchronize()
            t0 = time.time()
            out = netG(inp)
            # torch.cuda.synchronize()
            exec_time += time.time() - t0
            nums += 1
    
    return exec_time / nums
    

In [7]:
dataset = ["Set5"]
model_path = "./models/swiftfsrgan_combinev1/content/netG_4x_epoch6.pth.tar"

for ds in dataset:
    psnr, ssim_val = swiftsrgan_psnr_testing("./dataset/SR_testing_datasets/" + ds, model_path, ds)
    # exec_time = swiftsrgan_time_testing("./dataset/SR_testing_datasets/" + ds, model_path, ds)
    # print(f'Execution time in {ds} = {exec_time}')
    print(f'PSNR in {ds} = {psnr}')
    print(f'SSIM in {ds} = {ssim_val}')




PSNR in Set5 = 0.0
SSIM in Set5 = 0
