In [3]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from torchvision.transforms.functional import resize, InterpolationMode

import argparse
import os
import copy

https://github.com/yjn870/DRRN-pytorch

In [4]:
os.environ['KAGGLE_USERNAME'] = 'tungbinhthuong'
os.environ['KAGGLE_KEY'] = 'fd622c39c04da294c5e392c7f76ec1a7'

In [None]:
# !kaggle kernels output tungbinhthuong/imagesuperresolution-cvproject -p /kaggle/working/best_weights

In [None]:
!pip install datasets
!pip install super-image
!pip install torchsummary

In [5]:
class Args:
    def __init__(self, B, U, num_features, 
                 weights_file, eval_scale, 
                 lr, batch_size, num_workers,
                 clip_grad,
                 num_epochs,
                 seed,
                 outputs_dir,
                    eval = True):
        self.B = B
        self.U = U
        self.num_features = num_features
        self.weights_file = weights_file
        self.lr = lr
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.outputs_dir = os.path.join(outputs_dir, str(eval_scale))
        isExist = os.path.exists(self.outputs_dir)
        if not isExist:
            os.makedirs(self.outputs_dir)
            print("The new directory is created!")
            
        self.clip_grad = clip_grad
        self.num_epochs = num_epochs
        self.seed = seed
        self.eval = eval
        self.eval_scale = eval_scale

args = Args(B = 1, U = 9, num_features = 32, 
            lr = 0.01, weights_file = None, 
            clip_grad = 0.01, num_epochs = 50, 
            num_workers=2, seed = 123, 
            outputs_dir = "/kaggle/working/", 
            batch_size=64, eval_scale = 4)

In [6]:
from torch import nn


class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super(ConvLayer, self).__init__()
        self.module = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=kernel_size // 2, bias=False)
        )

    def forward(self, x):
        return self.module(x)


class ResidualUnit(nn.Module):
    def __init__(self, num_features):
        super(ResidualUnit, self).__init__()
        self.module = nn.Sequential(
            ConvLayer(num_features, num_features),
            ConvLayer(num_features, num_features)
        )

    def forward(self, h0, x):
        return h0 + self.module(x)


class RecursiveBlock(nn.Module):
    def __init__(self, in_channels, out_channels, U):
        super(RecursiveBlock, self).__init__()
        self.U = U
        self.h0 = ConvLayer(in_channels, out_channels)
        self.ru = ResidualUnit(out_channels)

    def forward(self, x):
        h0 = self.h0(x)
        x = h0
        for i in range(self.U):
            x = self.ru(h0, x)
        return x


class DRRN(nn.Module):
    def __init__(self, B, U, num_channels=1, num_features=128):
        super(DRRN, self).__init__()
        self.rbs = nn.Sequential(*[RecursiveBlock(num_channels if i == 0 else num_features, num_features, U) for i in range(B)])
        self.rec = ConvLayer(num_features, num_channels)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        residual = x
        x = self.rbs(x)
        x = self.rec(x)
        x += residual
        return x

In [7]:
import torch

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [8]:
from torchsummary import summary

In [9]:
model = DRRN(B=args.B, U=args.U, num_features=args.num_features, num_channels=3)

In [10]:
print(model)

