## import librarys

In [1]:
import os
import copy
import glob
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import math
from tqdm import tqdm

import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn import MSELoss
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torch.optim import lr_scheduler
from torch.nn.modules.activation import ReLU, Sigmoid
from torch.nn import Conv2d, modules
from torch.nn import Sequential

from math import sqrt

import torch.backends.cudnn as cudnn

## define SRCNN and VDSR models

In [2]:
class SRCNN(nn.Module):
    def __init__(self, num_channels=1):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=num_channels, 
                               out_channels=64, 
                               kernel_size=9, 
                               padding=9 // 2)
        self.conv2 = nn.Conv2d(in_channels=64, 
                               out_channels=32, 
                               kernel_size=5, 
                               padding=5 // 2)
        self.conv3 = nn.Conv2d(in_channels=32, 
                               out_channels=num_channels, 
                               kernel_size=5, 
                               padding=5 // 2)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x

In [3]:
class Conv_ReLU_Block(nn.Module):
    def __init__(self):
        super(Conv_ReLU_Block, self).__init__()
        self.conv = nn.Conv2d(in_channels=64, out_channels=64,
                              kernel_size=3, stride=1, padding=1, bias=False)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.conv(x))


class VDSR(nn.Module):
    def __init__(self):
        super(VDSR, self).__init__()
        self.residual_layer = self.make_layer(Conv_ReLU_Block, 18)
        self.input = nn.Conv2d(
            in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.output = nn.Conv2d(
            in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu = nn.ReLU(inplace=True)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, sqrt(2. / n))

    def make_layer(self, block, num_of_layer):
        layers = []
        for _ in range(num_of_layer):
            layers.append(block())
        return nn.Sequential(*layers)

    def forward(self, x):
        residual = x
        out = self.relu(self.input(x))
        out = self.residual_layer(out)
        out = self.output(out)
        out = torch.add(out, residual)
        return out

## useful functions and class

In [4]:
# utils functions
def img_read(fPath):
    '''
    read the image given path "fPath"
    '''
    img = cv2.imread(fPath, -1) # single channel image
    # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

def downsample(orig_img, scale):
    '''
    downsample by "scale" to get the low resolution image
    '''
    if scale == 1:
        return orig_img
    h_orig, w_orig = orig_img.shape
    h, w = int(h_orig/scale), int(w_orig/scale)
    return cv2.resize(orig_img, (w, h), interpolation=cv2.INTER_NEAREST)

def bicubic_sr(lr_img, scale):
    '''
    bibubic super-resolved reconstruction from lr_img by factor "scale"
    '''
    h, w = lr_img.shape
    h_orig, w_orig = h*scale, w*scale
    return cv2.resize(lr_img, (w_orig, h_orig), interpolation=cv2.INTER_CUBIC)

def computePSNR(img1, img2):
    '''
    compute PSNR(Peak Signal to Noise Ratio) to calculate accuracy
    img1 and img2 have range [0, 1], and both are gray level images
    '''
    if not img1.shape == img2.shape:
        print("Input images must have the same dimensions.")
    mse = torch.mean((img1-img2)**2)
    if mse == 0: # img1 and img2 are same images
        return float('inf')
    return 10.0 * torch.log10(1.0/mse)

def getPatches(dataRoot, field):
    '''
    get the list of patches sorted by order
    '''
    dataset = os.path.join(dataRoot, field)
    patches = []
    for date in os.listdir(dataset):
        dateFolder = os.path.join(dataset, date)
        for patch in os.listdir(dateFolder):
            patches.append(os.path.join(dateFolder, patch))
    return patches

In [5]:
class SRData(Dataset):
    def __init__(self, dataRoot="D:\work/dataset/", field="sst", gt_scale=1, lr_scale=9, transform=None):
        self.dataRoot = dataRoot
        self.field = field
        self.gt_scale = gt_scale
        self.lr_scale = lr_scale
        self.transform = transform
        self.patches = getPatches(self.dataRoot, self.field)
    
    def __getitem__(self, index):
        orig_img = img_read(self.patches[index])# 90*90
        gt_img = downsample(orig_img, self.gt_scale) # 90*90
        lr_img = downsample(orig_img, self.lr_scale) #10*10
        bicub_img = bicubic_sr(lr_img, scale=int(self.lr_scale/self.gt_scale)) # 90*90 bicubic sr
        if self.transform:
            gt_img = self.transform(gt_img)
            bicub_img = self.transform(bicub_img)
        return gt_img, bicub_img

    def __len__(self):
        return len(self.patches)

class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

## prepare test data

In [6]:
'''
prepare data
'''
# convert data to normalized tensor
trans_input = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.0), (1.0)) # do nothing
])

transform = transforms.Compose([
     transforms.ToTensor()
])

entire_dataset = SRData("D:\work/dataset", "sst", 
                     gt_scale=1,
                     lr_scale=9,
                     transform=transform)

# entire test set data
entire_test_indices = torch.arange(57600, 64000)
entire_test_data = torch.utils.data.Subset(entire_dataset, entire_test_indices)
entire_test_dataloader = DataLoader(dataset=entire_test_data, batch_size=1)
# randomly select 5 sample in test data
sample_test_indices = np.random.randint(0, len(entire_test_indices), 5)
sample_test_indices = [entire_test_indices[x] for x in sample_test_indices]
sample_test_data = torch.utils.data.Subset(entire_dataset, sample_test_indices)
sample_test_dataloader = DataLoader(dataset=sample_test_data, batch_size=1)
print("test set length: {}".format(int(len(entire_test_data))))

test set length: 6400


## load model

In [7]:
DEVICE =  "cuda" if torch.cuda.is_available() else "cpu"
# load model
model_srcnn = SRCNN(num_channels=1).to(DEVICE)

params = torch.load("weights/srcnn_x9.pth", map_location=DEVICE)
model_srcnn.load_state_dict(params)

SRCNN(
  (conv1): Conv2d(1, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (conv2): Conv2d(64, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv3): Conv2d(32, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (relu): ReLU(inplace=True)
)


In [8]:
model_srcnn.to(DEVICE)
model_srcnn.eval()
srcnn_meter = AverageMeter()
bicub_meter = AverageMeter()
for data in entire_test_dataloader:
    gt_imgs, bicub_imgs = data
    bicub_imgs = bicub_imgs.to(DEVICE)
    gt_imgs = gt_imgs.to(DEVICE)
    # model prediction
    with torch.no_grad():
        pred_srcnn = model_srcnn(bicub_imgs).clamp(0.0, 1.0)
    # calculate psnr
    psnr_bicub = computePSNR(bicub_imgs, gt_imgs)
    psnr_vdsr =  computePSNR(pred_srcnn, gt_imgs)
    bicub_meter.update(psnr_bicub, len(gt_imgs))
    srcnn_meter.update(psnr_vdsr, len(gt_imgs))
    if srcnn_meter.count % 1000 == 0:
        print("{} tested".format(srcnn_meter.count))
print("Average PSNR bicubic and gt = {:.4f}".format(bicub_meter.avg))
print("Average PSNR srcnn and gt = {:.4f}".format(srcnn_meter.avg))

1000 tested
2000 tested
3000 tested
4000 tested
5000 tested
6000 tested
Average PSNR bicubic and gt = 33.4954
Average PSNR srcnn and gt = 37.6262
