### SRCNN

This notebook implements SRCNN model along with training and test data creation.

In [None]:
"""
Import Library
"""
from torch import nn
import torch
import numpy as np
import glob
from PIL import Image
import os
import h5py
from torch.utils.data import Dataset
import torch.optim as optim
from torch.utils.data.dataloader import DataLoader
import copy
from tqdm import tqdm
from pytorch_ssim import pytorch_ssim

In [None]:
"""
SRCNN model
"""
class SRCNN(nn.Module):
    def __init__(self, num_channels=1):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
        self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x

In [None]:
"""
Utility Function
Helper function to convert from RGB to Y
"""
def convert_rgb_to_y(img):
    return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.

"""
Convert RGB to YCbCr
"""
def convert_rgb_to_ycbcr(img):
    y = 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
    cb = 128. + (-37.945 * img[:, :, 0] - 74.494 * img[:, :, 1] + 112.439 * img[:, :, 2]) / 256.
    cr = 128. + (112.439 * img[:, :, 0] - 94.154 * img[:, :, 1] - 18.285 * img[:, :, 2]) / 256.
    return np.array([y, cb, cr]).transpose([1, 2, 0])


"""
Convert YCbCr to RGB
"""
def convert_ycbcr_to_rgb(img):
    r = 298.082 * img[:, :, 0] / 256. + 408.583 * img[:, :, 2] / 256. - 222.921
    g = 298.082 * img[:, :, 0] / 256. - 100.291 * img[:, :, 1] / 256. - 208.120 * img[:, :, 2] / 256. + 135.576
    b = 298.082 * img[:, :, 0] / 256. + 516.412 * img[:, :, 1] / 256. - 276.836
    return np.array([r, g, b]).transpose([1, 2, 0])

"""
Calculate PSNR
"""
def calc_psnr(img1, img2):
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))

In [None]:
"""
Create dataset using h5 format
"""
def create_data(path, output):

    h5_file = h5py.File(os.path.join(path, output), 'w')

    hr_image_path = os.path.join(path, 'images_stage3/*.png')
    lr_image_path = os.path.join(path, 'images_stage4/*.png')

    hr_image_list = glob.glob(hr_image_path)
    lr_image_list = glob.glob(lr_image_path)

    hr_imgs = []
    lr_imgs = []

    for i in range(len(hr_image_list)):

        # open image
        hr = Image.open(hr_image_list[i]).convert('RGB')
        lr = Image.open(lr_image_list[i]).convert('RGB')

        # convert data type
        hr = np.array(hr).astype(np.float32)
        lr = np.array(lr).astype(np.float32)

        # convert rgb to y
        hr = convert_rgb_to_y(hr)
        lr = convert_rgb_to_y(lr)

        hr_imgs.append(hr)
        lr_imgs.append(lr)

    hr_imgs = np.array(hr_imgs)
    lr_imgs = np.array(lr_imgs)

    h5_file.create_dataset('lr', np.shape(lr_imgs), h5py.h5t.STD_U8BE, data=lr_imgs)
    h5_file.create_dataset('hr', np.shape(hr_imgs), h5py.h5t.STD_U8BE, data=hr_imgs)

    h5_file.close()

In [None]:
create_data('dataset/train', 'train_full.h5')
create_data('dataset/val', 'val_full.h5')

In [None]:
"""
Dataset feeding
"""
class CustomDataset(Dataset):
    def __init__(self, h5_file):
        super(CustomDataset, self).__init__()
        self.h5_file = h5_file

    def __getitem__(self, idx):
        with h5py.File(self.h5_file, 'r') as f:
            return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)

    def __len__(self):
        with h5py.File(self.h5_file, 'r') as f:
            return len(f['lr'])

In [None]:
"""
Setup the model
"""
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(123)

model = SRCNN().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam([
        {'params': model.conv1.parameters()},
        {'params': model.conv2.parameters()},
        {'params': model.conv3.parameters(), 'lr': 1e-4 * 0.1}
    ], lr=1e-4)

In [None]:
"""
Setup data loder
"""
batch_size = 4

