In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import h5py
from torch.utils.data import Dataset


class TrainDataset(Dataset):
    def __init__(self, h5_file):
        super(TrainDataset, self).__init__()
        self.h5_file = h5_file

    def __getitem__(self, idx):
        with h5py.File(self.h5_file, 'r') as f:
            lr = f['lr'][idx]
            hr = f['hr'][idx]
            return lr, hr

    def __len__(self):
        with h5py.File(self.h5_file, 'r') as f:
            return len(f['lr'])


class EvalDataset(Dataset):
    def __init__(self, h5_file):
        super(EvalDataset, self).__init__()
        self.h5_file = h5_file

    def __getitem__(self, idx):
        with h5py.File(self.h5_file, 'r') as f:
            lr = f['lr'][str(idx)].value
            hr = f['hr'][str(idx)].value
            return lr, hr

    def __len__(self):
        with h5py.File(self.h5_file, 'r') as f:
            return len(f['lr'])

In [None]:

from torch import nn


class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super(ConvLayer, self).__init__()
        self.module = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=kernel_size // 2, bias=False)
        )

    def forward(self, x):
        return self.module(x)


class ResidualUnit(nn.Module):
    def __init__(self, num_features):
        super(ResidualUnit, self).__init__()
        self.module = nn.Sequential(
            ConvLayer(num_features, num_features),
            ConvLayer(num_features, num_features)
        )

    def forward(self, h0, x):
        return h0 + self.module(x)


class RecursiveBlock(nn.Module):
    def __init__(self, in_channels, out_channels, U):
        super(RecursiveBlock, self).__init__()
        self.U = U
        self.h0 = ConvLayer(in_channels, out_channels)
        self.ru = ResidualUnit(out_channels)

    def forward(self, x):
        h0 = self.h0(x)
        x = h0
        for i in range(self.U):
            x = self.ru(h0, x)
        return x


class DRRN(nn.Module):
    def __init__(self, B, U, num_channels=1, num_features=128):
        super(DRRN, self).__init__()
        self.rbs = nn.Sequential(*[RecursiveBlock(num_channels if i == 0 else num_features, num_features, U) for i in range(B)])
        self.rec = ConvLayer(num_features, num_channels)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        residual = x
        x = self.rbs(x)
        x = self.rec(x)
        x += residual
        return x

In [None]:
import torch
import numpy as np


def calc_patch_size(func):
    def wrapper(args):
        if args.scale == 2:
            args.patch_size = 10
        elif args.scale == 3:
            args.patch_size = 7
        elif args.scale == 4:
            args.patch_size = 6
        else:
            raise Exception('Scale Error', args.scale)
        return func(args)
    return wrapper


def convert_rgb_to_y(img, dim_order='hwc'):
    if dim_order == 'hwc':
        return 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
    else:
        return 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.


def convert_rgb_to_ycbcr(img, dim_order='hwc'):
    if dim_order == 'hwc':
        y = 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
        cb = 128. + (-37.945 * img[..., 0] - 74.494 * img[..., 1] + 112.439 * img[..., 2]) / 256.
        cr = 128. + (112.439 * img[..., 0] - 94.154 * img[..., 1] - 18.285 * img[..., 2]) / 256.
    else:
        y = 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.
        cb = 128. + (-37.945 * img[0] - 74.494 * img[1] + 112.439 * img[2]) / 256.
        cr = 128. + (112.439 * img[0] - 94.154 * img[1] - 18.285 * img[2]) / 256.
    return np.array([y, cb, cr]).transpose([1, 2, 0])


def convert_ycbcr_to_rgb(img, dim_order='hwc'):
    if dim_order == 'hwc':
        r = 298.082 * img[..., 0] / 256. + 408.583 * img[..., 2] / 256. - 222.921
        g = 298.082 * img[..., 0] / 256. - 100.291 * img[..., 1] / 256. - 208.120 * img[..., 2] / 256. + 135.576
        b = 298.082 * img[..., 0] / 256. + 516.412 * img[..., 1] / 256. - 276.836
    else:
        r = 298.082 * img[0] / 256. + 408.583 * img[2] / 256. - 222.921
        g = 298.082 * img[0] / 256. - 100.291 * img[1] / 256. - 208.120 * img[2] / 256. + 135.576
        b = 298.082 * img[0] / 256. + 516.412 * img[1] / 256. - 276.836
    return np.array([r, g, b]).transpose([1, 2, 0])


