### RCAN

This notebook implements RCAN model along with training and test data creation.

In [None]:
"""
Import Library
"""
from torch import nn
import h5py
import numpy as np
import glob
import os
from PIL import Image
from torch.utils.data import Dataset
import torch.optim as optim
from torch.utils.data.dataloader import DataLoader
import torch
from tqdm import tqdm
from collections import namedtuple
import copy

In [None]:
"""
RCAN model
"""
class ChannelAttention(nn.Module):
    def __init__(self, num_features, reduction):
        super(ChannelAttention, self).__init__()
        self.module = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(num_features, num_features // reduction, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(num_features // reduction, num_features, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return x * self.module(x)

class RCAB(nn.Module):
    def __init__(self, num_features, reduction):
        super(RCAB, self).__init__()
        self.module = nn.Sequential(
            nn.Conv2d(num_features, num_features, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(num_features, num_features, kernel_size=3, padding=1),
            ChannelAttention(num_features, reduction)
        )

    def forward(self, x):
        return x + self.module(x)

class RG(nn.Module):
    def __init__(self, num_features, num_rcab, reduction):
        super(RG, self).__init__()
        self.module = [RCAB(num_features, reduction) for _ in range(num_rcab)]
        self.module.append(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1))
        self.module = nn.Sequential(*self.module)

    def forward(self, x):
        return x + self.module(x)

class RCAN(nn.Module):
    def __init__(self, args):
        super(RCAN, self).__init__()
        scale = args.scale
        num_features = args.num_features
        num_rg = args.num_rg
        num_rcab = args.num_rcab
        reduction = args.reduction

        self.sf = nn.Conv2d(3, num_features, kernel_size=3, padding=1)
        # RIR layers
        self.rgs = nn.Sequential(*[RG(num_features, num_rcab, reduction) for _ in range(num_rg)])
        self.conv1 = nn.Conv2d(num_features, num_features, kernel_size=3, padding=1)
        self.upscale = nn.Sequential(
            nn.Conv2d(num_features, num_features * (scale ** 2), kernel_size=3, padding=1),
            nn.PixelShuffle(scale)
        )
        self.conv2 = nn.Conv2d(num_features, 3, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.sf(x)
        residual = x
        x = self.rgs(x)
        x = self.conv1(x)
        x += residual
        x = self.upscale(x)
        x = self.conv2(x)
        return x

In [None]:
"""
Setup the dataset
"""
def create_data(path, output):

    h5_file = h5py.File(os.path.join(path, output), 'w')

    hr_image_path = os.path.join(path, 'images_stage3/*.png')
    lr_image_path = os.path.join(path, 'images_stage5/*.png')

    hr_image_list = glob.glob(hr_image_path)
    lr_image_list = glob.glob(lr_image_path)

    hr_imgs = []
    lr_imgs = []

    for i in range(len(hr_image_list)):

        # open image
        hr = Image.open(hr_image_list[i]).convert('RGB')
        lr = Image.open(lr_image_list[i]).convert('RGB')

        # convert data type
        hr = np.array(hr).astype(np.float32)
        lr = np.array(lr).astype(np.float32)

        # transpose and normalize
        hr = np.transpose(hr, axes=[2, 0, 1])
        lr = np.transpose(lr, axes=[2, 0, 1])

        hr /= 255.0
        lr /= 255.0

        hr_imgs.append(hr)
        lr_imgs.append(lr)

    hr_imgs = np.array(hr_imgs)
    lr_imgs = np.array(lr_imgs)

    h5_file.create_dataset('lr', np.shape(lr_imgs), h5py.h5t.IEEE_F32LE, data=lr_imgs)
    h5_file.create_dataset('hr', np.shape(hr_imgs), h5py.h5t.IEEE_F32LE, data=hr_imgs)

    h5_file.close()

In [None]:
create_data('dataset/train', 'train_full_s.h5')
create_data('dataset/val', 'val_full_s.h5')

In [None]:
"""
Dataset feeding
"""
class CustomDataset(Dataset):
    def __init__(self, h5_file):
        super(CustomDataset, self).__init__()
        self.h5_file = h5_file

    def __getitem__(self, idx):
        with h5py.File(self.h5_file, 'r') as f:
            return f['lr'][idx], f['hr'][idx]

    def __len__(self):
        with h5py.File(self.h5_file, 'r') as f:
            return len(f['lr'])

In [None]:
"""
Model Setup
"""
torch.manual_seed(123)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

ParamStrct = namedtuple('ParamStrct', 'scale num_features num_rg num_rcab reduction')
param = ParamStrct(4, 64, 6, 16, 16)
batch_size = 1

model = RCAN(param).to(device)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
"""
Setup data loader
"""
train_dataset = CustomDataset('dataset/train/train_full_s.h5')
train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=0,
                                  pin_memory=True,
                                  drop_last=True)
eval_dataset = CustomDataset('dataset/val/val_full_s.h5')
eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=batch_size)

