# Train

In [None]:
import os
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import to_tensor
from torch.utils.data import Dataset
import torch
from torch import nn
from torchvision.models import vgg19
import math
import torch.nn.functional as F
import pandas as pd
import torchvision
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import torch.optim.lr_scheduler as lr_scheduler

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])


def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)


def train_hr_transform(crop_size):
    return transforms.Compose([
        transforms.RandomCrop(crop_size, pad_if_needed=True),
        transforms.RandomVerticalFlip(p=0.3),
        transforms.RandomHorizontalFlip(p=0.4),
        transforms.ToTensor()
    ])


def train_lr_transform(crop_size, upscale_factor):
    return transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(crop_size // upscale_factor, interpolation=Image.Resampling.BICUBIC),
        transforms.ToTensor()
    ])


class TrainDataset(Dataset):
    def __init__(self, dataset_dir, crop_size, upscale_factor):
        super(TrainDataset, self).__init__()
        self.image_filenames = []
        for ds_path in dataset_dir:
            for x in os.listdir(ds_path):
                self.image_filenames.append(os.path.join(ds_path, x))
        crop_size = calculate_valid_crop_size(crop_size, upscale_factor)
        self.hr_transform = train_hr_transform(crop_size)
        self.lr_transform = train_lr_transform(crop_size, upscale_factor)

    def __getitem__(self, index):
        hr_image = self.hr_transform(Image.open(self.image_filenames[index]).convert('RGB'))
        lr_image = self.lr_transform(hr_image)
        return lr_image, hr_image

    def __len__(self):
        return len(self.image_filenames)


class ValDataset(Dataset):
    def __init__(self, dataset_dir, crop_size, upscale_factor):
        super(ValDataset, self).__init__()
        self.image_filenames = []
        for ds_path in dataset_dir:
            for x in os.listdir(ds_path):
                self.image_filenames.append(os.path.join(ds_path, x))
        self.crop_size = calculate_valid_crop_size(crop_size, upscale_factor)
        self.upscale_factor = upscale_factor

    def __getitem__(self, index):
        hr_image = Image.open(self.image_filenames[index]).convert('RGB')
        
        lr_scale = transforms.Resize(self.crop_size // self.upscale_factor, interpolation=Image.Resampling.BICUBIC)
        hr_scale = transforms.Resize(self.crop_size, interpolation=Image.Resampling.BICUBIC)
        hr_image = transforms.CenterCrop(self.crop_size)(hr_image)
        lr_image = lr_scale(hr_image)
        hr_restore_img = hr_scale(lr_image)
        return to_tensor(lr_image), to_tensor(hr_restore_img), to_tensor(hr_image)

    def __len__(self):
        return len(self.image_filenames)

    
class Discriminator(nn.Module):
    def __init__(
        self,
        in_channels: int = 3,
        features: tuple = (64, 64, 128, 128, 256, 256, 512, 512),
    ) -> None:
        super(Discriminator, self).__init__()

        blocks = []
        for idx, feature in enumerate(features):
            blocks.append(
                ConvBlock(
                    in_channels,
                    feature,
                    kernel_size=3,
                    stride=1 + idx % 2,
                    padding=1,
                    discriminator=True,
                    use_act=True,
                    use_bn=False if idx == 0 else True,
                )
            )
            in_channels = feature

        self.blocks = nn.Sequential(*blocks)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((6, 6)),
            nn.Flatten(),
            nn.Linear(512 * 6 * 6, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.blocks(x)
        return torch.sigmoid(self.classifier(x))

class GeneratorLoss(nn.Module):
    def __init__(self):
        super(GeneratorLoss, self).__init__()
        vgg = vgg19(pretrained=True)
        loss_network = nn.Sequential(*list(vgg.features)).eval()
        for param in loss_network.parameters():
            param.requires_grad = False
        self.loss_network = loss_network
        self.mse_loss = nn.MSELoss()
        self.tv_loss = TVLoss()
        self.preprocess = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def forward(self, fake_out, real_out, out_images, target_images):
        # Adversarial Loss
        adversarial_loss = nn.BCEWithLogitsLoss()(fake_out - real_out, target_real)
        # Perception Loss
        a = self.preprocess(out_images)
        b = self.preprocess(target_images)
        perception_loss = self.mse_loss(self.loss_network(a), self.loss_network(b))
        # Image Loss
        # image_loss = self.mse_loss(out_images, target_images)
        # TV Loss
        tv_loss = self.tv_loss(out_images)
        return 0.6 * adversarial_loss + perception_loss + 1e-6 * tv_loss


class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super(TVLoss, self).__init__()
        self.tv_loss_weight = tv_loss_weight

    def forward(self, x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self.tensor_size(x[:, :, 1:, :])
        count_w = self.tensor_size(x[:, :, :, 1:])
        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
        return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

    @staticmethod
    def tensor_size(t):
        return t.size()[1] * t.size()[2] * t.size()[3]

def gaussian(window_size, sigma):
    gauss = torch.Tensor([math.exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
    return gauss / gauss.sum()


def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
    return window


def _ssim(img1, img2, window, window_size, channel, size_average=True):
    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2

    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)


class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)

            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)

            self.window = window
            self.channel = channel

        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)


def ssim(img1, img2, window_size=11, size_average=True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)

    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)

    return _ssim(img1, img2, window, window_size, channel, size_average)