def preprocess(img, device):
    img = np.array(img).astype(np.float32)
    ycbcr = convert_rgb_to_ycbcr(img)
    x = ycbcr[..., 0]
    x /= 255.
    x = torch.from_numpy(x).to(device)
    x = x.unsqueeze(0).unsqueeze(0)
    return x, ycbcr


def calc_psnr(img1, img2):
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))


from scipy.ndimage import gaussian_filter

def calc_ssim(img1, img2, sd=1.5, C1=0.01**2, C2=0.03**2):
    img1 = img1.cpu()
    img2 = img2.cpu()
    mu1 = gaussian_filter(img1, sd)
    mu2 = gaussian_filter(img2, sd)
    mu1_sq = mu1 * mu1
    mu2_sq = mu2 * mu2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = gaussian_filter(img1 * img1, sd) - mu1_sq
    sigma2_sq = gaussian_filter(img2 * img2, sd) - mu2_sq
    sigma12 = gaussian_filter(img1 * img2, sd) - mu1_mu2
    
    ssim_num = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2))
    ssim_den = ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    ssim_map = ssim_num / ssim_den
    mssim = np.mean(ssim_map)
    
    return mssim


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


In [None]:
import os
import copy

import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm


def train(train_file, outputs_dir, eval_file, eval_scale, weights_file, B=1, U=9, num_features=128, lr=0.1, clip_grad=0.01, batch_size=128, num_epochs=50, num_workers=8, seed=123):
    outputs_dir = os.path.join(outputs_dir, 'x234')
    if not os.path.exists(outputs_dir):
        os.makedirs(outputs_dir)

    cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    torch.manual_seed(seed)

    model = DRRN(B=B, U=U, num_features=num_features).to(device)

    if weights_file is not None:
        model = load_weights(model, weights_file)

    criterion = nn.MSELoss(reduction='sum')
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)

    train_dataset = TrainDataset(train_file)
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers,
                                  pin_memory=True)

    if eval_file is not None:
        eval_dataset = EvalDataset(eval_file)
        eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

    best_weights = copy.deepcopy(model.state_dict())
    best_epoch = 0
    best_psnr = 0.0

    for epoch in range(num_epochs):
        lr = lr * (0.5 ** ((epoch + 1) // 10))

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        model.train()
        epoch_losses = AverageMeter()

        with tqdm(total=(len(train_dataset) - len(train_dataset) % batch_size), ncols=80) as t:
            t.set_description('epoch: {}/{}'.format(epoch, num_epochs - 1))

            for data in train_dataloader:
                inputs, labels = data

                inputs = inputs.to(device)
                labels = labels.to(device)

                preds = model(inputs)

                loss = criterion(preds, labels) / (2 * len(inputs))
                epoch_losses.update(loss.item(), len(inputs))

                optimizer.zero_grad()
                loss.backward()

                nn.utils.clip_grad.clip_grad_norm_(model.parameters(), clip_grad / lr)

                optimizer.step()

                t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg), lr=lr)
                t.update(len(inputs))

        torch.save(model.state_dict(), os.path.join(outputs_dir, 'epoch_{}.pth'.format(epoch)))

        if eval_file is not None:
            model.eval()
            epoch_psnr = AverageMeter()

            for data in eval_dataloader:
                inputs, labels = data

                inputs = inputs.to(device)
                labels = labels.to(device)

                with torch.no_grad():
                    preds = model(inputs)

                preds = denormalize(preds.squeeze(0).squeeze(0))
                labels = denormalize(labels.squeeze(0).squeeze(0))

                epoch_psnr.update(PSNR(preds, labels, shave_border=eval_scale), len(inputs))

            print('eval psnr: {:.2f}'.format(epoch_psnr.avg))

            if epoch_psnr.avg > best_psnr:
                best_epoch = epoch
                best_psnr = epoch_psnr.avg
                best_weights = copy.deepcopy(model.state_dict())

    if eval_file is not None:
        print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
        torch.save(best_weights, os.path.join(outputs_dir, 'best.pth'))

In [None]:
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_image


