## import librarys

In [2]:
import os
import copy
import glob
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import math
from tqdm import tqdm

import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn import MSELoss
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torch.optim import lr_scheduler
from torch.nn.modules.activation import ReLU, Sigmoid
from torch.nn import Conv2d, modules
from torch.nn import Sequential

from math import sqrt

import torch.backends.cudnn as cudnn

## define SRCNN model

In [3]:
class SRCNN(nn.Module):
    def __init__(self, num_channels=1):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=num_channels, 
                               out_channels=64, 
                               kernel_size=9, 
                               padding=9 // 2)
        self.conv2 = nn.Conv2d(in_channels=64, 
                               out_channels=32, 
                               kernel_size=5, 
                               padding=5 // 2)
        self.conv3 = nn.Conv2d(in_channels=32, 
                               out_channels=num_channels, 
                               kernel_size=5, 
                               padding=5 // 2)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x

## useful functions and class

In [4]:
# utils functions
def img_read(fPath):
    '''
    read the image given path "fPath"
    '''
    img = cv2.imread(fPath, -1) # single channel image
    # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

def downsample(orig_img, scale):
    '''
    downsample by "scale" to get the low resolution image
    '''
    if scale == 1:
        return orig_img
    h_orig, w_orig = orig_img.shape
    h, w = int(h_orig/scale), int(w_orig/scale)
    return cv2.resize(orig_img, (w, h), interpolation=cv2.INTER_NEAREST)

def bicubic_sr(lr_img, scale):
    '''
    bibubic super-resolved reconstruction from lr_img by factor "scale"
    '''
    h, w = lr_img.shape
    h_orig, w_orig = h*scale, w*scale
    return cv2.resize(lr_img, (w_orig, h_orig), interpolation=cv2.INTER_CUBIC)

def computePSNR(img1, img2):
    '''
    compute PSNR(Peak Signal to Noise Ratio) to calculate accuracy
    img1 and img2 have range [0, 1], and both are gray level images
    '''
    if not img1.shape == img2.shape:
        print("Input images must have the same dimensions.")
    mse = torch.mean((img1-img2)**2)
    if mse == 0: # img1 and img2 are same images
        return float('inf')
    return 10.0 * torch.log10(1.0/mse)

In [5]:
class SRData(Dataset):
    def __init__(self, dataRoot="D:\work/dataset/", field="sst", inter_scale=3, lr_scale=9, transform=None):
        self.dataRoot = dataRoot
        self.field = field
        self.inter_scale = inter_scale
        self.lr_scale = lr_scale
        self.transform = transform
        self.patches = self.getPatches()
    
    def __getitem__(self, index):
        T1 = img_read(self.patches[index])# 90*90
        T3 = downsample(T1, self.inter_scale) #30*30
        T9 = downsample(T1, self.lr_scale) # 10*10
        bicubT9 = bicubic_sr(T9, scale=int(self.lr_scale/self.inter_scale)) # 90*90 bicubic sr
        if self.transform:
            T1 = self.transform(T1)
            T3 = self.transform(T3)
            T9 = self.transform(T9)
            bicubT9 = self.transform(bicubT9)
        return T1, T3, T9, bicubT9

    def __len__(self):
        return len(self.patches)

    def getPatches(self):
        '''
        get the list of patches sorted by order
        '''
        dataset = os.path.join(self.dataRoot, self.field)
        patches = []
        for date in os.listdir(dataset):
            dateFolder = os.path.join(dataset, date)
            for patch in os.listdir(dateFolder):
                patches.append(os.path.join(dateFolder, patch))
        return patches

class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

## prepare test data

In [6]:
'''
prepare data
'''
# convert data to normalized tensor
trans_input = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.0), (1.0)) # do nothing
])

trans_img = transforms.ToPILImage()

trans_bicub = transforms.Resize(size=90, interpolation=InterpolationMode.BICUBIC)

entire_dataset = SRData(dataRoot="D:\work/dataset", field="sst", inter_scale=3, lr_scale=9, transform=trans_input)

# entire test set data
entire_test_indices = torch.arange(57600, 64000)
entire_test_data = torch.utils.data.Subset(entire_dataset, entire_test_indices)
entire_test_dataloader = DataLoader(dataset=entire_test_data, batch_size=1)
# randomly select 5 sample in test data
sample_test_indices = np.random.randint(0, len(entire_test_indices), 5)
sample_test_indices = [entire_test_indices[x] for x in sample_test_indices]
sample_test_data = torch.utils.data.Subset(entire_dataset, sample_test_indices)
sample_test_dataloader = DataLoader(dataset=sample_test_data, batch_size=1)
print("test set length: {}".format(int(len(entire_test_data))))

test set length: 6400


## Cascade Alpha = 0.2

In [7]:
DEVICE =  "cuda" if torch.cuda.is_available() else "cpu"
# load model
model1 = SRCNN(num_channels=1).to(DEVICE)
model2 = SRCNN(num_channels=1).to(DEVICE)

