In [1]:
import imgaug as ia
import imgaug.augmenters as iaa
import torch
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset
import cv2
import numpy as np
from sklearn.decomposition import PCA
from google.colab import drive
import math
from torchvision.utils import make_grid
import os, glob
import torch.nn as nn
from torch.nn import init
from torchvision.models import resnet18
from torch.utils.data import DataLoader
from PIL import Image
import time

In [2]:
torch.cuda.is_available()

True

In [3]:
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
default_augmentations = iaa.Sequential([
            iaa.Fliplr(0.5),
            iaa.Flipud(0.5), 
            iaa.Rot90([0, 1, 2, 3])])

default_transforms = transforms.ToTensor()

KERNEL_SIZE = 21

In [5]:
class KernelImagePair(Dataset):
    def __init__(self, imgs:list, 
                 kernel_pickle:str,
                 scale:int,
                 augmentations:object,
                 transforms:object,
                 seed=0,
                 patch_size=(256,256),
                 train=True, noise=False, interpolation="cubic", downsample_on_pipe=True):
        super(KernelImagePair, self).__init__()
        self._kernel_dict = torch.load(kernel_pickle)
        
        self.kernels = self._kernel_dict['kernels']  
        self.kernel_size = (KERNEL_SIZE, KERNEL_SIZE) 
        self.k_reduced = self._kernel_dict['kernels_compressed']
        self.stddevs = self._kernel_dict['sigmas']  

        self.pca = self._kernel_dict['pca']  
        
        self.imgs = imgs
        self.scale = scale
        self.augmentations = augmentations
        
        self.transforms = transforms
        self.seed = seed
        self.patch_size = patch_size
        self.train=train
        self.downsample_on_pipe = downsample_on_pipe
        
        if interpolation == "cubic":
            self.inter = cv2.INTER_CUBIC

        self.random = np.random.RandomState(seed)
        self.noise = None
        if noise:
            self.noise = iaa.AdditiveGaussianNoise(scale=(0.03*255, 0.1*255))
            
    def __getitem__(self, idx) -> dict:
        img = self.imgs[idx]
        img = cv2.imread(img, cv2.IMREAD_COLOR)
        
        if self.train:
            """
            Random Crop image Adding margin w.r.t kernel_size and patch_size
            """
            img_from, img_to = np.zeros(2, dtype=int), np.zeros(2, dtype=int)
            for i in range(2):
                img_from[i] = self.random.randint(0, img.shape[i] - (self.kernel_size[i] + 1 + self.patch_size[i]))
                img_to[i] = img_from[i] + self.kernel_size[i] + 1 + self.patch_size[i]
            img_patch = img[img_from[0]:img_to[0], img_from[1]:img_to[1]]
        else:
            img_patch = img
        
        if self.train and self.augmentations is not None:
            img_patch = self.augmentations.augment_image(img_patch)
            img = img_patch
        kernel_idx = self.random.randint(len(self.kernels))
        stddev = self.stddevs[kernel_idx]

        gaussian_kernel = self.kernels[kernel_idx].reshape(KERNEL_SIZE, KERNEL_SIZE).astype(np.float32)

        if self.downsample_on_pipe:
            img_blur = cv2.filter2D(img_patch, ddepth=-1, kernel=gaussian_kernel)
        else:
            img_blur = img_patch

        k_reduced = self.k_reduced[kernel_idx].astype(np.float32)

        if self.train:            
            half = self.kernel_size[0] // 2 + 1, self.kernel_size[1] // 2 + 1
            img_blur = img_blur[half[0] : -half[0], half[1]:-half[1]]
            img = img[half[0] : -half[0], half[1]:-half[1]]
            img_lr = cv2.resize(img_blur, 
                                (self.patch_size[0]//self.scale, self.patch_size[1]//self.scale), self.inter)
        if self.noise is not None:
                img_lr = self.noise.augment_image(img_lr)
        if self.downsample_on_pipe:
            img_lr = cv2.resize(img_blur, 
                    (img.shape[1]//self.scale, img.shape[0]//self.scale), self.inter)
        else:
            img_lr = img_blur
            img = cv2.imread(self.imgs[idx].replace("/lr/", "/hr/"))

        if self.transforms is not None:
            if self.train:
                if not img_lr.flags['C_CONTIGUOUS']:
                    img_lr = np.ascontiguousarray(img_lr)
                if not img.flags['C_CONTIGUOUS']:
                    img = np.ascontiguousarray(img)

            img_lr = self.transforms(img_lr)
            img = self.transforms(img)
        

        re_dict = dict(LR=img_lr,
                    HR=img,
                    k=gaussian_kernel,
                    k_reduced=k_reduced,
                    stddev=stddev,
                    )      
        return re_dict

    def __len__(self):
        return len(self.imgs)

In [6]:
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
  
    tensor = tensor.squeeze().float().cpu().clamp_(*min_max)  
    tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])  
    n_dim = tensor.dim()
    if n_dim == 4:
        n_img = len(tensor)
        img_np = make_grid(tensor, padding=0, nrow=1, normalize=False).numpy()
        img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  

    if out_type == np.uint8:
        img_np = (img_np * 255.0).round()
    return img_np.astype(out_type)

In [7]:
def get_datasets(args):
    train_imgs = glob.glob("/content/drive/My Drive/CS663_project/DIV2K_train_HR/*.png", recursive=True)
    valid_imgs = glob.glob("/content/drive/My Drive/CS663_project/DIV2K_valid_HR/*.png", recursive=True)

    if args.augment == "default":
        augmentations = default_augmentations

    print(f"num of train {len(train_imgs)},  num of valid {len(valid_imgs)}.")

    train_dataset = KernelImagePair(imgs=train_imgs, 
                                    kernel_pickle=args.train_kernel, scale=args.scale, 
                                    augmentations=default_augmentations, transforms=default_transforms, 
                                    patch_size=(args.patch_size, args.patch_size),
                                    seed=args.seed, train=True, noise=args.use_noise, interpolation=args.inter)

    valid_dataset = KernelImagePair(imgs=valid_imgs, 
                                    kernel_pickle=args.train_kernel, scale=args.scale, 
                                    augmentations=default_augmentations, transforms=default_transforms, 
                                    patch_size=(args.patch_size, args.patch_size),
                                    seed=args.seed, train=True, interpolation=args.inter)

    rt_dict = dict(train_dataset=train_dataset,valid_dataset=valid_dataset)
    
    return rt_dict

In [8]:
class args():
    def __init__(self):
        self.scale = 4
        self.train_kernel = "/content/drive/My Drive/CS663_project/kernels_scale4dim10.pth"
        self.inter = "cubic"
        self.seed = None
        self.patch_size = 256
        self.augment = "default"
        self.use_noise = False
        self.batch_size=16
        self.num_step=1000
        self.validation_interval=500
        self.num_workers=4
        self.lr=0.00004
        self.lr_decay=0.4
        self.lr_min=1e-7
        self.lr_scheduler="cosine"
        self.optimizer="adam"
        self.loss="l2"
        self.metric="psnr"
        self.resume=False
        self.nf=64
        self.valid_rate=0.1
        self.kernel_dim=10
args = args()

In [9]:
class SFT_Layer(nn.Module):
    def __init__(self, nf=64, para=10):
        super(SFT_Layer, self).__init__()
        self.mul_conv1 = nn.Conv2d(para + nf, 32, kernel_size=3, stride=1, padding=1)
        self.mul_leaky = nn.LeakyReLU(0.2)
        self.mul_conv2 = nn.Conv2d(32, nf, kernel_size=3, stride=1, padding=1)

        self.add_conv1 = nn.Conv2d(para + nf, 32, kernel_size=3, stride=1, padding=1)
        self.add_leaky = nn.LeakyReLU(0.2)
        self.add_conv2 = nn.Conv2d(32, nf, kernel_size=3, stride=1, padding=1)

    def forward(self, feature_maps, para_maps):
        cat_input = torch.cat((feature_maps, para_maps), dim=1)
        mul = torch.sigmoid(self.mul_conv2(self.mul_leaky(self.mul_conv1(cat_input))))
        add = self.add_conv2(self.add_leaky(self.add_conv1(cat_input)))
        return feature_maps * mul + add


In [10]:
class SFT_Residual_Block(nn.Module):
    def __init__(self, nf=64, para=10):
        super(SFT_Residual_Block, self).__init__()
        self.sft1 = SFT_Layer(nf=nf, para=para)
        self.sft2 = SFT_Layer(nf=nf, para=para)
        self.conv1 = nn.Conv2d(in_channels=nf, out_channels=nf, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv2 = nn.Conv2d(in_channels=nf, out_channels=nf, kernel_size=3, stride=1, padding=1, bias=True)

    def forward(self, feature_maps, para_maps):
        fea1 = F.relu(self.sft1(feature_maps, para_maps))
        fea2 = F.relu(self.sft2(self.conv1(fea1), para_maps))
        fea3 = self.conv2(fea2)
        return torch.add(feature_maps, fea3)

In [11]:
class SFTMD(nn.Module):
    def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, scale=4, input_para=10, min=0.0, max=1.0):
        super(SFTMD, self).__init__()
        self.min = min
        self.max = max
        self.para = input_para
        self.num_blocks = nb

        self.conv1 = nn.Conv2d(in_nc, nf, 3, stride=1, padding=1)
        self.relu_conv1 = nn.LeakyReLU(0.2)
        self.conv2 = nn.Conv2d(nf, nf, 3, stride=1, padding=1)
        self.relu_conv2 = nn.LeakyReLU(0.2)
        self.conv3 = nn.Conv2d(nf, nf, 3, stride=1, padding=1)

        for i in range(nb):
            self.add_module('SFT-residual' + str(i + 1), SFT_Residual_Block(nf=nf, para=input_para))

        self.sft = SFT_Layer(nf=nf, para=input_para)
        self.conv_mid = nn.Conv2d(in_channels=nf, out_channels=nf, kernel_size=3, stride=1, padding=1, bias=True)

        if scale == 4: #x4
            self.upscale = nn.Sequential(
                nn.Conv2d(in_channels=nf, out_channels=nf * scale, kernel_size=3, stride=1, padding=1, bias=True),
                nn.PixelShuffle(scale // 2),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(in_channels=nf, out_channels=nf * scale, kernel_size=3, stride=1, padding=1, bias=True),
                nn.PixelShuffle(scale // 2),
                nn.LeakyReLU(0.2, inplace=True),
            )

        self.conv_output = nn.Conv2d(in_channels=nf, out_channels=out_nc, kernel_size=9, stride=1, padding=4, bias=True)

    def forward(self, input_dict):
        input = input_dict['LR']
        ker_code = input_dict['k_reduced']
        
        B, C, H, W = input.size() 
        B_h, C_h = ker_code.size() 
        ker_code_exp = ker_code.view((B_h, C_h, 1, 1)).expand((B_h, C_h, H, W)) #kernel_map stretch

        fea_bef = self.conv3(self.relu_conv2(self.conv2(self.relu_conv1(self.conv1(input)))))
        fea_in = fea_bef
        for i in range(self.num_blocks):
            fea_in = self.__getattr__('SFT-residual' + str(i + 1))(fea_in, ker_code_exp)
        fea_mid = fea_in
        fea_add = torch.add(fea_mid, fea_bef)
        fea = self.upscale(self.conv_mid(self.sft(fea_add, ker_code_exp)))
        out = self.conv_output(fea)
        
        return out

In [22]:
def train_SFTMD(train_dl,valid_dl, model, epochs, optimizer, loss_func,batch_size,scheduler):
    os.makedirs( "/content/drive/My Drive/CS663_project/valid_results/", exist_ok=True)
    for t in range(0,epochs):  
      if t%10 == 0 and t!=0:
        path = "/content/drive/My Drive/CS663_project/SFTMD_temp{}.pth".format(t) 
        torch.save(net.state_dict(), path)
      if t%10 == 0:
        with torch.no_grad(): 
          bd = next(iter(valid_dl))
          for k,v in bd.items():
              bd[k] = v.cuda()                
          sr = net(bd)
          hr = bd['HR']
          lr = bd['LR']
          cat = torch.cat([sr, hr], dim=3)
          img = tensor2img(cat.detach())           
          lr_img = tensor2img(lr.detach())
          lr_img = cv2.resize(lr_img, (args.patch_size, args.batch_size * args.patch_size), interpolation=cv2.INTER_CUBIC)
          img = np.concatenate((lr_img, img), axis=1)         
          Image.fromarray(img).save("/content/drive/My Drive/CS663_project/valid_results/{}.png".format(t))
      #start =time.time()
      loss_calc = 0
      for batch in train_dl:
        input = batch
        for k,v in input.items():
          input[k] = v.cuda() 
        gt_hr = input['HR']
        sr = model(input)
        loss = loss_func(sr, gt_hr)
        loss_calc += loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
      scheduler.step(loss_calc)
      #end = time.time()
      #print(end-start)
      if t%3==0:
        print("loss after {} epochs".format(t),loss_calc)
            

In [13]:
datasets = get_datasets(args)
print(f"datasets are prepared.")

num of train 800,  num of valid 100.
datasets are prepared.


In [14]:
train_dl = DataLoader(datasets['train_dataset'], batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)
valid_dl = DataLoader(datasets['valid_dataset'], batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

In [15]:
net = SFTMD(input_para=args.kernel_dim, scale=args.scale, nf=args.nf).cuda()

In [16]:
loss = F.mse_loss

In [17]:
#metric 

In [18]:
if args.optimizer == "adam":
    optimizer = torch.optim.Adam(filter(lambda x: x.requires_grad, net.parameters()), lr=args.lr)

In [19]:
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", patience=5, factor=args.lr_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=40, eta_min=args.lr_min) 

In [None]:
train_SFTMD(train_dl=train_dl,valid_dl=valid_dl,model= net,epochs=args.num_step, optimizer=optimizer, loss_func=loss,batch_size=args.batch_size,scheduler =scheduler)

loss after 0 epochs tensor(4.5461, device='cuda:0', grad_fn=<AddBackward0>)
loss after 3 epochs tensor(1.0047, device='cuda:0', grad_fn=<AddBackward0>)
loss after 6 epochs tensor(0.5023, device='cuda:0', grad_fn=<AddBackward0>)
loss after 9 epochs tensor(0.4548, device='cuda:0', grad_fn=<AddBackward0>)
loss after 12 epochs tensor(0.3919, device='cuda:0', grad_fn=<AddBackward0>)
loss after 15 epochs tensor(0.3622, device='cuda:0', grad_fn=<AddBackward0>)
loss after 18 epochs tensor(0.3526, device='cuda:0', grad_fn=<AddBackward0>)
loss after 21 epochs tensor(0.3097, device='cuda:0', grad_fn=<AddBackward0>)
loss after 24 epochs tensor(0.2563, device='cuda:0', grad_fn=<AddBackward0>)
loss after 27 epochs tensor(0.2450, device='cuda:0', grad_fn=<AddBackward0>)
loss after 30 epochs tensor(0.2546, device='cuda:0', grad_fn=<AddBackward0>)
loss after 33 epochs tensor(0.2341, device='cuda:0', grad_fn=<AddBackward0>)
loss after 36 epochs tensor(0.2406, device='cuda:0', grad_fn=<AddBackward0>)
los

KeyboardInterrupt: ignored

In [20]:
path = "/content/drive/My Drive/CS663_project/SFTMD_temp260.pth"
net.load_state_dict(torch.load(path))

<All keys matched successfully>

In [None]:
train_SFTMD(train_dl=train_dl,valid_dl=valid_dl,model= net,epochs=args.num_step, optimizer=optimizer, loss_func=loss,batch_size=args.batch_size,scheduler =scheduler)

loss after 261 epochs tensor(0.1708, device='cuda:0', grad_fn=<AddBackward0>)
loss after 264 epochs tensor(0.1668, device='cuda:0', grad_fn=<AddBackward0>)
loss after 267 epochs tensor(0.1597, device='cuda:0', grad_fn=<AddBackward0>)
loss after 270 epochs tensor(0.1645, device='cuda:0', grad_fn=<AddBackward0>)
loss after 273 epochs tensor(0.1547, device='cuda:0', grad_fn=<AddBackward0>)
loss after 276 epochs tensor(0.1604, device='cuda:0', grad_fn=<AddBackward0>)
loss after 279 epochs tensor(0.1729, device='cuda:0', grad_fn=<AddBackward0>)
loss after 282 epochs tensor(0.1625, device='cuda:0', grad_fn=<AddBackward0>)
loss after 285 epochs tensor(0.1679, device='cuda:0', grad_fn=<AddBackward0>)
loss after 288 epochs tensor(0.1655, device='cuda:0', grad_fn=<AddBackward0>)
loss after 291 epochs tensor(0.1629, device='cuda:0', grad_fn=<AddBackward0>)
loss after 294 epochs tensor(0.1646, device='cuda:0', grad_fn=<AddBackward0>)
loss after 297 epochs tensor(0.1619, device='cuda:0', grad_fn=<A

In [20]:
path = "/content/drive/My Drive/CS663_project/SFTMD_temp730.pth"
net.load_state_dict(torch.load(path))

<All keys matched successfully>

In [21]:
train_SFTMD(train_dl=train_dl,valid_dl=valid_dl,model= net,epochs=args.num_step, optimizer=optimizer, loss_func=loss,batch_size=args.batch_size,scheduler =scheduler)

loss after 738 epochs tensor(0.1478, device='cuda:0', grad_fn=<AddBackward0>)
loss after 741 epochs tensor(0.1501, device='cuda:0', grad_fn=<AddBackward0>)
loss after 744 epochs tensor(0.1449, device='cuda:0', grad_fn=<AddBackward0>)
loss after 747 epochs tensor(0.1542, device='cuda:0', grad_fn=<AddBackward0>)
loss after 750 epochs tensor(0.1478, device='cuda:0', grad_fn=<AddBackward0>)
loss after 753 epochs tensor(0.1392, device='cuda:0', grad_fn=<AddBackward0>)
loss after 756 epochs tensor(0.1535, device='cuda:0', grad_fn=<AddBackward0>)
loss after 759 epochs tensor(0.1543, device='cuda:0', grad_fn=<AddBackward0>)
loss after 762 epochs tensor(0.1539, device='cuda:0', grad_fn=<AddBackward0>)
loss after 765 epochs tensor(0.1493, device='cuda:0', grad_fn=<AddBackward0>)
loss after 768 epochs tensor(0.1539, device='cuda:0', grad_fn=<AddBackward0>)
loss after 771 epochs tensor(0.1456, device='cuda:0', grad_fn=<AddBackward0>)
loss after 774 epochs tensor(0.1459, device='cuda:0', grad_fn=<A

In [23]:
train_SFTMD(train_dl=train_dl,valid_dl=valid_dl,model= net,epochs=args.num_step, optimizer=optimizer, loss_func=loss,batch_size=args.batch_size,scheduler =scheduler)