def test(weights_file, image_file, scale, save=False, debug=False, B=1, U=9, num_features=128):
    cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    model = DRRN(B=B, U=U, num_features=num_features).to(device)
    state_dict = model.state_dict()
    for n, p in torch.load(weights_file, map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)

    model.eval()

    image = pil_image.open(image_file).convert('RGB')
    image_file = os.path.basename(image_file)

    image_width = (image.width // scale) * scale
    image_height = (image.height // scale) * scale

    hr = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
    lr = hr.resize((hr.width // scale, hr.height // scale), resample=pil_image.BICUBIC)
    bicubic = lr.resize((lr.width * scale, lr.height * scale), resample=pil_image.BICUBIC)

    lr, _ = preprocess(lr, device)
    hr, _ = preprocess(hr, device)
    bicubic, ycbcr = preprocess(bicubic, device)

    with torch.no_grad():
        # Pre upsampling
        preds = model(bicubic)

    psnr = calc_psnr(hr, preds)
    ssim = calc_ssim(hr, preds)
    if debug:
        print(f'PSNR/SSIM: {psnr:.2f}/{ssim:.4f}')

    preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)

    output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
    output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
    output = pil_image.fromarray(output)
    if save:
        save_path = f'/content/drive/Shareddrives/BTP Meets/results/Set5/{scale}x/{image_file}'
        output.save(save_path.replace('.', '_drrn.'))
    return float(psnr), float(ssim)


In [None]:
import os

def do_test(psnr, ssim, BASE_DIR, save=False, debug=False):
    scales = [2, 3, 4, 8, 16]

    for file in os.listdir(BASE_DIR):
        if file.endswith(".png"):
            image_file_path = os.path.join(BASE_DIR, file)
            if debug:
                print(file)
            for scale in scales:
                if debug:
                    print(f"Scale: {scale}")
                result = test(f'/content/drive/Shareddrives/BTP Meets/models/drrn.pth', image_file_path, scale, save, debug)
                if scale not in psnr:
                    psnr[scale] = []
                if scale not in ssim:
                    ssim[scale] = []
                psnr[scale].append(result[0])
                ssim[scale].append(result[1])
            if debug:
                print()


In [None]:
psnr = {}
ssim = {}
do_test(psnr, ssim, '/content/drive/Shareddrives/BTP Meets/datasets/test/Set5/', True, True)

head.png
Scale: 2
PSNR/SSIM: 36.01/0.8907
Scale: 3
PSNR/SSIM: 35.58/0.8792
Scale: 4
PSNR/SSIM: 32.91/0.7951
Scale: 8
PSNR/SSIM: 29.33/0.6819
Scale: 16
PSNR/SSIM: 27.02/0.6785

butterfly.png
Scale: 2
PSNR/SSIM: 34.59/0.9752
Scale: 3
PSNR/SSIM: 30.58/0.9459
Scale: 4
PSNR/SSIM: 27.11/0.8949
Scale: 8
PSNR/SSIM: 17.92/0.4952
Scale: 16
PSNR/SSIM: 15.19/0.3714

bird.png
Scale: 2
PSNR/SSIM: 42.37/0.9889
Scale: 3
PSNR/SSIM: 36.71/0.9650
Scale: 4
PSNR/SSIM: 33.42/0.9267
Scale: 8
PSNR/SSIM: 25.30/0.7026
Scale: 16
PSNR/SSIM: 21.75/0.5765

baby.png
Scale: 2
PSNR/SSIM: 38.80/0.9672
Scale: 3
PSNR/SSIM: 36.43/0.9402
Scale: 4
PSNR/SSIM: 33.56/0.8918
Scale: 8
PSNR/SSIM: 27.42/0.7406
Scale: 16
PSNR/SSIM: 24.45/0.6804

woman.png
Scale: 2
PSNR/SSIM: 36.08/0.9730
Scale: 3
PSNR/SSIM: 32.76/0.9433
Scale: 4
PSNR/SSIM: 30.01/0.9062
Scale: 8
PSNR/SSIM: 22.62/0.6884
Scale: 16
PSNR/SSIM: 19.28/0.5742



In [None]:
import statistics

scales = [2, 3, 4, 8, 16]
for scale in scales:
    print(f'Avg PSNR/SSIM {scale}x: {statistics.mean(psnr[scale]):.2f}/{statistics.mean(ssim[scale]):.4f}')

Avg PSNR/SSIM 2x: 37.57/0.9590
Avg PSNR/SSIM 3x: 34.41/0.9347
Avg PSNR/SSIM 4x: 31.40/0.8829
Avg PSNR/SSIM 8x: 24.52/0.6617
Avg PSNR/SSIM 16x: 21.54/0.5762


In [None]:
scales = [2, 3, 4, 8, 16]

def calc_result(dataset):
    print()
    print(dataset)
    psnr = {}
    ssim = {}
    do_test(psnr, ssim, f'/content/drive/Shareddrives/BTP Meets/datasets/test/{dataset}/', False, True)
    for scale in scales:
        print(f'Avg PSNR/SSIM {scale}x: {statistics.mean(psnr[scale]):.2f}/{statistics.mean(ssim[scale]):.4f}')

calc_result('Set14')
calc_result('BSDS100')
calc_result('Manga109')
calc_result('Urban100')

'''
Set14
Avg PSNR/SSIM 2x: 33.09/0.9147
Avg PSNR/SSIM 3x: 30.34/0.8509
Avg PSNR/SSIM 4x: 28.04/0.7740
Avg PSNR/SSIM 8x: 23.14/0.5782
Avg PSNR/SSIM 16x: 20.81/0.5050

BSDS100
Avg PSNR/SSIM 2x: 33.73/0.9226
Avg PSNR/SSIM 3x: 29.44/0.8143
Avg PSNR/SSIM 4x: 28.22/0.7552
Avg PSNR/SSIM 8x: 24.21/0.5729
Avg PSNR/SSIM 16x: 22.27/0.5148

Manga109
Avg PSNR/SSIM 2x: 37.93/0.9770
Avg PSNR/SSIM 3x: 32.620.9398
Avg PSNR/SSIM 4x: 29.51/0.8975
Avg PSNR/SSIM 8x: 21.77/0.6632
Avg PSNR/SSIM 16x: 19.44/0.6009

Urban100
Avg PSNR/SSIM 2x: 30.88/0.9158
Avg PSNR/SSIM 3x: 28.37/0.8549
Avg PSNR/SSIM 4x: 25.28/0.7574
Avg PSNR/SSIM 8x: 20.86/0.5232
Avg PSNR/SSIM 16x: 19.11/0.4598
'''


Manga109
RinToSiteSippuNoNaka.png
Scale: 2
PSNR/SSIM: 39.16/0.9814
Scale: 3
PSNR/SSIM: 32.60/0.9539
Scale: 4
PSNR/SSIM: 29.25/0.9294
Scale: 8
PSNR/SSIM: 21.04/0.7699
Scale: 16
PSNR/SSIM: 19.16/0.7250

YamatoNoHane.png
Scale: 2
PSNR/SSIM: 43.47/0.9928
Scale: 3
PSNR/SSIM: 38.62/0.9842
Scale: 4
PSNR/SSIM: 34.21/0.9694
Scale: 8
PSNR/SSIM: 24.03/0.8033
Scale: 16
PSNR/SSIM: 21.52/0.7527

UchuKigekiM774.png
Scale: 2
PSNR/SSIM: 32.92/0.9407
Scale: 3
PSNR/SSIM: 28.00/0.8352
Scale: 4
PSNR/SSIM: 25.86/0.7454
Scale: 8
PSNR/SSIM: 20.50/0.4038
Scale: 16
PSNR/SSIM: 18.96/0.3450

Raphael.png
Scale: 2
PSNR/SSIM: 37.12/0.9808
Scale: 3
PSNR/SSIM: 30.97/0.9369
Scale: 4
PSNR/SSIM: 26.67/0.8679
Scale: 8
PSNR/SSIM: 20.05/0.5929
Scale: 16
PSNR/SSIM: 17.90/0.5023

YumeiroCooking.png
Scale: 2
PSNR/SSIM: 37.82/0.9873
Scale: 3
PSNR/SSIM: 30.94/0.9478
Scale: 4
PSNR/SSIM: 27.32/0.8814
Scale: 8
PSNR/SSIM: 22.43/0.6835
Scale: 16
PSNR/SSIM: 20.42/0.6335

MayaNoAkaiKutsu.png
Scale: 2
PSNR/SSIM: 40.60/0.9867
Scale: 3
P

'\nSet14\nAvg PSNR/SSIM 2x: 33.09/0.9147\nAvg PSNR/SSIM 3x: 30.34/0.8509\nAvg PSNR/SSIM 4x: 28.04/0.7740\nAvg PSNR/SSIM 8x: 23.14/0.5782\nAvg PSNR/SSIM 16x: 20.81/0.5050\n\nBSDS100\nAvg PSNR/SSIM 2x: 33.73/0.9226\nAvg PSNR/SSIM 3x: 29.44/0.8143\nAvg PSNR/SSIM 4x: 28.22/0.7552\nAvg PSNR/SSIM 8x: 24.21/0.5729\nAvg PSNR/SSIM 16x: 22.27/0.5148\n'