params1 = torch.load("weights/cascade_0.2_stage1.pth", map_location=DEVICE)
params2 = torch.load("weights/cascade_0.2_stage2.pth", map_location=DEVICE)

model1.load_state_dict(params1)
model2.load_state_dict(params2)

<All keys matched successfully>

In [8]:
model1.to(DEVICE)
model2.to(DEVICE)
model1.eval()
model2.eval()
cascade_meter = AverageMeter()
for data in entire_test_dataloader:
    batch_T1, batch_T3, batch_T9, batch_bicubT9 = data
    batch_T1 = batch_T1.to(DEVICE)
    batch_T3 = batch_T3.to(DEVICE)
    batch_T9 = batch_T9.to(DEVICE)
    batch_bicubT9 = batch_bicubT9.to(DEVICE)
    # model prediction
    with torch.no_grad():
        batch_I3 = model1(batch_bicubT9) # 30*30 output of model1
        batch_bicubI3 = trans_bicub(batch_I3) # 90*90 input of model2
        batch_I1 = model2(batch_bicubI3).clamp(0.0, 1.0)# 90*90, output of model2
    # calculate psnr
    psnr = computePSNR(batch_T1, batch_I1)
    cascade_meter.update(psnr, len(batch_T1))
print("Average PSNR cascade_0.2 and gt = {:.4f}".format(cascade_meter.avg))

Average PSNR cascade_0.2 and gt = 38.1243


## Cascade Alpha = 0.5

In [9]:
DEVICE =  "cuda" if torch.cuda.is_available() else "cpu"
# load model
model1 = SRCNN(num_channels=1).to(DEVICE)
model2 = SRCNN(num_channels=1).to(DEVICE)

params1 = torch.load("weights/cascade_0.5_stage1.pth", map_location=DEVICE)
params2 = torch.load("weights/cascade_0.5_stage2.pth", map_location=DEVICE)

model1.load_state_dict(params1)
model2.load_state_dict(params2)

<All keys matched successfully>

In [10]:
model1.to(DEVICE)
model2.to(DEVICE)
model1.eval()
model2.eval()
cascade_meter = AverageMeter()
for data in entire_test_dataloader:
    batch_T1, batch_T3, batch_T9, batch_bicubT9 = data
    batch_T1 = batch_T1.to(DEVICE)
    batch_T3 = batch_T3.to(DEVICE)
    batch_T9 = batch_T9.to(DEVICE)
    batch_bicubT9 = batch_bicubT9.to(DEVICE)
    # model prediction
    with torch.no_grad():
        batch_I3 = model1(batch_bicubT9) # 30*30 output of model1
        batch_bicubI3 = trans_bicub(batch_I3) # 90*90 input of model2
        batch_I1 = model2(batch_bicubI3).clamp(0.0, 1.0)# 90*90, output of model2
    # calculate psnr
    psnr = computePSNR(batch_T1, batch_I1)
    cascade_meter.update(psnr, len(batch_T1))
print("Average PSNR cascade_0.5 and gt = {:.4f}".format(cascade_meter.avg))

Average PSNR cascade_0.5 and gt = 38.1089


## Cascade Alpha = 0.8

In [11]:
DEVICE =  "cuda" if torch.cuda.is_available() else "cpu"
# load model
model1 = SRCNN(num_channels=1).to(DEVICE)
model2 = SRCNN(num_channels=1).to(DEVICE)

params1 = torch.load("weights/cascade_0.8_stage1.pth", map_location=DEVICE)
params2 = torch.load("weights/cascade_0.8_stage2.pth", map_location=DEVICE)

model1.load_state_dict(params1)
model2.load_state_dict(params2)

<All keys matched successfully>

In [12]:
model1.to(DEVICE)
model2.to(DEVICE)
model1.eval()
model2.eval()
cascade_meter = AverageMeter()
for data in entire_test_dataloader:
    batch_T1, batch_T3, batch_T9, batch_bicubT9 = data
    batch_T1 = batch_T1.to(DEVICE)
    batch_T3 = batch_T3.to(DEVICE)
    batch_T9 = batch_T9.to(DEVICE)
    batch_bicubT9 = batch_bicubT9.to(DEVICE)
    # model prediction
    with torch.no_grad():
        batch_I3 = model1(batch_bicubT9) # 30*30 output of model1
        batch_bicubI3 = trans_bicub(batch_I3) # 90*90 input of model2
        batch_I1 = model2(batch_bicubI3).clamp(0.0, 1.0)# 90*90, output of model2
    # calculate psnr
    psnr = computePSNR(batch_T1, batch_I1)
    cascade_meter.update(psnr, len(batch_T1))
print("Average PSNR cascade_0.8 and gt = {:.4f}".format(cascade_meter.avg))

Average PSNR cascade_0.8 and gt = 37.9992
