In [None]:
# srcnn sst *3
!nvidia-smi -L

In [None]:
import os
import copy
import glob
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import math
from tqdm import tqdm

import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn import MSELoss
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torch.optim import lr_scheduler
from torch.nn.modules.activation import ReLU, Sigmoid
from torch.nn import Conv2d, modules
from torch.nn import Sequential

import torch.backends.cudnn as cudnn

In [None]:
# utils functions
def img_read(fPath):
    '''
    read the image given path "fPath"
    '''
    img = cv2.imread(fPath, -1) # single channel image
    # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

def downsample(orig_img, scale):
    '''
    downsample by "scale" to get the low resolution image
    '''
    if scale == 1:
        return orig_img
    h_orig, w_orig = orig_img.shape
    h, w = int(h_orig/scale), int(w_orig/scale)
    return cv2.resize(orig_img, (w, h), interpolation=cv2.INTER_NEAREST)

def bicubic_sr(lr_img, scale):
    '''
    bibubic super-resolved reconstruction from lr_img by factor "scale"
    '''
    h, w = lr_img.shape
    h_orig, w_orig = h*scale, w*scale
    return cv2.resize(lr_img, (w_orig, h_orig), interpolation=cv2.INTER_CUBIC)

def computePSNR(img1, img2):
    '''
    compute PSNR(Peak Signal to Noise Ratio) to calculate accuracy
    img1 and img2 have range [0, 1], and both are gray level images
    '''
    if not img1.shape == img2.shape:
        print("Input images must have the same dimensions.")
    mse = torch.mean((img1-img2)**2)
    if mse == 0: # img1 and img2 are same images
        return float('inf')
    return 10.0 * torch.log10(1.0/mse)

def getPatches(dataRoot, field):
    '''
    get the list of patches sorted by order
    '''
    dataset = os.path.join(dataRoot, field)
    patches = []
    for date in os.listdir(dataset):
        dateFolder = os.path.join(dataset, date)
        for patch in os.listdir(dateFolder):
            patches.append(os.path.join(dateFolder, patch))
    return patches

In [None]:
class SRCNN(nn.Module):
    def __init__(self, num_channels=1):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=num_channels, 
                               out_channels=64, 
                               kernel_size=9, 
                               padding=9 // 2)
        self.conv2 = nn.Conv2d(in_channels=64, 
                               out_channels=32, 
                               kernel_size=5, 
                               padding=5 // 2)
        self.conv3 = nn.Conv2d(in_channels=32, 
                               out_channels=num_channels, 
                               kernel_size=5, 
                               padding=5 // 2)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x

In [None]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
class SRData(Dataset):
    def __init__(self, dataRoot="../dataset/", field="sst", gt_scale=1, lr_scale=3, transform=None):
        self.dataRoot = dataRoot
        self.field = field
        self.gt_scale = gt_scale
        self.lr_scale = lr_scale
        self.transform = transform
        self.patches = getPatches(self.dataRoot, self.field)
    
    def __getitem__(self, index):
        orig_img = img_read(self.patches[index])# 90*90
        gt_img = downsample(orig_img, self.gt_scale) # 90*90
        lr_img = downsample(orig_img, self.lr_scale) #30*30
        bicub_img = bicubic_sr(lr_img, scale=int(self.lr_scale/self.gt_scale)) # 90*90 bicubic sr
        if self.transform:
            gt_img = self.transform(gt_img)
            bicub_img = self.transform(bicub_img)
        return gt_img, bicub_img

    def __len__(self):
        return len(self.patches)

In [None]:
'''
hyper parameters
'''
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if DEVICE == "cuda":
    print("cuda available")
    cudnn.benchmark = True

BATCH_SIZE = 16
NUM_WORKERS = 0
LR = 1e-4
EPOCHS = 80
verbose = 1

In [None]:
'''
prepare data
'''
# convert input to tensor
trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.0), (1.0)) # do nothing
])

data = SRData("D:\work/dataset", "sst", 
              gt_scale=1,
              lr_scale=3,
              transform=trans)

train_indices = torch.arange(83648)
val_indices = torch.arange(83648, 107551)