CROP_SIZE = 96
UPSCALE_FACTOR = 4
NUM_EPOCHS = 200
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 16

torch.backends.cudnn.benchmark = True
torch.cuda.manual_seed_all(42)

torch.backends.cudnn.benchmark = True
torch.cuda.manual_seed_all(42)

train_path = ["../input/df2k-ost/train/DIV2K/DIV2K_train_HR", "../input/df2k-ost/train/Flickr2K", "../input/df2k-ost/train/OST"]
test_path = ["../input/df2k-ost/test/DIV2K_valid"]

train_set = TrainDataset(train_path, crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
val_set = ValDataset(test_path, crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)

train_loader = DataLoader(
    dataset=train_set,
    num_workers=2,
    batch_size=BATCH_SIZE,
    shuffle=True,
    pin_memory=True,
)
val_loader = DataLoader(dataset=val_set, num_workers=2, batch_size=1, shuffle=False)

netG = Generator(upscale_factor=UPSCALE_FACTOR).to(DEVICE)
print("# generator parameters:", sum(param.numel() for param in netG.parameters()))
netG.load_state_dict(torch.load("../input/finalmse/netG_4x_epoch200.pth.tar")['model'])

netD = Discriminator().to(DEVICE)
print(
    "# discriminator parameters:", sum(param.numel() for param in netD.parameters())
)

generator_criterion = GeneratorLoss().to(DEVICE)

optimizerG = torch.optim.AdamW(netG.parameters(), lr=1e-5)
optimizerD = torch.optim.AdamW(netD.parameters(), lr=1e-5)

scheduler_G = lr_scheduler.StepLR(optimizerG, step_size=100, gamma=0.5)
scheduler_D = lr_scheduler.StepLR(optimizerD, step_size=100, gamma=0.5)

results = {
    "d_loss": [],
    "g_loss": [],
    "psnr": [],
    "ssim": []
}

for epoch in range(1, NUM_EPOCHS + 1):
    scheduler_G.step()
    scheduler_D.step()
    cur_lr = optimizerG.param_groups[0]['lr']
    train_bar = tqdm(train_loader, total=len(train_loader))
    running_results = {
        "batch_sizes": 0,
        "d_loss": 0,
        "g_loss": 0,
        "learning_rate": cur_lr
    }

    netG.train()
    netD.train()
    for lr_img, hr_img in train_bar:
        batch_size = lr_img.size(0)
        running_results["batch_sizes"] += batch_size

        ############################
        # (1) Update D network: maximize D(x)-1-D(G(z))
        ###########################
        hr_img = hr_img.to(DEVICE)
        lr_img = lr_img.to(DEVICE)

        sr_img = netG(lr_img)

        netD.zero_grad()
        real_out = netD(hr_img)
        fake_out = netD(sr_img)
        target_real = torch.Tensor(batch_size, 1).fill_(1.0).to(DEVICE)
        target_fake = torch.Tensor(batch_size, 1).fill_(0.0).to(DEVICE)

        d_loss = nn.BCEWithLogitsLoss()(real_out - fake_out, target_real)
        d_loss.backward(retain_graph=True)
        optimizerD.step()

        ############################
        # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
        ###########################
        netG.zero_grad()

        sr_img = netG(lr_img)
        fake_out = netD(sr_img)
        real_out = netD(hr_img)

        g_loss = generator_criterion(fake_out, real_out, sr_img, hr_img)
        g_loss.backward()

        optimizerG.step()

        # loss for current after before optimization
        running_results["g_loss"] += g_loss.item() * batch_size
        running_results["d_loss"] += d_loss.item() * batch_size

        train_bar.set_description(
            desc="[%d/%d] Loss_D: %f Loss_G: %f Learning_rate: %f"
            % (
                epoch,
                NUM_EPOCHS,
                running_results["d_loss"] / running_results["batch_sizes"],
                running_results["g_loss"] / running_results["batch_sizes"],
                running_results["learning_rate"]
            )
        )

    netG.eval()

    with torch.no_grad():
        val_bar = tqdm(val_loader, total=len(val_loader))
        valing_results = {
            "mse": 0,
            "ssims": 0,
            "psnr": 0,
            "ssim": 0,
            "batch_sizes": 0,
        }
        val_images = []
        for val_lr, val_hr in val_bar:
            batch_size = val_lr.size(0)
            valing_results["batch_sizes"] += batch_size
            lr = val_lr
            hr = val_hr
            if torch.cuda.is_available():
                lr = lr.cuda()
                hr = hr.cuda()
            # Forward
            sr = netG(lr)
            # Loss & metrics
            batch_mse = ((sr - hr) ** 2).data.mean()
            valing_results["mse"] += batch_mse * batch_size
            batch_ssim = ssim(sr, hr).item()

            valing_results["ssims"] += batch_ssim * batch_size
            valing_results["psnr"] = 10 * math.log10(
                (hr.max() ** 2)
                / (valing_results["mse"] / valing_results["batch_sizes"])
            )
            valing_results["ssim"] = (
                valing_results["ssims"] / valing_results["batch_sizes"]
            )
            val_bar.set_description(
                desc="[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f"
                % (valing_results["psnr"], valing_results["ssim"])
            )

    # save model parameters
    netG.train()
    netD.train()
    
    #########################
    torch.save(
        {"model": netG.state_dict()},
        f"./netG_{UPSCALE_FACTOR}x_epoch{epoch}.pth.tar",
    )
    torch.save(
        {"model": netD.state_dict()},
        f"./netD_{UPSCALE_FACTOR}x_epoch{epoch}.pth.tar",
    )
    #########################

    results["d_loss"].append(
        running_results["d_loss"] / running_results["batch_sizes"]
    )
    results["g_loss"].append(
        running_results["g_loss"] / running_results["batch_sizes"]
    )

    results["psnr"].append(valing_results["psnr"])
    results["ssim"].append(valing_results["ssim"])

# Architecture

In [None]:
import torch
from torch import nn

class SeperableConv2d(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True, padding_mode='zeros'):
        super(SeperableConv2d, self).__init__(
            nn.Conv2d(
                in_channels=in_channels, out_channels=in_channels, groups=in_channels, 
                kernel_size=kernel_size, padding='same', dilation=dilation,
                bias=bias, padding_mode=padding_mode
            ),
            nn.Conv2d(in_channels, out_channels, kernel_size=1)
        )


class UpsampleBlock(nn.Sequential):
    def __init__(self, in_channels, scale_factor):
        super(UpsampleBlock, self).__init__(
            SeperableConv2d(in_channels, in_channels * scale_factor**2, kernel_size=3),
            nn.PixelShuffle(scale_factor),
            nn.PReLU(num_parameters=in_channels),
            SeperableConv2d(in_channels, in_channels * scale_factor**2, kernel_size=3),
            nn.PixelShuffle(scale_factor),
            nn.PReLU(num_parameters=in_channels)
        )


class ConvBlock(nn.Sequential):
    def __init__(self, in_channels, out_channels, use_bn=False, use_ffc=False, use_act=True, discriminator=False, **kwargs):
        if use_ffc: conv = FFC(in_channels, out_channels, kernel_size=3, 
                ratio_gin=0.5, ratio_gout=0.5, inline = True
            )
        else: conv = SeperableConv2d(in_channels, out_channels, **kwargs)
        m = [conv]
        
        if use_bn: m.append(nn.BatchNorm2d(out_channels))
        if use_act: m.append(nn.LeakyReLU(0.2, inplace=True) if discriminator else nn.PReLU(num_parameters=out_channels))
        super(ConvBlock, self).__init__(*m)


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, index):
        super(ResidualBlock, self).__init__()
        
        self.block1 = ConvBlock(
            in_channels,
            in_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            use_ffc=True if index % 2 == 0 else False
        )
        self.block2 = ConvBlock(
            in_channels,
            in_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            use_act=False
        )
        
    def forward(self, x):
        out = self.block1(x)
        out = self.block2(out)
        out = out.mul(0.1)
        out += x
        return out


