### LAPSRN

This notebook implements LAPSRN model along with training and test data creation.

In [None]:
"""
Import Library
"""
from torch import nn
import h5py
import numpy as np
import glob
import os
from PIL import Image
from torch.utils.data import Dataset
import torch.optim as optim
from torch.utils.data.dataloader import DataLoader
import torch
from tqdm import tqdm
from collections import namedtuple
import copy
import math

In [None]:
"""
Utility function
"""
def get_upsample_filter(size):
    #Make a 2D bilinear kernel suitable for upsampling
    factor = (size + 1) // 2
    if size % 2 == 1:
        center = factor - 1
    else:
        center = factor - 0.5
    og = np.ogrid[:size, :size]
    filter = (1 - abs(og[0] - center) / factor) * \
             (1 - abs(og[1] - center) / factor)
    return torch.from_numpy(filter).float()

In [None]:
"""
LAPSRN model
"""
class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()

        self.cov_block = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        output = self.cov_block(x)
        return output

class LAPSRN(nn.Module):
    def __init__(self):
        super(LAPSRN, self).__init__()

        self.conv_input = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu = nn.LeakyReLU(0.2, inplace=True)

        self.convt_I1 = nn.ConvTranspose2d(in_channels=3, out_channels=3, kernel_size=4, stride=2, padding=1, bias=False)
        self.convt_R1 = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False)
        self.convt_F1 = self.make_layer(ConvBlock)

        self.convt_I2 = nn.ConvTranspose2d(in_channels=3, out_channels=3, kernel_size=4, stride=2, padding=1, bias=False)
        self.convt_R2 = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False)
        self.convt_F2 = self.make_layer(ConvBlock)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            if isinstance(m, nn.ConvTranspose2d):
                c1, c2, h, w = m.weight.data.size()
                weight = get_upsample_filter(h)
                m.weight.data = weight.view(1, 1, h, w).repeat(c1, c2, 1, 1)
                if m.bias is not None:
                    m.bias.data.zero_()

    def make_layer(self, block):
        layers = []
        layers.append(block())
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.relu(self.conv_input(x))

        convt_F1 = self.convt_F1(out)
        convt_I1 = self.convt_I1(x)
        convt_R1 = self.convt_R1(convt_F1)
        HR_2x = convt_I1 + convt_R1

        convt_F2 = self.convt_F2(convt_F1)
        convt_I2 = self.convt_I2(HR_2x)
        convt_R2 = self.convt_R2(convt_F2)
        HR_4x = convt_I2 + convt_R2

        return HR_2x, HR_4x

In [None]:
"""
Loss function
"""
class L1_Charbonnier_loss(nn.Module):
    """L1 Charbonnierloss."""
    def __init__(self):
        super(L1_Charbonnier_loss, self).__init__()
        self.eps = 1e-6

    def forward(self, X, Y):
        diff = torch.add(X, -Y)
        error = torch.sqrt( diff * diff + self.eps )
        loss = torch.sum(error)
        return loss

In [None]:
"""
Setup the dataset
"""
def create_data(path, output):

    h5_file = h5py.File(os.path.join(path, output), 'w')

    hr_image_path = os.path.join(path, 'images_stage3/*.png')
    lr_image_path = os.path.join(path, 'images_stage5/*.png')
    mid_image_path = os.path.join(path, 'images_stage6/*.png')

    hr_image_list = glob.glob(hr_image_path)
    lr_image_list = glob.glob(lr_image_path)
    mid_image_list = glob.glob(mid_image_path)

    hr_imgs = []
    lr_imgs = []
    mid_imgs = []

    for i in range(len(hr_image_list)):

        # open image
        hr = Image.open(hr_image_list[i]).convert('RGB')
        lr = Image.open(lr_image_list[i]).convert('RGB')
        mid = Image.open(mid_image_list[i]).convert('RGB')

        # convert data type
        hr = np.array(hr).astype(np.float32)
        lr = np.array(lr).astype(np.float32)
        mid = np.array(mid).astype(np.float32)

        # transpose and normalize
        hr = np.transpose(hr, axes=[2, 0, 1])
        lr = np.transpose(lr, axes=[2, 0, 1])
        mid = np.transpose(mid, axes=[2, 0, 1])

        hr /= 255.0
        lr /= 255.0
        mid /= 255.0

        hr_imgs.append(hr)
        lr_imgs.append(lr)
        mid_imgs.append(mid)

    hr_imgs = np.array(hr_imgs)
    lr_imgs = np.array(lr_imgs)
    mid_imgs = np.array(mid_imgs)

    h5_file.create_dataset('lr', np.shape(lr_imgs), h5py.h5t.IEEE_F32LE, data=lr_imgs)
    h5_file.create_dataset('hr', np.shape(hr_imgs), h5py.h5t.IEEE_F32LE, data=hr_imgs)
    h5_file.create_dataset('mid', np.shape(mid_imgs), h5py.h5t.IEEE_F32LE, data=mid_imgs)

    h5_file.close()


In [None]:
create_data('dataset/train', 'train_full_mid.h5')
create_data('dataset/val', 'val_full_mid.h5')