train_data = torch.utils.data.Subset(data, train_indices)
val_data = torch.utils.data.Subset(data, val_indices)
# data size 7 : 2 : 1
print("train set length: {}".format(int(len(train_data))))
print("val set length: {}".format(int(len(val_data))))
# load data
train_dataloader = DataLoader(dataset=train_data,
                              batch_size=BATCH_SIZE,
                              num_workers=NUM_WORKERS,
                              pin_memory=True,
                              shuffle=True)
val_dataloader = DataLoader(dataset=val_data,
                            batch_size=1)

In [None]:
'''
training the SRCNN model
'''
from torch.utils.tensorboard import SummaryWriter

total_train_step = 0    # total training step
total_val_step = 0      # total validation step

# build model
model = SRCNN(num_channels=1).to(DEVICE)
# loss function
loss_fn = nn.MSELoss()
# optimizer
optimizer = torch.optim.Adam([
    {'params': model.conv1.parameters()},
    {'params': model.conv2.parameters()},
    {'params': model.conv3.parameters(),'lr': LR * 0.1}], lr=LR)
# learning rate scheduler
#scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5, verbose=True)
# visualize (tensorboard)
writer = SummaryWriter("logs")

best_weights = copy.deepcopy(model.state_dict())
best_epoch = 0
best_psnr = 0.0

torch.cuda.empty_cache()

count = 0


for i in range(EPOCHS):
    # training
    model.train()
    epoch_losses = AverageMeter()
    
    with tqdm(total=(len(train_data)), ncols=100) as t1:
        t1.set_description('epoch train: {}/{}'.format(i+1, EPOCHS))
        
        for data in train_dataloader:
            
            # get data, transpose to device
            gt_imgs, bicub_imgs = data
            gt_imgs = gt_imgs.to(DEVICE)
            bicub_imgs = bicub_imgs.to(DEVICE)
            
            # predict
            predicts = model(bicub_imgs)
            
            # calculate loss
            loss = loss_fn(predicts, gt_imgs)
            epoch_losses.update(loss.item(), len(bicub_imgs))
            
            # optimizer
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # print loss
            t1.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))

            # show on tensorboard
            total_train_step += 1            
            if total_train_step%1000 == 0:
            #    print("train step: {}，Loss: {}".format(total_train_step, loss.item()))
                writer.add_scalar("train_loss", epoch_losses.avg, total_train_step)
            
            # update tqdm
            t1.update(len(bicub_imgs))
    
        # scheduler
        #scheduler.step()

    # validation
    if (i+1) % verbose == 0:
        
        model.eval()
        epoch_psnr = AverageMeter()

        for data in val_dataloader:
            
            # get data, transpose to device
            gt_imgs, bicub_imgs = data
            gt_imgs = gt_imgs.to(DEVICE)
            bicub_imgs = bicub_imgs.to(DEVICE)
            
            # predict (no_grad) is important, not use to update model
            # cut compute graphe to reduce needed memory of device and 
            # accelerate the computation
            with torch.no_grad():
                predicts = model(bicub_imgs).clamp(0.0, 1.0)
            # calculate psnr
            psnr =  computePSNR(predicts, gt_imgs)
            # update total psnr
            epoch_psnr.update(psnr, len(bicub_imgs))
        
        # print psnr
        print('val set PSNR: {:.2f}'.format(epoch_psnr.avg))
        
        # show in tensorboard
        total_val_step += verbose
        writer.add_scalar("val_psnr", epoch_psnr.avg, total_val_step) 
        # save best weights
        if epoch_psnr.avg > best_psnr:
            best_epoch = i+1
            best_psnr = epoch_psnr.avg
            best_weights = copy.deepcopy(model.state_dict())
            
    # save best models every 100 epochs
    if (i+1)%20 == 0:
        print('top {} best epoch: {}, val set psnr: {:.2f}'.format(i+1, best_epoch, best_psnr))
        torch.save(best_weights, "weights/top_{}_best_iter_{}.pth".format(i+1, best_epoch))
        print("best model in first {} epochs saved".format(i+1))

# close tensorboard
writer.close()

# save best model
print('global best epoch: {}, val set psnr: {:.2f}'.format(best_epoch, best_psnr))
torch.save(best_weights, "weights/best_iter_{}.pth".format(best_epoch))
print("global best model saved")