DRRN(
  (rbs): Sequential(
    (0): RecursiveBlock(
      (h0): ConvLayer(
        (module): Sequential(
          (0): ReLU(inplace=True)
          (1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (ru): ResidualUnit(
        (module): Sequential(
          (0): ConvLayer(
            (module): Sequential(
              (0): ReLU(inplace=True)
              (1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            )
          )
          (1): ConvLayer(
            (module): Sequential(
              (0): ReLU(inplace=True)
              (1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            )
          )
        )
      )
    )
  )
  (rec): ConvLayer(
    (module): Sequential(
      (0): ReLU(inplace=True)
      (1): Conv2d(32, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
  )
)


In [11]:
import PIL.Image as pil_image
import numpy as np
import torch
import cv2


def load_image(path):
    return pil_image.open(path).convert('RGB')


def generate_lr(image, scale):
    image = image.resize((image.width // scale, image.height // scale), resample=pil_image.BICUBIC)
    image = image.resize((image.width * scale, image.height * scale), resample=pil_image.BICUBIC)
    return image


def modcrop(image, modulo):
    w = image.width - image.width % modulo
    h = image.height - image.height % modulo
    return image.crop((0, 0, w, h))


def generate_patch(image, patch_size, stride):
    for i in range(0, image.height - patch_size + 1, stride):
        for j in range(0, image.width - patch_size + 1, stride):
            yield image.crop((j, i, j + patch_size, i + patch_size))


def image_to_array(image):
    return np.array(image).transpose((2, 0, 1))

def normalize(x):
    return x / 255.0

def rgb_to_y(img, dim_order='hwc'):
    if dim_order == 'hwc':
        return 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
    else:
        return 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.

def load_weights(model, path):
    state_dict = model.state_dict()
    for n, p in torch.load(path, map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)
    return model


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [12]:
def convert_rgb_to_y(img, dim_order='hwc'):
    if dim_order == 'hwc':
        return 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
    else:
        return 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.


def denormalize(img):
    img = img.mul(255.0).clamp(0.0, 255.0)
    return img

def calc_psnr(img1, img2, max=255.0):
    return 10. * ((max ** 2) / ((img1 - img2) ** 2).mean()).log10()

def ssim(img1, img2):
    C1 = (0.01 * 255)**2
    C2 = (0.03 * 255)**2

    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    kernel = cv2.getGaussianKernel(11, 1.5)
    window = np.outer(kernel, kernel.transpose())

    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid
    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
    mu1_sq = mu1**2
    mu2_sq = mu2**2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
                                                            (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()


def calc_ssim(img1, img2):
    """calculate SSIM
    the same outputs as MATLAB's
    img1, img2: [0, 255]
    """
    img1=img1.detach().cpu().numpy()
    img2 = img2.detach().cpu().numpy()
    if not img1.shape == img2.shape:
        raise ValueError('Input images must have the same dimensions.')
    if img1.ndim == 2:
        return ssim(img1, img2)
    elif img1.ndim == 3:
        if img1.shape[2] == 3:
            ssims = []
            for i in range(3):
                ssims.append(ssim(img1, img2))
            return np.array(ssims).mean()
        elif img1.shape[2] == 1:
            return ssim(np.squeeze(img1), np.squeeze(img2))
    else:
        raise ValueError('Wrong input image dimensions.')

def compute_metrics(eval_prediction, scale):
    preds = eval_prediction.predictions
    labels = eval_prediction.labels

    # from piq import ssim, psnr
    # print(psnr(denormalize(preds), denormalize(labels), data_range=255.),
    #       ssim(denormalize(preds), denormalize(labels), data_range=255.))

    # original = preds[0][0][0][0]

    preds = convert_rgb_to_y(denormalize(preds.squeeze(0)), dim_order='chw')
    labels = convert_rgb_to_y(denormalize(labels.squeeze(0)), dim_order='chw')

    # print(preds[0][0], original * 255.)

    preds = preds[scale:-scale, scale:-scale]
    labels = labels[scale:-scale, scale:-scale]

    # print(calc_psnr(preds, labels), calc_ssim(preds, labels))

    return {
        'psnr': calc_psnr(preds, labels),
        'ssim': calc_ssim(preds, labels)
    }

In [13]:
import random
import numpy as np
from typing import NamedTuple, Tuple, Union

class EvalPrediction(NamedTuple):
    """
    Evaluation output (always contains labels), to be used to compute metrics.
    Parameters:
        predictions (:obj:`np.ndarray`): Predictions of the model.
        labels (:obj:`np.ndarray`): Targets to be matched.
    """

    predictions: Union[np.ndarray, Tuple[np.ndarray]]
    labels: np.ndarray


### LOAD Div2k Training dataset

In [14]:
from datasets import load_dataset
from super_image.data import EvalDataset, TrainDataset, augment_five_crop

augmented_dataset = load_dataset('eugenesiow/Div2k', 'bicubic_x{}'.format(args.eval_scale), split='train')                          # download and augment the data with the five_crop method
train_dataset = TrainDataset(augmented_dataset)                                                     # prepare the train dataset for loading PyTorch DataLoader
eval_dataset = EvalDataset(load_dataset('eugenesiow/Div2k', 'bicubic_x{}'.format(args.eval_scale), split='validation'))      # prepare the eval dataset for the PyTorch DataLoader

In [15]:
# model = DRRN(B=args.B, U=args.U, num_features=args.num_features, num_channels=3).to(device)

if args.weights_file is not None:
    model = load_weights(model, args.weights_file)

#     train_dataset = TrainDataset(args.train_file)
train_dataloader = DataLoader(dataset=train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              pin_memory=True)

if args.eval == True:
#     eval_dataset = EvalDataset(args.eval_file)
    eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

best_weights = copy.deepcopy(model.state_dict())
best_epoch = 0
best_psnr = 0.0
model.to(device)

criterion = nn.L1Loss()

# optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
step_size = int(len(train_dataset) / args.batch_size * 200)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size)

for epoch in range(args.num_epochs):
#     lr = args.lr * (0.5 ** ((epoch + 1) // 10))
    lr = args.lr * (0.1 ** (epoch // int(num_train_epochs * 0.8)))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    model.train()
    epoch_losses = AverageMeter()

    with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size), ncols=80) as t:
        t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1))

        for data in train_dataloader:
            inputs, labels = data
            x = resize(inputs, size = (inputs.shape[2] * args.eval_scale, inputs.shape[3] * args.eval_scale), interpolation  =InterpolationMode.BICUBIC)
            x = x.to(device)
            
            labels = labels.to(device)

            preds = model(x)

            loss = criterion(preds, labels) / (2 * len(inputs))
            epoch_losses.update(loss.item(), len(inputs))

            optimizer.zero_grad()
            loss.backward()

#             nn.utils.clip_grad.clip_grad_norm_(model.parameters(), args.clip_grad / lr)

            optimizer.step()
            scheduler.step()

            t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg), lr=lr)
            t.update(len(inputs))

    torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))

    if args.eval == True:
        model.eval()
        
        epoch_psnr = AverageMeter()
        epoch_ssim = AverageMeter()

        for data in eval_dataloader:
            inputs, labels = data
            x = resize(inputs, size = (inputs.shape[2] * args.eval_scale, inputs.shape[3] * args.eval_scale), 
                       interpolation = InterpolationMode.BICUBIC)
            x = x.to(device)
            labels = labels.to(device)

            with torch.no_grad():
                preds = model(x)

#             preds = denormalize(preds.squeeze(0).squeeze(0))
#             labels = denormalize(labels.squeeze(0).squeeze(0))
            res = compute_metrics(EvalPrediction(predictions=preds, labels=labels), args.eval_scale)
            
#             epoch_psnr.update(PSNR(preds, labels, shave_border=args.eval_scale), len(inputs))
            epoch_psnr.update(res['psnr'], len(inputs))
            epoch_ssim.update(res['ssim'], len(inputs))

        print('eval psnr: {:.2f}, eval ssim: {:.2f}'.format(epoch_psnr.avg, epoch_ssim.avg))

        if epoch_psnr.avg > best_psnr:
            best_epoch = epoch
            best_psnr = epoch_psnr.avg
            best_psnr_ssim = epoch_ssim.avg
            best_weights = copy.deepcopy(model.state_dict())

if args.eval ==True:
    print('best epoch: {}, psnr: {:.2f}, ssim: {:.2f}'.format(best_epoch, best_psnr, best_psnr_ssim))
    torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))

epoch: 0/49: : 800it [01:32,  8.69it/s, loss=9686115770608.767578, lr=0.1]      


eval psnr: 11.55, eval ssim: 0.53


epoch: 1/49: : 800it [01:23,  9.56it/s, loss=0.033248, lr=0.1]                  


eval psnr: 22.27, eval ssim: 0.77


epoch: 2/49: : 800it [01:26,  9.27it/s, loss=0.000653, lr=0.1]                  


eval psnr: 27.58, eval ssim: 0.77


epoch: 3/49: : 800it [01:27,  9.17it/s, loss=0.000259, lr=0.1]                  


eval psnr: 28.27, eval ssim: 0.78


epoch: 4/49: : 800it [01:25,  9.36it/s, loss=0.000250, lr=0.1]                  


eval psnr: 28.27, eval ssim: 0.78


epoch: 5/49: : 800it [01:21,  9.86it/s, loss=0.000253, lr=0.1]                  


eval psnr: 28.27, eval ssim: 0.78


epoch: 6/49: : 800it [01:19, 10.05it/s, loss=0.000246, lr=0.1]                  


eval psnr: 28.27, eval ssim: 0.78


epoch: 7/49: : 800it [01:19, 10.00it/s, loss=0.000255, lr=0.1]                  


eval psnr: 28.27, eval ssim: 0.78


epoch: 8/49: : 800it [01:19, 10.05it/s, loss=0.000243, lr=0.1]                  


eval psnr: 28.27, eval ssim: 0.78


epoch: 9/49: : 800it [01:19, 10.02it/s, loss=0.000260, lr=0.05]                 


eval psnr: 28.27, eval ssim: 0.78


epoch: 10/49: : 800it [01:22,  9.72it/s, loss=0.000254, lr=0.05]                


eval psnr: 28.27, eval ssim: 0.78


epoch: 11/49: : 800it [01:19, 10.11it/s, loss=0.000255, lr=0.05]                


eval psnr: 28.27, eval ssim: 0.78


epoch: 12/49: : 800it [01:18, 10.19it/s, loss=0.000252, lr=0.05]                


eval psnr: 28.27, eval ssim: 0.78


epoch: 13/49: : 800it [01:20,  9.94it/s, loss=0.000247, lr=0.05]                


eval psnr: 28.27, eval ssim: 0.78


epoch: 14/49: : 800it [01:28,  9.06it/s, loss=0.000248, lr=0.05]                


eval psnr: 28.27, eval ssim: 0.78


epoch: 15/49: : 800it [01:26,  9.27it/s, loss=0.000251, lr=0.05]                


eval psnr: 28.27, eval ssim: 0.78


epoch: 16/49: : 800it [01:22,  9.73it/s, loss=0.000253, lr=0.05]                


eval psnr: 28.27, eval ssim: 0.78


epoch: 17/49: : 800it [01:29,  8.97it/s, loss=0.000256, lr=0.05]                


eval psnr: 28.27, eval ssim: 0.78


epoch: 18/49: : 800it [01:27,  9.19it/s, loss=0.000237, lr=0.05]                


eval psnr: 28.27, eval ssim: 0.78


epoch: 19/49: : 800it [01:22,  9.66it/s, loss=0.000250, lr=0.025]               


eval psnr: 28.27, eval ssim: 0.78


epoch: 20/49: : 800it [01:20,  9.91it/s, loss=0.000251, lr=0.025]               


eval psnr: 28.27, eval ssim: 0.78


epoch: 21/49: : 800it [01:20,  9.94it/s, loss=0.000249, lr=0.025]               


eval psnr: 28.27, eval ssim: 0.78


epoch: 22/49: : 800it [01:20,  9.91it/s, loss=0.000254, lr=0.025]               


eval psnr: 28.27, eval ssim: 0.78


epoch: 23/49: : 800it [01:19, 10.06it/s, loss=0.000239, lr=0.025]               


eval psnr: 28.27, eval ssim: 0.78


epoch: 24/49: : 800it [01:20,  9.98it/s, loss=0.000252, lr=0.025]               


eval psnr: 28.27, eval ssim: 0.78


epoch: 25/49: : 800it [01:21,  9.80it/s, loss=0.000257, lr=0.025]               


eval psnr: 28.27, eval ssim: 0.78


epoch: 26/49: : 800it [01:20,  9.90it/s, loss=0.000258, lr=0.025]               


eval psnr: 28.27, eval ssim: 0.78


epoch: 27/49: : 800it [01:21,  9.86it/s, loss=0.000249, lr=0.025]               


eval psnr: 28.27, eval ssim: 0.78


epoch: 28/49: : 800it [01:20,  9.88it/s, loss=0.000250, lr=0.025]               


eval psnr: 28.27, eval ssim: 0.78


epoch: 29/49: : 800it [01:18, 10.24it/s, loss=0.000247, lr=0.0125]              


eval psnr: 28.27, eval ssim: 0.78


epoch: 30/49: : 800it [01:18, 10.24it/s, loss=0.000257, lr=0.0125]              


eval psnr: 28.27, eval ssim: 0.78


epoch: 31/49: : 800it [01:17, 10.33it/s, loss=0.000256, lr=0.0125]              


eval psnr: 28.27, eval ssim: 0.78


epoch: 32/49: : 800it [01:21,  9.81it/s, loss=0.000247, lr=0.0125]              


eval psnr: 28.27, eval ssim: 0.78


epoch: 33/49: : 800it [01:24,  9.43it/s, loss=0.000257, lr=0.0125]              


eval psnr: 28.27, eval ssim: 0.78


epoch: 34/49: : 800it [01:26,  9.27it/s, loss=0.000250, lr=0.0125]              


eval psnr: 28.27, eval ssim: 0.78


epoch: 35/49: : 800it [01:24,  9.52it/s, loss=0.000244, lr=0.0125]              


eval psnr: 28.27, eval ssim: 0.78


epoch: 36/49: : 800it [01:23,  9.59it/s, loss=0.000248, lr=0.0125]              


eval psnr: 28.27, eval ssim: 0.78


epoch: 37/49: : 800it [01:20,  9.93it/s, loss=0.000243, lr=0.0125]              


eval psnr: 28.27, eval ssim: 0.78


epoch: 38/49: : 800it [01:19, 10.06it/s, loss=0.000246, lr=0.0125]              


eval psnr: 28.27, eval ssim: 0.78


epoch: 39/49: : 800it [01:19, 10.12it/s, loss=0.000253, lr=0.00625]             


eval psnr: 28.27, eval ssim: 0.78


epoch: 40/49: : 800it [01:18, 10.16it/s, loss=0.000251, lr=0.00625]             


eval psnr: 28.27, eval ssim: 0.78


epoch: 41/49: : 800it [01:17, 10.28it/s, loss=0.000256, lr=0.00625]             


eval psnr: 28.27, eval ssim: 0.78


epoch: 42/49: : 800it [01:21,  9.80it/s, loss=0.000252, lr=0.00625]             


eval psnr: 28.27, eval ssim: 0.78


epoch: 43/49: : 800it [01:21,  9.76it/s, loss=0.000249, lr=0.00625]             


eval psnr: 28.27, eval ssim: 0.78


epoch: 44/49: : 800it [01:29,  8.93it/s, loss=0.000253, lr=0.00625]             


eval psnr: 28.27, eval ssim: 0.78


epoch: 45/49: : 800it [01:26,  9.21it/s, loss=0.000245, lr=0.00625]             


eval psnr: 28.27, eval ssim: 0.78


epoch: 46/49: : 800it [01:27,  9.16it/s, loss=0.000246, lr=0.00625]             


eval psnr: 28.27, eval ssim: 0.78


epoch: 47/49: : 800it [01:28,  9.08it/s, loss=0.000252, lr=0.00625]             


eval psnr: 28.27, eval ssim: 0.78


epoch: 48/49: : 800it [01:27,  9.17it/s, loss=0.000261, lr=0.00625]             


eval psnr: 28.27, eval ssim: 0.78


epoch: 49/49: : 800it [01:25,  9.41it/s, loss=0.000245, lr=0.00313]             


eval psnr: 28.27, eval ssim: 0.78
best epoch: 4, psnr: 28.27, ssim: 0.78


In [None]:
from super_image import ImageLoader
from PIL import Image
from torchvision.transforms import ToTensor, ToPILImage
import requests

def detect_from_url(image_url, model, device = 'cpu'):
    model.to(device)
    
    image = Image.open(requests.get(image_url, stream=True).raw)
    
    inputs = ImageLoader.load_image(image)
    preds = model(inputs.to(device))

    ImageLoader.save_image(preds, './scaled_2x.png')
    ImageLoader.save_compare(inputs, preds, './scaled_2x_compare.png')

def detect_from_path(image_path, model, device = "cpu"):
    model.to(device)
    
    # Load the image from the given image_path
    image = Image.open(image_path)

    # Load the image into a format that the model can understand
    inputs = ToTensor()(image).unsqueeze(0)

    # Use the given model to make predictions
    preds = model(inputs.to(device))

    # Convert the predictions into an image
    output_image = ToPILImage()(preds.squeeze(0))

    # Save the scaled image and a comparison image
    output_image.save('./scaled_2x.png')
    ImageLoader.save_compare(inputs, preds, './scaled_2x_compare.png')

# Test
# url = 'https://paperswithcode.com/media/datasets/Set5-0000002728-07a9793f_zA3bDjj.jpg'
# detect_from_url(url, model)