In [1]:
import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.utils import spectral_norm

In [15]:
import os
import numpy as np
import matplotlib.pyplot as plt
import glob

import torch
from torchvision import transforms
import torch.nn as nn
import torch.utils.data as data
from torch.utils.data import DataLoader

import torch.nn.functional as F
from torch.autograd import Variable
import math
from math import exp
from tqdm import tqdm

from kornia.filters.sobel import Sobel
import wandb
from torchvision.utils import make_grid
import gc
from torchmetrics import StructuralSimilarityIndexMeasure as SSIM

In [16]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

# Loss, Metric and Other Functions

In [17]:
def normalize(sample):
#     MIN_H = 0.0
#     MAX_H = 8802.0
    MIN_H = sample.min()
    MAX_H = sample.max()
    return (sample - MIN_H)/(MAX_H-MIN_H)

In [18]:
class Sobel_older(nn.Module):
    def __init__(self):
        super().__init__()
        self.filter = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=0, bias=False)

        Gx = torch.tensor([[2.0, 0.0, -2.0], [4.0, 0.0, -4.0], [2.0, 0.0, -2.0]])
        Gy = torch.tensor([[2.0, 4.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -4.0, -2.0]])
        G = torch.cat([Gx.unsqueeze(0), Gy.unsqueeze(0)], 0)
        G = G.unsqueeze(1)
        self.filter.weight = nn.Parameter(G, requires_grad=False)

    def forward(self, img):
        x = self.filter(img)
        x = torch.mul(x, x)
        x = torch.sum(x, dim=1, keepdim=True)
        x = torch.sqrt(x)
#         x = (torch.tanh(x) + 1)/2
        return x

In [19]:
def calculate_psnr(img1, img2, border=0 ,data_min=0.0 ,data_max=1.0 ):
    
    if not img1.shape == img2.shape:
        raise ValueError('Input images must have the same dimensions.')
    h, w = img1.shape[2:]

    img1 = img1[border:h-border, border:w-border]
    img2 = img2[border:h-border, border:w-border]

    mse = np.mean((img1 - img2)**2)
#     print(mse)
    if mse == 0:
        return float('inf')
    return 20 * math.log10((data_max - data_min)/ math.sqrt(mse))

In [20]:
class gradientAwareLoss(nn.Module): 
    def __init__(self):
        super().__init__()
        self.sobelFilter = Sobel().to('cuda')
        self.l1Loss = nn.L1Loss().to('cuda')

    def forward(self, hr, sr):
        hrEdgeMap = self.sobelFilter(hr)
        srEdgeMap = self.sobelFilter(sr)
        return self.l1Loss(hrEdgeMap, srEdgeMap)

In [21]:
sobel_old = Sobel_older()
inp = torch.rand([1,1,256,256])
out = sobel_old(inp)
print(out.shape)

torch.Size([1, 1, 254, 254])


# Dataset and Dataloader

In [22]:
# Dataset class
class Dataset(data.Dataset):
    def __init__(self, hr_paths,lr_paths,transform = None):
        self.load_dir_hr = hr_paths
        self.load_dir_lr = lr_paths
        self.tranform = transform
        
    def __getitem__(self, index):
        hr = normalize(cv2.imread(self.load_dir_hr[index])).astype(np.float32)
        lr = normalize(cv2.imread(self.load_dir_lr[index])).astype(np.float32)
        
        
        if self.tranform:
            hr, lr = self.tranform(hr), self.tranform(lr)
        
        return hr, lr
    
    def __len__(self):
        return len(self.load_dir_hr)

In [26]:
# Train and Test Data Loader
import cv2
train_hr = glob.glob('/kaggle/input/imagesuperresolution-dataset-2x4x-scale2/dataset/train/hr/**.png')
test_hr = glob.glob('/kaggle/input/imagesuperresolution-dataset-2x4x-scale2/dataset/test/hr/**.png')

train_lr = glob.glob('/kaggle/input/imagesuperresolution-dataset-2x4x-scale2/dataset/train/lr_2/**.png')
test_lr = glob.glob('/kaggle/input/imagesuperresolution-dataset-2x4x-scale2/dataset/test/lr_2/**.png')

train_loader = DataLoader(Dataset(train_hr, train_lr,
                                  transform = transforms.Compose([transforms.ToTensor()]))
                            ,batch_size=32, shuffle=True)
test_loader = DataLoader(Dataset(test_hr,test_lr,
                                transform = transforms.Compose([transforms.ToTensor()]))
                                ,batch_size=32, shuffle=False)

In [28]:
for hr , lr in test_loader:
    print(hr.shape)
    print(lr.shape)
    


torch.Size([32, 3, 512, 512])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 512, 512])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 512, 512])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 512, 512])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 512, 512])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 512, 512])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 512, 512])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 512, 512])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 512, 512])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 512, 512])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 512, 512])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 512, 512])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 512, 512])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 512, 512])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 512, 512])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 512, 512])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 512, 512])
torch.Size

# Generator

In [31]:
!git clone https://github.com/seungjunlee96/DepthwiseSeparableConvolution_Pytorch.git
%cd DepthwiseSeparableConvolution_Pytorch/
!python3 setup.py install --user