In [None]:
"""
Util function to measure error
"""
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

"""
Calculate PSNR
"""
def calc_psnr(img1, img2):
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))

In [None]:
"""
Train and val the model
"""
best_weights = copy.deepcopy(model.state_dict())
best_epoch = 0
best_psnr = 0.0

num_epoch = 20

for epoch in range(num_epoch):
    model.train()
    epoch_losses = AverageMeter()

    with tqdm(total=(len(train_dataset) - len(train_dataset) % batch_size)) as t:
        t.set_description('epoch: {}/{}'.format(epoch, num_epoch - 1))
        
        # training
        for data in train_dataloader:
            inputs, labels = data

            inputs = inputs.to(device, dtype=torch.float)
            labels = labels.to(device, dtype=torch.float)

            preds = model(inputs)

            loss = criterion(preds, labels)

            epoch_losses.update(loss.item(), len(inputs))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
            t.update(len(inputs))

        torch.save(model.state_dict(), os.path.join('weight_rcan', 'epoch_{}.pth'.format(epoch)))

        # validation
        model.eval()
        epoch_psnr = AverageMeter()

        for data in eval_dataloader:
            inputs, labels = data

            inputs = inputs.to(device, dtype=torch.float)
            labels = labels.to(device, dtype=torch.float)

            with torch.no_grad():
                preds = model(inputs).clamp(0.0, 1.0)

            epoch_psnr.update(calc_psnr(preds, labels), len(inputs))

        print('eval psnr: {:.2f}'.format(epoch_psnr.avg))

        if epoch_psnr.avg > best_psnr:
            best_epoch = epoch
            best_psnr = epoch_psnr.avg
            best_weights = copy.deepcopy(model.state_dict())

    print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
    torch.save(best_weights, os.path.join('weight_rcan', 'best.pth'))

In [None]:
"""
Evaluate the model with test set
"""
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = RCAN(param).to(device)
state_dict = model.state_dict()
for n, p in torch.load('weight_rcan/best.pth', map_location=lambda storage, loc: storage).items():
    if n in state_dict.keys():
        state_dict[n].copy_(p)
    else:
        raise KeyError(n)

model.eval()

In [None]:
from pytorch_ssim import pytorch_ssim

lr_image_path = 'dataset/test/images_stage5/*.png'
lr_image_list = glob.glob(lr_image_path)
hr_image_path = 'dataset/test/images_stage3/*.png'
hr_image_list = glob.glob(hr_image_path)

psnr_total = 0
ssim_total = 0

for i, img in enumerate(lr_image_list):
    image = Image.open(img).convert('RGB')
    image = np.array(image).astype(np.float32)
    image = np.transpose(image, axes=[2, 0, 1])
    image /= 255.0

    image = torch.from_numpy(image).to(device)
    image = image.unsqueeze(0)

    label = Image.open(hr_image_list[i]).convert('RGB')
    label = np.array(label).astype(np.float32)
    label = np.transpose(label, axes=[2, 0, 1])
    label /= 255.0
    label = torch.from_numpy(label).to(device)
    label = label.unsqueeze(0)

    with torch.no_grad():
        preds = model(image).clamp(0.0, 1.0)

    psnr = calc_psnr(label, preds)
    psnr_total += psnr
    ssim = pytorch_ssim.ssim(label, preds)
    ssim_total += ssim
    print('PSNR: {:.2f}'.format(psnr))
    print('SSIM: {:.2f}'.format(ssim))

    output = preds.mul_(255.0).clamp_(0.0, 255.0).squeeze(0).permute(1, 2, 0).byte().cpu().numpy()
    output = Image.fromarray(output)
    output.save(f'result_rcan_new/img_{i}.png')

psnr_total /= len(lr_image_list)
ssim_total /= len(lr_image_list)
print('PSNR_T: {:.4f}'.format(psnr_total))
print('SSIM_T: {:.4f}'.format(ssim_total))