In [None]:
# srcnn sst cascade
!nvidia-smi -L

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import copy
import glob
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import math
from tqdm import tqdm

import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn import MSELoss
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torch.optim import lr_scheduler
from torch.nn.modules.activation import ReLU, Sigmoid
from torch.nn import Conv2d, modules
from torch.nn import Sequential

from math import sqrt

import torch.backends.cudnn as cudnn

In [None]:
# utils functions
def img_read(fPath):
    '''
    read the image given path "fPath"
    '''
    img = cv2.imread(fPath, -1) # single channel image
    # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

def downsample(orig_img, scale):
    '''
    downsample by "scale" to get the low resolution image
    '''
    if scale == 1:
        return orig_img
    h_orig, w_orig = orig_img.shape
    h, w = int(h_orig/scale), int(w_orig/scale)
    return cv2.resize(orig_img, (w, h), interpolation=cv2.INTER_NEAREST)

def bicubic_sr(lr_img, scale):
    '''
    bibubic super-resolved reconstruction from lr_img by factor "scale"
    '''
    h, w = lr_img.shape
    h_orig, w_orig = h*scale, w*scale
    return cv2.resize(lr_img, (w_orig, h_orig), interpolation=cv2.INTER_CUBIC)

def computePSNR(img1, img2):
    '''
    compute PSNR(Peak Signal to Noise Ratio) to calculate accuracy
    img1 and img2 have range [0, 1], and both are gray level images
    '''
    if not img1.shape == img2.shape:
        print("Input images must have the same dimensions.")
    mse = torch.mean((img1-img2)**2)
    if mse == 0: # img1 and img2 are same images
        return float('inf')
    return 10.0 * torch.log10(1.0/mse)

In [None]:
class SRCNN(nn.Module):
    def __init__(self, num_channels=1):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=num_channels, 
                               out_channels=64, 
                               kernel_size=9, 
                               padding=9 // 2)
        self.conv2 = nn.Conv2d(in_channels=64, 
                               out_channels=32, 
                               kernel_size=5, 
                               padding=5 // 2)
        self.conv3 = nn.Conv2d(in_channels=32, 
                               out_channels=num_channels, 
                               kernel_size=5, 
                               padding=5 // 2)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x

In [None]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
class SRData(Dataset):
    def __init__(self, dataRoot="../dataset/", field="sst", inter_scale=3, lr_scale=9, transform=None):
        self.dataRoot = dataRoot
        self.field = field
        self.inter_scale = inter_scale
        self.lr_scale = lr_scale
        self.transform = transform
        self.patches = self.getPatches()
    
    def __getitem__(self, index):
        T1 = img_read(self.patches[index])# 90*90, GT for second stage
        T3 = downsample(T1, self.inter_scale) #30*30, GT of first stage
        T9 = downsample(T1, self.lr_scale) # 10*10
        bicubT9 = bicubic_sr(T9, scale=int(self.lr_scale/self.inter_scale)) # 30*30 input of first stage 
        if self.transform:
            T1 = self.transform(T1)
            T3 = self.transform(T3)
            T9 = self.transform(T9)
            bicubT9 = self.transform(bicubT9)
        return T1, T3, T9, bicubT9

    def __len__(self):
        return len(self.patches)

    def getPatches(self):
        '''
        get the list of patches sorted by order
        '''
        dataset = os.path.join(self.dataRoot, self.field)
        patches = []
        for date in os.listdir(dataset):
            dateFolder = os.path.join(dataset, date)
            for patch in os.listdir(dateFolder):
                patches.append(os.path.join(dateFolder, patch))
        return patches

In [None]:
'''
hyper parameters
'''
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if DEVICE == "cuda":
    print("cuda available")
    cudnn.benchmark = True

BATCH_SIZE = 16
NUM_WORKERS = 0
LR = 1e-4
EPOCHS = 200
verbose = 1
ALPHA = 0.2

In [None]:
'''
prepare data
'''
# convert input data to normalized tensor
trans_input = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.0), (1.0)) # do nothing
])

trans_img = transforms.ToPILImage()

trans_bicub = transforms.Resize(size=90, interpolation=InterpolationMode.BICUBIC)

data = SRData(dataRoot="dataset", field="sst", inter_scale=3, lr_scale=9, transform=trans_input)

train_indices = torch.arange(44800)
val_indices = torch.arange(44800, 57600)

train_data = torch.utils.data.Subset(data, train_indices)
val_data = torch.utils.data.Subset(data, val_indices)
# data size 7 : 2 : 1
print("train set length: {}".format(int(len(train_data))))
print("val set length: {}".format(int(len(val_data))))
# load data
train_dataloader = DataLoader(dataset=train_data,
                              batch_size=BATCH_SIZE,
                              num_workers=NUM_WORKERS,
                              pin_memory=True,
                              shuffle=True)
val_dataloader = DataLoader(dataset=val_data,
                            batch_size=1)

In [None]:
total_train_step = 0    # total training step
total_val_step = 0      # total validation step