In [None]:
"""
Dataset feeding
"""
class CustomDataset(Dataset):
    def __init__(self, h5_file):
        super(CustomDataset, self).__init__()
        self.h5_file = h5_file

    def __getitem__(self, idx):
        with h5py.File(self.h5_file, 'r') as f:
            return f['lr'][idx], f['hr'][idx], f['mid'][idx]

    def __len__(self):
        with h5py.File(self.h5_file, 'r') as f:
            return len(f['lr'])

In [None]:
"""
Model Setup
"""
torch.manual_seed(123)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 2

model = LAPSRN().to(device)
criterion = L1_Charbonnier_loss()
optimizer = optim.Adam(model.parameters(), lr=1e-1)

In [None]:
"""
Setup data loader
"""
train_dataset = CustomDataset('dataset/train/train_full_mid.h5')
train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=0,
                                  pin_memory=True,
                                  drop_last=True)
eval_dataset = CustomDataset('dataset/val/val_full_mid.h5')
eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=batch_size)

In [None]:
"""
Util function to measure error
"""
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

"""
Calculate PSNR
"""
def calc_psnr(img1, img2):
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))

In [None]:
"""
Train and val the model
"""
best_weights = copy.deepcopy(model.state_dict())
best_epoch = 0
best_psnr = 0.0

num_epoch = 20

for epoch in range(num_epoch):
    model.train()
    epoch_losses = AverageMeter()

    with tqdm(total=(len(train_dataset) - len(train_dataset) % batch_size)) as t:
        t.set_description('epoch: {}/{}'.format(epoch, num_epoch - 1))

        # adjust learning rate
        lr =  1e-3 * (0.1 ** (epoch // 5))
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr
        
        # training
        for data in train_dataloader:
            inputs, labels, mid_labels = data

            inputs = inputs.to(device, dtype=torch.float)
            labels = labels.to(device, dtype=torch.float)
            mid_labels = mid_labels.to(device, dtype=torch.float)

            pred_mid, pred = model(inputs)

            loss_x2 = criterion(pred_mid, mid_labels)
            loss_x4 = criterion(pred, labels)
            loss = loss_x2 + loss_x4

            epoch_losses.update(loss.item(), len(inputs))

            optimizer.zero_grad()
            loss_x2.backward(retain_graph=True)
            loss_x4.backward()
            optimizer.step()

            t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
            t.update(len(inputs))

        torch.save(model.state_dict(), os.path.join('weight_lapsrn', 'epoch_{}.pth'.format(epoch)))

        # validation
        model.eval()
        epoch_psnr = AverageMeter()

        for data in eval_dataloader:
            inputs, labels, mid_labels = data

            inputs = inputs.to(device, dtype=torch.float)
            labels = labels.to(device, dtype=torch.float)

            with torch.no_grad():
                pred_mid, preds = model(inputs)
                preds = torch.clamp(preds, 0.0, 1.0)

            epoch_psnr.update(calc_psnr(preds, labels), len(inputs))

        print('eval psnr: {:.2f}'.format(epoch_psnr.avg))

        # save the best weight
        if epoch_psnr.avg > best_psnr:
            best_epoch = epoch
            best_psnr = epoch_psnr.avg
            best_weights = copy.deepcopy(model.state_dict())

    print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
    torch.save(best_weights, os.path.join('weight_lapsrn', 'best.pth'))

In [None]:
"""
Evaluate the model with test set
"""
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = LAPSRN().to(device)
state_dict = model.state_dict()
for n, p in torch.load('weight_lapsrn/best.pth', map_location=lambda storage, loc: storage).items():
    if n in state_dict.keys():
        state_dict[n].copy_(p)
    else:
        raise KeyError(n)

model.eval()

In [None]:
from pytorch_ssim import pytorch_ssim

lr_image_path = 'dataset/test/images_stage5/*.png'
lr_image_list = glob.glob(lr_image_path)
hr_image_path = 'dataset/test/images_stage3/*.png'
hr_image_list = glob.glob(hr_image_path)

psnr_total = 0
ssim_total = 0

for i, img in enumerate(lr_image_list):
    image = Image.open(img).convert('RGB')
    image = np.array(image).astype(np.float32)
    image = np.transpose(image, axes=[2, 0, 1])
    image /= 255.0

    image = torch.from_numpy(image).to(device)
    image = image.unsqueeze(0)

    label = Image.open(hr_image_list[i]).convert('RGB')
    label = np.array(label).astype(np.float32)
    label = np.transpose(label, axes=[2, 0, 1])
    label /= 255.0
    label = torch.from_numpy(label).to(device)
    label = label.unsqueeze(0)

    with torch.no_grad():
        pred_mid, preds = model(image)
        preds = torch.clamp(preds, 0.0, 1.0)

    psnr = calc_psnr(label, preds)
    psnr_total += psnr
    ssim = pytorch_ssim.ssim(label, preds)
    ssim_total += ssim
    print('PSNR: {:.2f}'.format(psnr))
    print('SSIM: {:.2f}'.format(ssim))

    output = preds.mul_(255.0).clamp_(0.0, 255.0).squeeze(0).permute(1, 2, 0).byte().cpu().numpy()
    output = Image.fromarray(output)
    output.save(f'result_lapsrn_new/img_{i}.png')

psnr_total /= len(lr_image_list)
ssim_total /= len(lr_image_list)
print('PSNR_T: {:.4f}'.format(psnr_total))
print('SSIM_T: {:.4f}'.format(ssim_total))