train_dataset = CustomDataset('dataset/train/train_full.h5')
train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=0,
                                  pin_memory=True,
                                  drop_last=True)
eval_dataset = CustomDataset('dataset/val/val_full.h5')
eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=batch_size)

In [None]:
"""
Util function to measure error
"""
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
"""
Train and val the model
"""
best_weights = copy.deepcopy(model.state_dict())
best_epoch = 0
best_psnr = 0.0

num_epoch = 20

for epoch in range(num_epoch):
    model.train()
    epoch_losses = AverageMeter()

    with tqdm(total=(len(train_dataset) - len(train_dataset) % batch_size)) as t:
        t.set_description('epoch: {}/{}'.format(epoch, num_epoch - 1))
        
        # training
        for data in train_dataloader:
            inputs, labels = data

            inputs = inputs.to(device, dtype=torch.float)
            labels = labels.to(device, dtype=torch.float)

            preds = model(inputs)
            loss = criterion(preds, labels)
            epoch_losses.update(loss.item(), len(inputs))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
            t.update(len(inputs))

        torch.save(model.state_dict(), os.path.join('weight_srcnn', 'epoch_{}.pth'.format(epoch)))
        
        # validation
        model.eval()
        epoch_psnr = AverageMeter()

        for data in eval_dataloader:
            inputs, labels = data

            inputs = inputs.to(device, dtype=torch.float)
            labels = labels.to(device, dtype=torch.float)

            with torch.no_grad():
                preds = model(inputs).clamp(0.0, 1.0)

            epoch_psnr.update(calc_psnr(preds, labels), len(inputs))

        print('eval psnr: {:.2f}'.format(epoch_psnr.avg))

        # save best weight
        if epoch_psnr.avg > best_psnr:
            best_epoch = epoch
            best_psnr = epoch_psnr.avg
            best_weights = copy.deepcopy(model.state_dict())

    print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
    torch.save(best_weights, os.path.join('weight_srcnn', 'best.pth'))

In [None]:
"""
Evaluate the model with test set
"""

# setup to GPU if available
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# initialize model with best weight
model = SRCNN().to(device)
state_dict = model.state_dict()
for n, p in torch.load('weight_srcnn/best.pth', map_location=lambda storage, loc: storage).items():
    if n in state_dict.keys():
        state_dict[n].copy_(p)
    else:
        raise KeyError(n)

# set model to evaluation mode
model.eval()


In [None]:
lr_image_path = 'dataset/test/images_stage4/*.png'
lr_image_list = glob.glob(lr_image_path)
hr_image_path = 'dataset/test/images_stage3/*.png'
hr_image_list = glob.glob(hr_image_path)

psnr_total = 0
ssim_total = 0

for i, img in enumerate(lr_image_list):
    image = Image.open(img).convert('RGB')
    image = np.array(image).astype(np.float32)
    ycbcr = convert_rgb_to_ycbcr(image)

    # perform image transformation
    y = ycbcr[..., 0]
    y /= 255.
    y = torch.from_numpy(y).to(device)
    y = y.unsqueeze(0).unsqueeze(0)

    label = Image.open(hr_image_list[i]).convert('RGB')
    label = np.array(label).astype(np.float32)

    y_l = label[..., 0]
    y_l /= 255.
    y_l = torch.from_numpy(y_l).to(device)
    y_l = y_l.unsqueeze(0).unsqueeze(0)

    with torch.no_grad():
        preds = model(y).clamp(0.0, 1.0)

    psnr = calc_psnr(y_l, preds)
    psnr_total += psnr
    ssim = pytorch_ssim.ssim(y_l, preds)
    ssim_total += ssim
    print('PSNR: {:.2f}'.format(psnr))
    print('SSIM: {:.2f}'.format(ssim))

    # inverse transform and save images
    preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)
    output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
    output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
    output = Image.fromarray(output)
    output.save(f'result_srcnn_new/img_{i}.png')

psnr_total /= len(lr_image_list)
ssim_total /= len(lr_image_list)
print('PSNR_T: {:.4f}'.format(psnr_total))
print('SSIM_T: {:.4f}'.format(ssim_total))