Cloning into 'DepthwiseSeparableConvolution_Pytorch'...
remote: Enumerating objects: 105, done.[K
remote: Counting objects: 100% (105/105), done.[K
remote: Compressing objects: 100% (105/105), done.[K
remote: Total 105 (delta 57), reused 0 (delta 0), pack-reused 0[K
Receiving objects: 100% (105/105), 312.97 KiB | 5.05 MiB/s, done.
Resolving deltas: 100% (57/57), done.
/kaggle/working/StyleSwin/DepthwiseSeparableConvolution_Pytorch
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
!!

        ********************************************************************************
        Please avoid r

In [32]:
from DepthwiseSeparableConvolution import depthwise_separable_conv

In [33]:
%cd ..

/kaggle/working/StyleSwin


In [34]:

import torch
import torch.nn as nn
import time
import torch
import torch.nn as nn


def get_padding(kernel_size, dilation=1):
    return int((kernel_size * dilation - dilation) / 2)


LRELU_SLOPE = 0.1


class AddSkipConn(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.add = torch.nn.quantized.FloatFunctional()

    def forward(self, x):
        return self.add.add(x, self.net(x))


class ConcatSkipConn(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net

    def forward(self, x):
        return torch.cat([x, self.net(x)], 1)


class FourierUnit(torch.nn.Module):
    """Implements Fourier Unit block.

    Applies FFT to tensor and performs convolution in spectral domain.
    After that return to time domain with Inverse FFT.

    Attributes:
        inter_conv: conv-bn-relu block that performs conv in spectral domain

    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        fu_kernel: int = 1,
        padding_type: str = "reflect",
        fft_norm: str = "ortho",
        use_only_freq: bool = False,
        norm_layer=nn.BatchNorm2d,
        bias: bool = True,
    ):
        super().__init__()
        self.fft_norm = fft_norm
        self.use_only_freq = use_only_freq

        self.inter_conv = nn.Sequential(
            nn.Conv2d(
                in_channels * 2,
                out_channels * 2,
                kernel_size=fu_kernel,
                stride=1,
                padding=get_padding(fu_kernel),
                padding_mode=padding_type,
                bias=bias,
            ),
            norm_layer(out_channels * 2),
            nn.ReLU(True),
        )

    def forward(self, x):
        batch_size, ch, freq_dim, embed_dim = x.size()

        dims_to_fft = (-2,) if self.use_only_freq else (-2, -1)
        recover_length = (freq_dim,) if self.use_only_freq else (freq_dim, embed_dim)

        fft_representation = torch.fft.rfftn(x, dim=dims_to_fft, norm=self.fft_norm)

        # (B, Ch, 2, FFT_freq, FFT_embed)
        fft_representation = torch.stack(
            (fft_representation.real, fft_representation.imag), dim=2
        )  # .view(batch_size, ch * 2, -1, embed_dim)

        ffted_dims = fft_representation.size()[-2:]
        fft_representation = fft_representation.view(
            (
                batch_size,
                ch * 2,
            )
            + ffted_dims
        )

        fft_representation = (
            self.inter_conv(fft_representation)
            .view(
                (
                    batch_size,
                    ch,
                    2,
                )
                + ffted_dims
            )
            .permute(0, 1, 3, 4, 2)
        )

        fft_representation = torch.complex(
            fft_representation[..., 0], fft_representation[..., 1]
        )

        reconstructed_x = torch.fft.irfftn(
            fft_representation, dim=dims_to_fft, s=recover_length, norm=self.fft_norm
        )

        assert reconstructed_x.size() == x.size()

        return reconstructed_x


class SpectralTransform(torch.nn.Module):
    """Implements Spectrals Transform block.

    Residual Block containing Fourier Unit with convolutions before and after.

    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        fu_kernel: int = 1,
        padding_type: str = "reflect",
        fft_norm: str = "ortho",
        use_only_freq: bool = False,
        norm_layer=nn.BatchNorm2d,
        bias: bool = False,
    ):
        super().__init__()
        halved_out_ch = out_channels // 2

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, halved_out_ch, kernel_size=1, stride=1, bias=bias),
            norm_layer(halved_out_ch),
            nn.ReLU(True),
        )

        self.fu = FourierUnit(
            halved_out_ch,
            halved_out_ch,
            fu_kernel=fu_kernel,
            use_only_freq=use_only_freq,
            fft_norm=fft_norm,
            padding_type=padding_type,
            norm_layer=norm_layer,
        )

        self.conv2 = nn.Conv2d(
            halved_out_ch, out_channels, kernel_size=1, stride=1, bias=bias
        )

    def forward(self, x):

        residual = self.conv1(x)
        x = self.fu(residual)
        x += residual
        x = self.conv2(x)

        return x


class FastFourierConvolution(torch.nn.Module):
    """Implements FFC block.

    Divides Tensor in two branches: local and global. Local branch performs
    convolutions and global branch applies Spectral Transform layer.
    After performing transforms in local and global branches outputs are passed through BatchNorm + ReLU
    and eventually concatenated. Based on proportion of input and output global channels if the number is equal
    to zero respective blocks are replaced by Identity Transform.
    For clarity refer to original paper.

    Attributes:
        local_in_channels: # input channels for l2l and l2g convs
        local_out_channels: # output channels for l2l and g2l convs
        global_in_channels: # input channels for g2l and g2g convs
        global_out_channels: # output_channels for l2g and g2g convs
        l2l_layer: local to local Convolution
        l2g_layer: local to global Convolution
        g2l_layer: global to local Convolution
        g2g_layer: global to global Spectral Transform

    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        alpha_in: float = 0.5,
        alpha_out: float = 0.5,
        kernel_size: int = 3,
        padding_type: str = "reflect",
        fu_kernel: int = 1,
        fft_norm: str = "ortho",
        bias: bool = True,
        norm_layer=nn.BatchNorm2d,
        activation=nn.ReLU(True),
        use_only_freq: bool = False,
    ):
        """Inits FFC module.

        Args:
            in_channels: total channels of tensor before dividing into local and global
            alpha_in:
                proportion of global channels as input
            alpha_out:
                proportion of global channels as output
            use_only_freq:
                controls dimensionality of fft in Fourier Unit. If false uses 2D fft in Fourier Unit affecting both
                frequency and time dimensions, otherwise applies 1D FFT only to frequency dimension

        """
        super().__init__()
        self.global_in_channels = int(in_channels * alpha_in)
        self.local_in_channels = in_channels - self.global_in_channels
        self.global_out_channels = int(out_channels * alpha_out)
        self.local_out_channels = out_channels - self.global_out_channels

        padding = get_padding(kernel_size)

        tmp_module = self._get_module_on_true_predicate(
            self.local_in_channels > 0 and self.local_out_channels > 0,
            nn.Conv2d,
            nn.Identity,
        )
        self.l2l_layer = tmp_module(
            self.local_in_channels,
            self.local_out_channels,
            kernel_size=kernel_size,
            stride=1,
            padding=padding,
            padding_mode=padding_type,
            bias=bias,
        )

        tmp_module = self._get_module_on_true_predicate(
            self.local_in_channels > 0 and self.global_out_channels > 0,
            nn.Conv2d,
            nn.Identity,
        )
        self.l2g_layer = tmp_module(
            self.local_in_channels,
            self.global_out_channels,
            kernel_size=kernel_size,
            stride=1,
            padding=padding,
            padding_mode=padding_type,
            bias=bias,
        )

        tmp_module = self._get_module_on_true_predicate(
            self.global_in_channels > 0 and self.local_out_channels > 0,
            nn.Conv2d,
            nn.Identity,
        )
        self.g2l_layer = tmp_module(
            self.global_in_channels,
            self.local_out_channels,
            kernel_size=kernel_size,
            stride=1,
            padding=padding,
            padding_mode=padding_type,
            bias=bias,
        )

        tmp_module = self._get_module_on_true_predicate(
            self.global_in_channels > 0 and self.global_out_channels > 0,
            SpectralTransform,
            nn.Identity,
        )
        self.g2g_layer = tmp_module(
            self.global_in_channels,
            self.global_out_channels,
            fu_kernel=fu_kernel,
            fft_norm=fft_norm,
            padding_type=padding_type,
            bias=bias,
            norm_layer=norm_layer,
            use_only_freq=use_only_freq,
        )

        self.local_bn_relu = (
            nn.Sequential(norm_layer(self.local_out_channels), activation)
            if self.local_out_channels != 0
            else nn.Identity()
        )

        self.global_bn_relu = (
            nn.Sequential(norm_layer(self.global_out_channels), activation)
            if self.global_out_channels != 0
            else nn.Identity()
        )

    @staticmethod
    def _get_module_on_true_predicate(
        condition: bool, true_module=nn.Identity, false_module=nn.Identity
    ):
        if condition:
            return true_module
        else:
            return false_module

    def forward(self, x):

        #  chunk into local and global channels
        x_l, x_g = (
            x[:, : self.local_in_channels, ...],
            x[:, self.local_in_channels :, ...],
        )
        x_l = 0 if x_l.size()[1] == 0 else x_l
        x_g = 0 if x_g.size()[1] == 0 else x_g

        out_local, out_global = torch.Tensor(0).to(x.device), torch.Tensor(0).to(
            x.device
        )

        if self.local_out_channels != 0:
            out_local = self.l2l_layer(x_l) + self.g2l_layer(x_g)
            out_local = self.local_bn_relu(out_local)

        if self.global_out_channels != 0:
            out_global = self.l2g_layer(x_l) + self.g2g_layer(x_g)
            out_global = self.global_bn_relu(out_global)

        #  (B, out_ch, F, T)
        output = torch.cat((out_local, out_global), dim=1)

        return output


class FFCResNetBlock(torch.nn.Module):
    """Implements Residual FFC block.

    Contains two FFC blocks with residual connection.

    Wraps around FFC arguments.

    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        alpha_in: float = 0.5,
        alpha_out: float = 0.5,
        kernel_size: int = 3,
        padding_type: str = "reflect",
        bias: bool = True,
        fu_kernel: int = 1,
        fft_norm: str = "ortho",
        use_only_freq: bool = False,
        norm_layer=nn.BatchNorm2d,
        activation=nn.ReLU(True),
    ):
        super().__init__()
        self.ffc1 = FastFourierConvolution(
            in_channels,
            out_channels,
            alpha_in=alpha_in,
            alpha_out=alpha_out,
            kernel_size=kernel_size,
            padding_type=padding_type,
            fu_kernel=fu_kernel,
            fft_norm=fft_norm,
            use_only_freq=use_only_freq,
            bias=bias,
            norm_layer=norm_layer,
            activation=activation,
        )

        self.ffc2 = FastFourierConvolution(
            in_channels,
            out_channels,
            alpha_in=alpha_in,
            alpha_out=alpha_out,
            kernel_size=kernel_size,
            padding_type=padding_type,
            fu_kernel=fu_kernel,
            fft_norm=fft_norm,
            use_only_freq=use_only_freq,
            bias=bias,
            norm_layer=norm_layer,
            activation=activation,
        )

    def forward(self, x):
        out = self.ffc1(x)
        # out = self.ffc2(out)
        return x 
    
    
inp = torch.rand([16,64,128,128])

ffc = FastFourierConvolution(64,64,kernel_size=5,use_only_freq=False)

start = time.time()

out = ffc(inp)
end = time.time()

print(f"Runtime of the program is {end - start}")
# print(ffc)

print("Out shape",out.shape)

Runtime of the program is 0.6585361957550049
Out shape torch.Size([16, 64, 128, 128])


In [35]:
class RefinementBlock(nn.Module):
    def __init__(self,inp_channel,out_channel,kernel_sizes=[3,5,7]):
        super().__init__()
        self.inp = inp_channel
        self.out = out_channel
        self.kernel_size = kernel_sizes

        self.ffc0 = FastFourierConvolution(inp_channel,out_channel,kernel_size=kernel_sizes[0])
        self.ffc1 = FastFourierConvolution(inp_channel,out_channel,kernel_size=kernel_sizes[1])
        self.ffc2 = FastFourierConvolution(inp_channel,out_channel,kernel_size=kernel_sizes[2])
        self.relu = nn.ReLU()
        
    def forward(self,x):
        
        x0 = self.ffc0(x)
        x1 = self.ffc1(x)
        x2 = self.ffc2(x)
        
        return x + x0 + x1 + x2 
    
inp = torch.rand([12,4,128,128])
rb = RefinementBlock(4,4)

out = rb(inp)
out.shape

torch.Size([12, 4, 128, 128])

In [36]:
class ERAM(nn.Module):
    def __init__(self, channel_begin, dimension):
        super().__init__()
        self.conv = nn.Conv2d(channel_begin, channel_begin, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.avgpool = nn.AvgPool2d(dimension)
        
        self.conv1 = nn.Conv2d(channel_begin, channel_begin//2, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(channel_begin//2, channel_begin, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(channel_begin, channel_begin, kernel_size=3, stride=1, padding=1)

        self.dconv = depthwise_separable_conv(channel_begin, channel_begin, kernel_size = 3, padding = 1, bias=False)
        self.sigmoid = nn.Sigmoid()


    def forward(self, x):
        si_ca = self.avgpool(x) + torch.var_mean(x, dim=(2,3))[0].unsqueeze(2).unsqueeze(2)
        mi_ca = self.conv2(self.relu(self.conv1(si_ca)))

        mi_sa = self.conv3(self.relu(self.dconv(x)))

        return self.sigmoid(mi_ca+mi_sa) * x

In [37]:
class RepetitiveBlock(nn.Module):
    def __init__(self,input_channel,out_channel,dimension=128,kernel_sizes=[3,5,7]):
        super().__init__()
        
        self.input_channel = input_channel
        self.out_channel = out_channel
        self.kernel_sizes = kernel_sizes
        self.dimension = dimension
        
        self.refinement = RefinementBlock(inp_channel=self.input_channel,
                                          out_channel=self.out_channel,
                                          kernel_sizes=self.kernel_sizes)
        
        self.eram = ERAM(channel_begin=self.out_channel,dimension=self.dimension)
        self.gelu = nn.GELU()
        
    def forward(self,x):
        
        x = self.refinement(x)
        x = self.eram(x)
        x = self.gelu(x)
        
        return x

    

In [38]:
inp = torch.rand([12,4,128,128])
rb = RepetitiveBlock(input_channel=4,out_channel=4)

out = rb(inp)
out.shape

torch.Size([12, 4, 128, 128])

In [39]:
import torch
import torch.nn as nn


import time
from DepthwiseSeparableConvolution import depthwise_separable_conv



class RefinementBlock(nn.Module):
    def __init__(self,inp_channel,out_channel,kernel_sizes=[3,5,7]):
        super().__init__()
        self.inp = inp_channel # channel needs to be even
        self.out = out_channel
        self.kernel_size = kernel_sizes

        self.ffc0 = FastFourierConvolution(inp_channel,out_channel,kernel_size=kernel_sizes[0])
        self.ffc1 = FastFourierConvolution(inp_channel,out_channel,kernel_size=kernel_sizes[1])
        self.ffc2 = FastFourierConvolution(inp_channel,out_channel,kernel_size=kernel_sizes[2])
        self.relu = nn.ReLU()
        
    def forward(self,x):
        
        x0 = self.ffc0(x)
        x1 = self.ffc1(x)
        x2 = self.ffc2(x)
        
        return x + x0 + x1 + x2 
    

        
class ERAM(nn.Module):
    def __init__(self, channel_begin, dimension):
        super().__init__()
        self.conv = nn.Conv2d(channel_begin, channel_begin, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.avgpool = nn.AvgPool2d(dimension)
        
        self.conv1 = nn.Conv2d(channel_begin, channel_begin//2, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(channel_begin//2, channel_begin, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(channel_begin, channel_begin, kernel_size=3, stride=1, padding=1)

        self.dconv = depthwise_separable_conv(channel_begin, channel_begin, kernel_size = 3, padding = 1, bias=False)
        self.sigmoid = nn.Sigmoid()


    def forward(self, x):
        si_ca = self.avgpool(x) + torch.var_mean(x, dim=(2,3))[0].unsqueeze(2).unsqueeze(2)
        mi_ca = self.conv2(self.relu(self.conv1(si_ca)))

        mi_sa = self.conv3(self.relu(self.dconv(x)))

        return self.sigmoid(mi_ca+mi_sa) * x


class RepetitiveBlock(nn.Module):
    def __init__(self,input_channel,out_channel,dimension=128,kernel_sizes=[3,5,7]):
        super().__init__()
        
        self.input_channel = input_channel
        self.out_channel = out_channel
        self.kernel_sizes = kernel_sizes
        self.dimension = dimension
        
        self.refinement = RefinementBlock(inp_channel=self.input_channel,
                                          out_channel=self.out_channel,
                                          kernel_sizes=self.kernel_sizes)
        
        self.eram = ERAM(channel_begin=self.out_channel,dimension=self.dimension)
        self.gelu = nn.GELU()
        
    def forward(self,x):
        
        x = self.refinement(x)
        x = self.eram(x)
        x = self.gelu(x)
        
        return x
        
        
    
        
        
        

class Generator(nn.Module):
    def __init__(self, inp_channel, out_channel, repitive_units=5 ,channel_expansion=16, scale_factor=2):
        super().__init__()
        self.scale_factor =scale_factor
        self.channel_expansion = channel_expansion
        
        # feature extraction block
        self.conv1 = nn.Conv2d(inp_channel,self.channel_expansion,kernel_size=3,padding='same')
        self.conv2 = nn.Conv2d(self.channel_expansion,self.channel_expansion*4,kernel_size=3, padding='same')
        self.conv3 = nn.Conv2d(self.channel_expansion*4,self.channel_expansion*4,kernel_size=3, padding='same')
        
        self.repetitive_units = 5
        

        self.repetive_blocks_1 = RepetitiveBlock(input_channel=self.channel_expansion,out_channel=self.channel_expansion)
        self.repetive_blocks_2 = RepetitiveBlock(input_channel=self.channel_expansion,out_channel=self.channel_expansion)
        self.repetive_blocks_3 = RepetitiveBlock(input_channel=self.channel_expansion,out_channel=self.channel_expansion)
        self.repetive_blocks_4 = RepetitiveBlock(input_channel=self.channel_expansion,out_channel=self.channel_expansion)
        self.repetive_blocks_5 = RepetitiveBlock(input_channel=self.channel_expansion,out_channel=self.channel_expansion)
            
        self.upsample_PixelShuffle = nn.PixelShuffle(int(scale_factor))
        self.downsample_PixelShuffle = nn.PixelShuffle(int(1.0/scale_factor))
        self.bicubic_upsample = nn.Upsample(mode='bicubic',scale_factor=self.scale_factor)
        self.refinement_Block = ...
        
        self.conv4 = nn.Conv2d(19,3,kernel_size=3,padding='same') # need to fix channel
        self.relu = nn.ReLU()
        
        self.tanh = nn.Tanh()
        
    def forward(self,x0):
        
        # x0 -> inp -> (b,3, h, w)
        x = self.relu(self.conv1(x0)) # (b,3, h, w) -> (b,16, h, w)
       
        x = self.repetive_blocks_1(x)  # (b,16, h, w) -> (b,16, h, w)
        x = self.repetive_blocks_2(x)  # (b,16, h, w) -> (b,16, h, w)
        x = self.repetive_blocks_3(x)  # (b,16, h, w) -> (b,16, h, w)
        x = self.repetive_blocks_4(x)  # (b,16, h, w) -> (b,16, h, w)
        x = self.repetive_blocks_5(x)  # (b,16, h, w) -> (b,16, h, w)
        
        x = self.conv2(x)   # (b,16, h, w) -> (b,64, h, w)
        x = self.conv3(x)   # (b,64, h, w) -> (b,64, h, w)

        x = self.upsample_PixelShuffle(x) #   # (b,16, h, w) -> (b,4, h*2, w*2)

        x_bc_upsample = self.bicubic_upsample(x0)   # (b,3, h, w) -> (b,3, h*2, w*2)
        
        
        x = torch.concat([x,x_bc_upsample],dim=1)   # (b,4, h, w)  + (b,3, h, w)  -> (b,7, h, w)
        
        x = self.conv4(x)   # (b,7, h, w) -> (b,3, h, w)
        
        return (self.tanh(x) + 1.0)/2.0   

In [40]:
inp = torch.rand([12,3,128,128])
gen = Generator(3,3)

out = gen(inp)
out.shape

torch.Size([12, 3, 256, 256])

# Discriminator

In [41]:
!git clone https://github.com/microsoft/StyleSwin

Cloning into 'StyleSwin'...
remote: Enumerating objects: 114, done.[K
remote: Counting objects: 100% (28/28), done.[K
remote: Compressing objects: 100% (4/4), done.[K
remote: Total 114 (delta 25), reused 24 (delta 24), pack-reused 86[K
Receiving objects: 100% (114/114), 11.54 MiB | 13.67 MiB/s, done.
Resolving deltas: 100% (49/49), done.


In [42]:
%cd /kaggle/working/StyleSwin

/kaggle/working/StyleSwin


In [43]:
!pip install ninja



In [44]:
from op import FusedLeakyReLU, upfirdn2d
from models.basic_layers import (Blur, Downsample, EqualConv2d, EqualLinear,
                                 ScaledLeakyReLU)

In [45]:
class ConvLayer(nn.Sequential):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        downsample=False,
        blur_kernel=[1, 3, 3, 1],
        bias=True,
        activate=True,
        sn=False
    ):
        layers = []

        if downsample:
            factor = 2
            p = (len(blur_kernel) - factor) + (kernel_size - 1)
            pad0 = (p + 1) // 2
            pad1 = p // 2

            layers.append(Blur(blur_kernel, pad=(pad0, pad1)))

            stride = 2
            self.padding = 0

        else:
            stride = 1
            self.padding = kernel_size // 2

        if sn:
            # Not use equal conv2d when apply SN
            layers.append(
                spectral_norm(nn.Conv2d(
                    in_channel,
                    out_channel,
                    kernel_size,
                    padding=self.padding,
                    stride=stride,
                    bias=bias and not activate,
                ))
            )
        else:
            layers.append(
                EqualConv2d(
                    in_channel,
                    out_channel,
                    kernel_size,
                    padding=self.padding,
                    stride=stride,
                    bias=bias and not activate,
                )
            )

        if activate:
            if bias:
                layers.append(FusedLeakyReLU(out_channel))
            else:
                layers.append(ScaledLeakyReLU(0.2))

        super().__init__(*layers)

In [46]:
class ConvBlock(nn.Module):
    def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], sn=False):
        super().__init__()

        self.conv1 = ConvLayer(in_channel, in_channel, 3, sn=sn)
        self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True, sn=sn)

    def forward(self, input):
        out = self.conv1(input)
        out = self.conv2(out)

        return out


In [47]:
def get_haar_wavelet(in_channels):
    haar_wav_l = 1 / (2 ** 0.5) * torch.ones(1, 2)
    haar_wav_h = 1 / (2 ** 0.5) * torch.ones(1, 2)
    haar_wav_h[0, 0] = -1 * haar_wav_h[0, 0]

    haar_wav_ll = haar_wav_l.T * haar_wav_l
    haar_wav_lh = haar_wav_h.T * haar_wav_l
    haar_wav_hl = haar_wav_l.T * haar_wav_h
    haar_wav_hh = haar_wav_h.T * haar_wav_h

    return haar_wav_ll, haar_wav_lh, haar_wav_hl, haar_wav_hh

In [48]:
class HaarTransform(nn.Module):
    def __init__(self, in_channels):
        super().__init__()

        ll, lh, hl, hh = get_haar_wavelet(in_channels)

        self.register_buffer('ll', ll)
        self.register_buffer('lh', lh)
        self.register_buffer('hl', hl)
        self.register_buffer('hh', hh)

    def forward(self, input):
        ll = upfirdn2d(input, self.ll, down=2)
        lh = upfirdn2d(input, self.lh, down=2)
        hl = upfirdn2d(input, self.hl, down=2)
        hh = upfirdn2d(input, self.hh, down=2)

        return torch.cat((ll, lh, hl, hh), 1)

In [49]:
class InverseHaarTransform(nn.Module):
    def __init__(self, in_channels):
        super().__init__()

        ll, lh, hl, hh = get_haar_wavelet(in_channels)

        self.register_buffer('ll', ll)
        self.register_buffer('lh', -lh)
        self.register_buffer('hl', -hl)
        self.register_buffer('hh', hh)

    def forward(self, input):
        ll, lh, hl, hh = input.chunk(4, 1)
        ll = upfirdn2d(ll, self.ll, up=2, pad=(1, 0, 1, 0))
        lh = upfirdn2d(lh, self.lh, up=2, pad=(1, 0, 1, 0))
        hl = upfirdn2d(hl, self.hl, up=2, pad=(1, 0, 1, 0))
        hh = upfirdn2d(hh, self.hh, up=2, pad=(1, 0, 1, 0))

        return ll + lh + hl + hh

In [50]:
class FromRGB(nn.Module):
    def __init__(self, out_channel, downsample=True, blur_kernel=[1, 3, 3, 1], sn=False):
        super().__init__()

        self.downsample = downsample

        if downsample:
            self.iwt = InverseHaarTransform(3)
            self.downsample = Downsample(blur_kernel)
            self.dwt = HaarTransform(3)

        self.conv = ConvLayer(3 * 4, out_channel, 1, sn=sn)

    def forward(self, input, skip=None):
        if self.downsample:
            input = self.iwt(input)
            input = self.downsample(input)
            input = self.dwt(input)

        out = self.conv(input)

        if skip is not None:
            out = out + skip

        return input, out

In [51]:
class Discriminator(nn.Module):
    def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], sn=False, ssd=False):
        super().__init__()

        channels = {
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256 * channel_multiplier,
            128: 128 * channel_multiplier,
            256: 64 * channel_multiplier,
            512: 32 * channel_multiplier,
            1024: 16 * channel_multiplier,
        }

        self.dwt = HaarTransform(3)

        self.from_rgbs = nn.ModuleList()
        self.convs = nn.ModuleList()

        log_size = int(math.log(size, 2)) - 1

        in_channel = channels[size]

        for i in range(log_size, 2, -1):
            out_channel = channels[2 ** (i - 1)]

            self.from_rgbs.append(FromRGB(in_channel, downsample=i != log_size, sn=sn))
            self.convs.append(ConvBlock(in_channel, out_channel, blur_kernel, sn=sn))

            in_channel = out_channel

        self.from_rgbs.append(FromRGB(channels[4], sn=sn))

        self.stddev_group = 4
        self.stddev_feat = 1

        self.final_conv = ConvLayer(in_channel + 1, channels[4], 3, sn=sn)
        if sn:
            self.final_linear = nn.Sequential(
                spectral_norm(nn.Linear(channels[4] * 4 * 4, channels[4])),
                FusedLeakyReLU(channels[4]),
                spectral_norm(nn.Linear(channels[4], 1)),
        )
        else:
            self.final_linear = nn.Sequential(
                EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
                EqualLinear(channels[4], 1),
            )

    def forward(self, input):
        input = self.dwt(input)
        out = None

        for from_rgb, conv in zip(self.from_rgbs, self.convs):
            input, out = from_rgb(input, out)
            out = conv(out)

        _, out = self.from_rgbs[-1](input, out)

        batch, channel, height, width = out.shape
        group = min(batch, self.stddev_group)
        stddev = out.view(
            group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
        )
        stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
        stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
        stddev = stddev.repeat(group, 1, height, width)
        out = torch.cat([out, stddev], 1)

        out = self.final_conv(out)

        out = out.view(batch, -1)
        out = self.final_linear(out)

        return out

In [52]:
disc = Discriminator(size=256).to('cuda')
inp = torch.randn(1, 3, 256, 256).to('cuda')
out = disc(inp)
out

tensor([[-0.9260]], device='cuda:0', grad_fn=<AddmmBackward0>)

# training 

In [54]:
l1Loss = nn.L1Loss().to(device) 
edgeLoss = gradientAwareLoss().to(device) 
ssim = SSIM(data_range=1.0).to(device) 
anotherl1Loss = nn.L1Loss().to(device) 

optimizer = torch.optim.Adam(gen.parameters(), lr=1e-6 )
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=0)

In [56]:
scaler = torch.cuda.amp.GradScaler()
sobelFilter = Sobel().to('cuda')

In [57]:
from torchmetrics import StructuralSimilarityIndexMeasure as SSIM
ssim = SSIM(data_range=1.0).to('cuda')
inp1 = torch.rand([1,1,256,256])
inp2 = torch.rand([1,1,256,256])
out = ssim(inp1,inp2)
out

tensor(0.0015, device='cuda:0')

In [58]:
def train_one_epoch(epoch):
    print(f"\nEpoch {epoch}: ", end ="")
    
    l1_loss_per_epoch = 0.0
    edge_loss_per_epoch = 0.0
    ssim_loss_per_epoch = 0.0
    ssim_per_epoch = 0.0
    psnr_per_epoch = 0.0
    total_loss_per_epoch = 0.0
    
    network.train()
    loop = tqdm(train_loader)
    for hr, lr in (loop):
        
        
        # Forward Pass
        batched_hr, batched_lr = hr.to(device), lr.to(device)
        
#         with torch.cuda.amp.autocast():
        predicted_sr = network(batched_lr)

        # Loss evaluation
        l1_loss_per_sample = l1Loss(batched_hr*1000, predicted_sr*1000)
        ssim_per_sample = ssim(batched_hr, predicted_sr)
        ssim_loss_per_sample = 1 - ssim_per_sample

        sobelhr = sobelFilter(batched_hr)*1000
        sobelsr = sobelFilter(predicted_sr)*1000

        edge_loss = l1Loss(sobelhr,sobelsr)
            
#             print(f'L1 Loss : {l1_loss_per_sample.item()}')
#             print(f'Edge loss : {edge_loss.item()}')
#             print(f'SSIM loss')

        loop.set_postfix(L1_loss=f'{l1_loss_per_sample.item()}',
                            Edge_loss = f'{edge_loss.item()}')
        reconstruction_loss = l1_loss_per_sample + 100*(ssim_loss_per_sample) + 50*edge_loss
        
        t_loss = reconstruction_loss
        optimizer.zero_grad()

        t_loss.backward()
        
        optimizer.step()
        
        psnr_per_sample = calculate_psnr(batched_hr.detach().cpu().numpy(), predicted_sr.detach().cpu().numpy())
        
        l1_loss_per_epoch += l1_loss_per_sample.item()
        edge_loss_per_epoch += edge_loss.item() 
        ssim_loss_per_epoch += ssim_loss_per_sample.item() 
        ssim_per_epoch += ssim_per_sample.item()
        psnr_per_epoch += psnr_per_sample 
        total_loss_per_epoch += t_loss.item()
        
    l1_loss_per_epoch /= float(len(train_loader))
    edge_loss_per_epoch /= float(len(train_loader))
    ssim_loss_per_epoch /= float(len(train_loader))
    ssim_per_epoch /= float(len(train_loader))
    psnr_per_epoch /= float(len(train_loader))
    total_loss_per_epoch /= float(len(train_loader))
#     scheduler.step()
    wandb.log({"Train L1 Loss": l1_loss_per_epoch})
    wandb.log({"Train Edge Loss": edge_loss_per_epoch})
    wandb.log({"Train SSIM Loss": ssim_loss_per_epoch})
    wandb.log({"Train Total Loss": total_loss_per_epoch})
    wandb.log({"Train SSIM": ssim_per_epoch})
    wandb.log({"Train PSNR": psnr_per_epoch})
        
    print(f"(Train) L1 Loss: {l1_loss_per_epoch:.3f} | SSIM Loss: {ssim_loss_per_epoch:.3f} | Edge Loss: {edge_loss_per_epoch:.3f} | Total Loss: {total_loss_per_epoch:.3f}")
    print(f"SSIM: {ssim_per_epoch:.3f} | PSNR: {psnr_per_epoch}")
    
    torch.cuda.empty_cache()
    gc.collect()
    
    
    return psnr_per_epoch

In [59]:
def valid_one_epoch(epoch):
    ssim_per_epoch = 0.0
    psnr_per_epoch = 0.0
    b_ssim_per_epoch = 0.0
    b_psnr_per_epoch = 0.0
    
    network.eval()
    with torch.no_grad():
        for hr, lr in tqdm(test_loader):
            batched_hr, batched_lr = hr.to(device), lr.to(device)
#             with torch.cuda.amp.autocast():
            predicted_sr = network(batched_lr)
                
            bilinear_sr = F.interpolate(batched_lr, scale_factor=2, mode='bilinear')
            
#             print("Bil min : ",bilinear_sr.min())
#             print("Bil max : ",bilinear_sr.max())

            ssim_per_epoch += ssim(batched_hr, predicted_sr)
            psnr_per_epoch += calculate_psnr(batched_hr.cpu().numpy(), predicted_sr.cpu().numpy())

            b_ssim_per_epoch += ssim(batched_hr, bilinear_sr)
            b_psnr_per_epoch += calculate_psnr(batched_hr.cpu().numpy(), bilinear_sr.cpu().numpy())

            grid1 = make_grid(batched_lr[:4])
            grid2 = make_grid(batched_hr[:4])
            grid3 = make_grid(predicted_sr[:4])
            grid4 = make_grid(bilinear_sr[:4])

            grid1 = wandb.Image(grid1, caption="Low Resolution DEM")
            grid2 = wandb.Image(grid2, caption="High Resolution DEM")
            grid3 = wandb.Image(grid3, caption="Reconstructed High Resolution DEM")
            grid4 = wandb.Image(grid4, caption="Bilinear High Resolution DEM")
            
            wandb.log({"Original LR": grid1})
            wandb.log({"Original HR": grid2})
            wandb.log({"Reconstruced": grid3})
            wandb.log({"Bilinear": grid4})

        ssim_per_epoch /= float(len(test_loader))
        psnr_per_epoch /= float(len(test_loader))
        b_ssim_per_epoch /= float(len(test_loader))
        b_psnr_per_epoch /= float(len(test_loader))

        wandb.log({"Test Predicted SSIM": ssim_per_epoch})
        wandb.log({"Test Predicted PSNR": psnr_per_epoch})
        wandb.log({"Bilinear SSIM": b_ssim_per_epoch})
        wandb.log({"Bilinear PSNR": b_psnr_per_epoch})

        print(f"(Val) SSIM: {ssim_per_epoch:.3f} | PSNR: {psnr_per_epoch:.3f}")
        print(f"(Bil) SSIM: {b_ssim_per_epoch:.3f} | PSNR: {b_psnr_per_epoch:.3f}")
        
        torch.cuda.empty_cache()
        gc.collect()
        
        return psnr_per_epoch 


In [None]:
best_psnr = 0
count = 0
prev_psnr =0 
for i in range(150):
    torch.cuda.empty_cache()
    gc.collect()
    train_psnr = train_one_epoch(i)
    valid_psnr = valid_one_epoch(i)
    
    if valid_psnr >= prev_psnr:
        count =0
    else :
        count +=1
        
        if count ==5 :
            network = network.load_state_dict(torch.load(f"best_model_{best_psnr}.pt"))
    
    
    
    if valid_psnr > best_psnr:
        best_psnr = valid_psnr
        torch.save(network.state_dict(), f"best_model_{best_psnr}.pt")