class Generator(nn.Module):
    def __init__(self, in_channels: int = 3, num_channels: int = 64, num_blocks: int = 16, upscale_factor: int = 4):
        super(Generator, self).__init__()
        
        self.initial = ConvBlock(in_channels, num_channels, kernel_size=3, use_act=False)
        self.residual = nn.Sequential(
            *[ResidualBlock(num_channels, _) for _ in range(num_blocks)]
        )
        self.upsampler = UpsampleBlock(num_channels, scale_factor=2)
        self.final_conv = SeperableConv2d(num_channels, in_channels, kernel_size=3)
        
    def forward(self, x):
        initial = self.initial(x)
        x = self.residual(initial) + initial
        x = self.upsampler(x)
        out = self.final_conv(x)
        return out


class FourierUnit(nn.Module):

    def __init__(self, in_channels, out_channels, ffc3d=False, fft_norm='ortho'):
        super(FourierUnit, self).__init__()
        self.conv_layer = SeperableConv2d(in_channels=in_channels * 2,
                                          out_channels=out_channels * 2,
                                          kernel_size=1, stride=1, padding=0, bias=False)
        self.relu = torch.nn.ReLU(inplace=True)

        self.ffc3d = ffc3d
        self.fft_norm = fft_norm

    def forward(self, x):
        batch = x.shape[0]
        r_size = x.size()

        # (batch, c, h, w/2+1, 2)
        fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
        ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
        ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
        ffted = ffted.permute(0, 1, 4, 2, 3).contiguous()  # (batch, c, 2, h, w/2+1)
        ffted = ffted.view((batch, -1,) + ffted.size()[3:])

        ffted = self.conv_layer(ffted)  # (batch, c*2, h, w/2+1)
        ffted = self.relu(ffted)

        ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
            0, 1, 3, 4, 2).contiguous()  # (batch,c, t, h, w/2+1, 2)
        ffted = torch.complex(ffted[..., 0], ffted[..., 1])

        ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
        output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)

        return output