# build model
model1 = SRCNN(num_channels=1).to(DEVICE) # first stage model 10->30
model2 = SRCNN(num_channels=1).to(DEVICE) # second stage model 30->90
# loss function
Loss1 = nn.MSELoss() # first stage loss
Loss2 = nn.MSELoss() # second stage loss
# optimizer
optimizer = torch.optim.Adam([
    {'params': model1.conv1.parameters()},
    {'params': model1.conv2.parameters()},
    {'params': model1.conv3.parameters(),'lr': LR * 0.1},
    {'params': model2.conv1.parameters()},
    {'params': model2.conv2.parameters()},
    {'params': model2.conv3.parameters(),'lr': LR * 0.1}], lr=LR)
    

In [None]:
'''
training the SRCNN model
'''
from torch.utils.tensorboard import SummaryWriter


# learning rate scheduler
#scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5, verbose=True)
# visualize (tensorboard)
folder = "/content/drive/MyDrive/PRAT/cascade_0.2"
writer = SummaryWriter(os.path.join(folder, "logs"))

best_weight1 = copy.deepcopy(model1.state_dict()) #first stage best model
best_weight2 = copy.deepcopy(model2.state_dict()) # second stage best model
best_epoch = 0
best_psnr = 0.0

torch.cuda.empty_cache()

for i in range(EPOCHS):
    # training
    model1.train()
    model2.train()
    epoch_losses = AverageMeter()
    
    with tqdm(total=(len(train_data)), ncols=100) as t1:
        t1.set_description('epoch train: {}/{}'.format(i+1, EPOCHS))
        for data in train_dataloader:
            
            # get data, feed into device
            batch_T1, batch_T3, batch_T9, batch_bicubT9 = data
            batch_T1 = batch_T1.to(DEVICE)
            batch_T3 = batch_T3.to(DEVICE)
            batch_T9 = batch_T9.to(DEVICE)
            batch_bicubT9 = batch_bicubT9.to(DEVICE)
            
            batch_I3 = model1(batch_bicubT9) # 30*30 output of model
            batch_bicubI3 = trans_bicub(batch_I3) # 90*90, input of model_prime
            batch_I1 = model2(batch_bicubI3)# 90*90, output of model_prime
            
            # calculate weighted loss
            l1 = Loss1(batch_T3, batch_I3)
            l2 = Loss2(batch_T1, batch_I1)
            loss_total = ALPHA*l2 + (1-ALPHA)*l1
            epoch_losses.update(loss_total.item(), len(batch_I1))
            
            # optimizer
            optimizer.zero_grad()
            loss_total.backward()
            optimizer.step()
            
            # print loss
            t1.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))

            # show on tensorboard
            total_train_step += 1            
            if total_train_step%1000 == 0:
            #    print("train step: {}，Loss: {}".format(total_train_step, loss.item()))
                writer.add_scalar("train_loss", epoch_losses.avg, total_train_step)
            
            # update tqdm
            t1.update(len(batch_T1))

    # validation
    if (i+1) % verbose == 0:
        
        model1.eval()
        model2.eval()
        epoch_psnr = AverageMeter()
        cubic_psnr = AverageMeter()

        for data in val_dataloader:
            
             # get data, feed into device
            batch_T1, batch_T3, batch_T9, batch_bicubT9 = data
            batch_T1 = batch_T1.to(DEVICE)
            batch_T3 = batch_T3.to(DEVICE)
            batch_T9 = batch_T9.to(DEVICE)
            batch_bicubT9 = batch_bicubT9.to(DEVICE)
            
            with torch.no_grad():
                batch_I3 = model1(batch_bicubT9) # 30*30 output of model1
                batch_bicubI3 = trans_bicub(batch_I3) # 90*90 input of model2
                batch_I1 = model2(batch_bicubI3).clamp(0.0, 1.0)# 90*90, output of model2
            
            # calculate psnr
            psnr = computePSNR(batch_T1, batch_I1)
            # update total psnr
            epoch_psnr.update(psnr, len(batch_T1))
        
        # print psnr
        print('val set PSNR: {:.4f}'.format(epoch_psnr.avg))
#         print('cubic PSNR: {:.4f}'.format(cubic_psnr.avg))
        
        # show in tensorboard
        total_val_step += verbose
        writer.add_scalar("val_psnr", epoch_psnr.avg, total_val_step) 
        # save best weights
        if epoch_psnr.avg > best_psnr:
            best_epoch = i+1
            best_psnr = epoch_psnr.avg
            best_weight1 = copy.deepcopy(model1.state_dict())
            best_weight2 = copy.deepcopy(model2.state_dict())
            
    # save best models every 100 epochs
    if (i+1)%20 == 0:
        print('top {} best epoch: {}, val set psnr: {:.4f}'.format(i+1, best_epoch, best_psnr))
        torch.save(best_weight1, os.path.join(folder, "saved_weights/top_{}_best_iter_{}_stage1.pth".format(i+1, best_epoch)))
        torch.save(best_weight2, os.path.join(folder, "saved_weights/top_{}_best_iter_{}_stage2.pth".format(i+1, best_epoch)))
        print("best model in first {} epochs saved".format(i+1))

# close tensorboard
writer.close()

# save best model
print('global best epoch: {}, val set psnr: {:.4f}'.format(best_epoch, best_psnr))
torch.save(best_weight1, os.path.join(folder, "saved_weights/global_best_stage1_iter_{}.pth".format(best_epoch)))
torch.save(best_weight2, os.path.join(folder, "saved_weights/global_best_stage2_iter_{}.pth".format(best_epoch)))
print("global best model saved")