class SpectralTransform(nn.Module):

    def __init__(self, in_channels, out_channels, stride=1, enable_lfu=True, **fu_kwargs):
        super(SpectralTransform, self).__init__()
        self.enable_lfu = enable_lfu
        if stride == 2:
            self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        else:
            self.downsample = nn.Identity()

        self.stride = stride
        self.conv1 = nn.Sequential(
            SeperableConv2d(in_channels, out_channels // 2, kernel_size=1, bias=False),
            nn.ReLU(inplace=True)
        )
        self.fu = FourierUnit(
            out_channels // 2, out_channels // 2, **fu_kwargs)

        if self.enable_lfu:
            self.lfu = FourierUnit(out_channels // 2, out_channels // 2)
        self.conv2 = SeperableConv2d(out_channels // 2, out_channels, kernel_size=1, bias=False)

    def forward(self, x):

        x = self.downsample(x)
        x = self.conv1(x)
        output = self.fu(x)
        
        if self.enable_lfu:
            n, c, h, w = x.shape
            split_no = 2
            split_h = h // split_no
            split_w = w // split_no
            xs = torch.cat(torch.split(x[:, :c // 4], split_h, dim=-2)[0:2], dim=1).contiguous()
            xs = torch.cat(torch.split(xs, split_w, dim=-1)[0:2], dim=1).contiguous()
            xs = self.lfu(xs)
            xs = xs.repeat(1, 1, split_no, split_no).contiguous()

            if h % 2 == 1:
                h_zeros = torch.zeros(xs.shape[0], xs.shape[1], 1, xs.shape[3]).to(DEVICE)
                xs = torch.cat((xs, h_zeros), dim=2)
            if w % 2 == 1:
                w_zeros = torch.zeros(xs.shape[0], xs.shape[1], xs.shape[2], 1).to(DEVICE)
                xs = torch.cat((xs, w_zeros), dim=3)
        else:
            xs = 0

        output = self.conv2(x + output + xs)

        return output


class FFC(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size,
                 ratio_gin, ratio_gout, inline=True, stride=1, padding=0,
                 dilation=1, enable_lfu=True,
                 padding_type='reflect', gated=False, **spectral_kwargs):
        super(FFC, self).__init__()

        assert stride == 1 or stride == 2, "Stride should be 1 or 2."
        self.stride = stride
        self.inline = inline

        in_cg = int(in_channels * ratio_gin)
        in_cl = in_channels - in_cg
        out_cg = int(out_channels * ratio_gout)
        out_cl = out_channels - out_cg

        self.ratio_gin = ratio_gin
        self.ratio_gout = ratio_gout
        self.global_in_num = in_cg

        module = nn.Identity if in_cl == 0 or out_cl == 0 else SeperableConv2d
        self.convl2l = module(in_cl, out_cl, kernel_size,
                              stride, padding, dilation, padding_mode=padding_type)
        module = nn.Identity if in_cl == 0 or out_cg == 0 else SeperableConv2d
        self.convl2g = module(in_cl, out_cg, kernel_size,
                              stride, padding, dilation, padding_mode=padding_type)
        module = nn.Identity if in_cg == 0 or out_cl == 0 else SeperableConv2d
        self.convg2l = module(in_cg, out_cl, kernel_size,
                              stride, padding, dilation, padding_mode=padding_type)
        module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
        self.convg2g = module(
            in_cg, out_cg, stride, enable_lfu, **spectral_kwargs)

        self.gated = gated
        module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else SeperableConv2d
        self.gate = module(in_channels, 2, 1)

    def forward(self, x):
        if self.inline:
            x_l, x_g = x[:, :-self.global_in_num], x[:, -self.global_in_num:]
        else:
            x_l, x_g = x if type(x) is tuple else (x, 0)
        out_xl, out_xg = 0, 0

        if self.gated:
            total_input_parts = [x_l]
            if torch.is_tensor(x_g):
                total_input_parts.append(x_g)
            total_input = torch.cat(total_input_parts, dim=1)

            gates = torch.sigmoid(self.gate(total_input))
            g2l_gate, l2g_gate = gates.chunk(2, dim=1)
        else:
            g2l_gate, l2g_gate = 1, 1

        if self.ratio_gout != 1:
            out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
        if self.ratio_gout != 0:
            out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)
            
        out = out_xl, out_xg
        if self.inline:
            out = torch.cat(out, dim=1)

        return out