In [70]:
import math
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

def init_weights(modules):
    pass
   

class MeanShift(nn.Module):
    def __init__(self, mean_rgb, sub):
        super(MeanShift, self).__init__()

        sign = -1 if sub else 1
        r = mean_rgb[0] * sign
        g = mean_rgb[1] * sign
        b = mean_rgb[2] * sign

        self.shifter = nn.Conv2d(3, 3, 1, 1, 0)
        self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1)
        self.shifter.bias.data   = torch.Tensor([r, g, b])

        # Freeze the mean shift layer
        for params in self.shifter.parameters():
            params.requires_grad = False

    def forward(self, x):
        x = self.shifter(x)
        return x


class BasicBlock(nn.Module):
    def __init__(self,
                 in_channels, out_channels,
                 ksize=3, stride=1, pad=1):
        super(BasicBlock, self).__init__()

        self.body = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, ksize, stride, pad),
            nn.ReLU(inplace=True)
        )

        init_weights(self.modules)
        
    def forward(self, x):
        out = self.body(x)
        return out


class ResidualBlock(nn.Module):
    def __init__(self, 
                 in_channels, out_channels):
        super(ResidualBlock, self).__init__()

        self.body = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
        )

        init_weights(self.modules)
        
    def forward(self, x):
        out = self.body(x)
        out = F.relu(out + x)
        return out


class EResidualBlock(nn.Module):
    def __init__(self, 
                 in_channels, out_channels,
                 group=1):
        super(EResidualBlock, self).__init__()

        self.body = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, groups=group),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, groups=group),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 1, 1, 0),
        )

        init_weights(self.modules)
        
    def forward(self, x):
        out = self.body(x)
        out = F.relu(out + x)
        return out


class UpsampleBlock(nn.Module):
    def __init__(self, 
                 n_channels):
        super(UpsampleBlock, self).__init__()
        self.up =  _UpsampleBlock(n_channels)

    def forward(self, x, scale):
        return self.up(x)


class _UpsampleBlock(nn.Module):
    def __init__(self, n_channels):
        super(_UpsampleBlock, self).__init__()

        modules = []
        modules += [nn.Conv2d(n_channels, 4*n_channels, 3, 1, 1), nn.ReLU(inplace=True)]
        modules += [nn.PixelShuffle(2)]

        self.body = nn.Sequential(*modules)
        
    def forward(self, x):
        out = self.body(x)
        return out



In [71]:
import torch
import torch.nn as nn

class Block(nn.Module):
    def __init__(self, 
                 in_channels, out_channels,
                 group=1):
        super(Block, self).__init__()

        self.b1 = ResidualBlock(64, 64)
        self.b2 = ResidualBlock(64, 64)
        self.b3 = ResidualBlock(64, 64)
        self.c1 = BasicBlock(64*2, 64, 1, 1, 0)
        self.c2 = BasicBlock(64*3, 64, 1, 1, 0)
        self.c3 = BasicBlock(64*4, 64, 1, 1, 0)

    def forward(self, x):
        c0 = o0 = x

        b1 = self.b1(o0)
        c1 = torch.cat([c0, b1], dim=1)
        o1 = self.c1(c1)
        
        b2 = self.b2(o1)
        c2 = torch.cat([c1, b2], dim=1)
        o2 = self.c2(c2)
        
        b3 = self.b3(o2)
        c3 = torch.cat([c2, b3], dim=1)
        o3 = self.c3(c3)

        return o3
        

class CARN(nn.Module):
    def __init__(self):
        super(CARN, self).__init__()

        group = 1

        self.sub_mean = MeanShift((0.4488, 0.4371, 0.4040), sub=True)
        self.add_mean = MeanShift((0.4488, 0.4371, 0.4040), sub=False)
        
        self.entry = nn.Conv2d(3, 64, 3, 1, 1)

        self.b1 = Block(64, 64)
        self.b2 = Block(64, 64)
        self.b3 = Block(64, 64)
        self.c1 = BasicBlock(64*2, 64, 1, 1, 0)
        self.c2 = BasicBlock(64*3, 64, 1, 1, 0)
        self.c3 = BasicBlock(64*4, 64, 1, 1, 0)
        
        self.upsample = UpsampleBlock(64, group=group)
        self.exit = nn.Conv2d(64, 3, 3, 1, 1)
                
    def forward(self, x):
        # x = self.sub_mean(x)
        # x = self.entry(x)
        c0 = o0 = x

        b1 = self.b1(o0)
        c1 = torch.cat([c0, b1], dim=1)
        o1 = self.c1(c1)
        
        b2 = self.b2(o1)
        c2 = torch.cat([c1, b2], dim=1)
        o2 = self.c2(c2)
        
        b3 = self.b3(o2)
        c3 = torch.cat([c2, b3], dim=1)
        o3 = self.c3(c3)


        out = self.upsample(o3, scale=2)

        # out = self.exit(out)
        # out = self.add_mean(out)

        return out


In [72]:
import torch
import torch.nn as nn

class LapCARN(nn.Module):
    def __init__(self):
        super(LapCARN, self).__init__()

        # For upsampling image with 3 channels (Image Reconstuction)
        self.upsample_3 = UpsampleBlock(3, group=1)
        # For upsampling image with 64 channels (Feature Extraction)
        self.upsample_64 = UpsampleBlock(64, group=1)
        self.feat = self.make_layer(CARN)

        self.sub_mean = MeanShift((0.4488, 0.4371, 0.4040), sub=True)
        self.add_mean = MeanShift((0.4488, 0.4371, 0.4040), sub=False)

        self.entry = nn.Conv2d(3, 64, 3, 1, 1)
        self.exit = nn.Conv2d(64, 3, 3, 1, 1)

    def make_layer(self, block):
        layers = []
        layers.append(block())
        return nn.Sequential(*layers)
                
    def forward(self, x):
        out = self.sub_mean(x)
        out = self.entry(out)
        
        convt_F1 = self.feat(out)
        convt_R1 = self.add_mean(self.exit(convt_F1))
        convt_I1 = self.upsample_3(x)
        HR_2x = convt_I1 + convt_R1
        
        convt_F2 = self.feat(convt_F1)
        convt_R2 = self.add_mean(self.exit(convt_F2))
        convt_I2 = self.upsample_3(HR_2x)
        HR_4x = convt_I2 + convt_R2

        convt_F3 = self.feat(convt_F2)
        convt_R3 = self.add_mean(self.exit(convt_F3))
        convt_I3 = self.upsample_3(HR_4x)
        HR_8x = convt_I3 + convt_R3
       
        return HR_2x, HR_4x, HR_8x


In [73]:
import torch
import numpy as np
from scipy.ndimage import gaussian_filter


def convert_rgb_to_y(img, dim_order='hwc'):
    if dim_order == 'hwc':
        return 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
    else:
        return 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.


def convert_rgb_to_ycbcr(img, dim_order='hwc'):
    if dim_order == 'hwc':
        y = 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
        cb = 128. + (-37.945 * img[..., 0] - 74.494 * img[..., 1] + 112.439 * img[..., 2]) / 256.
        cr = 128. + (112.439 * img[..., 0] - 94.154 * img[..., 1] - 18.285 * img[..., 2]) / 256.
    else:
        y = 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.
        cb = 128. + (-37.945 * img[0] - 74.494 * img[1] + 112.439 * img[2]) / 256.
        cr = 128. + (112.439 * img[0] - 94.154 * img[1] - 18.285 * img[2]) / 256.
    return np.array([y, cb, cr]).transpose([1, 2, 0])


def convert_ycbcr_to_rgb(img, dim_order='hwc'):
    if dim_order == 'hwc':
        r = 298.082 * img[..., 0] / 256. + 408.583 * img[..., 2] / 256. - 222.921
        g = 298.082 * img[..., 0] / 256. - 100.291 * img[..., 1] / 256. - 208.120 * img[..., 2] / 256. + 135.576
        b = 298.082 * img[..., 0] / 256. + 516.412 * img[..., 1] / 256. - 276.836
    else:
        r = 298.082 * img[0] / 256. + 408.583 * img[2] / 256. - 222.921
        g = 298.082 * img[0] / 256. - 100.291 * img[1] / 256. - 208.120 * img[2] / 256. + 135.576
        b = 298.082 * img[0] / 256. + 516.412 * img[1] / 256. - 276.836
    return np.array([r, g, b]).transpose([1, 2, 0])


def preprocess(img, device):
    img = np.array(img).astype(np.float32)
    ycbcr = convert_rgb_to_ycbcr(img)
    x = ycbcr[..., 0]
    x /= 255.
    x = torch.from_numpy(x).to(device)
    x = x.unsqueeze(0).unsqueeze(0)
    return x, ycbcr


def calc_psnr(img1, img2):
#     print("PSNR")
#     print(img1.shape)
#     print(img2.shape)
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))


def calc_ssim(img1, img2, sd=1.5, C1=0.01**2, C2=0.03**2):
    img1 = img1.cpu()
    img2 = img2.cpu()
    mu1 = gaussian_filter(img1, sd)
    mu2 = gaussian_filter(img2, sd)
    mu1_sq = mu1 * mu1
    mu2_sq = mu2 * mu2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = gaussian_filter(img1 * img1, sd) - mu1_sq
    sigma2_sq = gaussian_filter(img2 * img2, sd) - mu2_sq
    sigma12 = gaussian_filter(img1 * img2, sd) - mu1_mu2
    
    ssim_num = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2))
    ssim_den = ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    ssim_map = ssim_num / ssim_den
    mssim = np.mean(ssim_map)
    
    return mssim


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [74]:
import torchvision.transforms as transforms
import torch

device = torch.device('cuda:1,2' if torch.cuda.is_available() else 'cpu')

DEBUG = False
SAVE = True

def get_results(hr, preds, filename, resolution, dataset, model_name, debug=DEBUG, save=SAVE):
    hr, ycbcr = preprocess(hr, device)

    try:
        pil_preds = transforms.ToPILImage()(preds.cpu().squeeze(0))
        pil_preds.show()
        preds, _ = preprocess(pil_preds, device)
    except AttributeError:
        preds, _ = preprocess(preds, device)
        pass
    

    psnr = calc_psnr(hr, preds)
    ssim = calc_ssim(hr, preds)

    if debug:
        print(f'PSNR/SSIM: {psnr:.2f}/{ssim:.4f}') 
    
    preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)

    output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
    output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
    output = pil_image.fromarray(output)
    if save:
        save_path = f'results-final/{dataset}/{resolution}/{filename}'
        output.save(save_path.replace('.', '_' + model_name + '.'))
    with open('results.csv', 'a') as f:
        f.write(f'{resolution},{dataset},{filename},{model_name},{psnr},{ssim}\n')
    return float(psnr), float(ssim)

In [75]:
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_image

import matplotlib.pyplot as plt
import cv2


from torchvision import transforms
convert_tensor = transforms.ToTensor()

def test(weights_file, image_file, scale, dataset, debug=DEBUG):
    cudnn.benchmark = True

    model = LapCARN()
    model.load_state_dict(torch.load(weights_file))
    model.eval()
    model.to(device)

    image = pil_image.open(image_file).convert('RGB')    
    image_file = os.path.basename(image_file)
    
    image_width, image_height = image.size
    image_width = (image_width // scale) * scale
    image_height = (image_height // scale) * scale
    
    hr = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
    lr = hr.resize((image_width // scale, image_height // scale), resample=pil_image.BICUBIC)
    bicubic = lr.resize((image_width, image_height), resample=pil_image.BICUBIC)
    
    lr_tensor = convert_tensor(lr)
    lr_tensor = lr_tensor.unsqueeze(0)
    lr_tensor = lr_tensor.to(device)

    with torch.no_grad():
        preds = model(lr_tensor)

#     plt.imshow(  lr_tensor.cpu().squeeze(0).permute(1, 2, 0)  )
#     plt.show()
    
#     preds[preds>=1]=1
#     preds[preds<=0]=0

    preds_2x, preds_4x, preds_8x = preds
    preds_2x[preds_2x>=1], preds_2x[preds_2x<=0] = 1, 0
    preds_4x[preds_4x>=1], preds_4x[preds_4x<=0] = 1, 0
    preds_8x[preds_8x>=1], preds_8x[preds_8x<=0] = 1, 0
    
    '''
    if debug:
        print("2x")

        plt.imshow(  preds_2x.cpu().squeeze(0).permute(1, 2, 0)  )
        plt.show()

        plt.imshow(  hr_2x  )
        plt.show()

        print("4x")

        plt.imshow(  preds_4x.cpu().squeeze(0).permute(1, 2, 0)  )
        plt.show()

        plt.imshow(  hr_4x  )
        plt.show()

        print("8x")

        plt.imshow(  preds_8x.cpu().squeeze(0).permute(1, 2, 0)  )
        plt.show()

        plt.imshow(  hr_8x  )
        plt.show()
    '''
    if scale == 2:
        get_results(hr, preds_2x, image_file, f"{scale}x", dataset, 'lapcarn')
    if scale == 4:
        get_results(hr, preds_4x, image_file, f"{scale}x", dataset, 'lapcarn')
    if scale == 8:
        get_results(hr, preds_8x, image_file, f"{scale}x", dataset, 'lapcarn')
    
    
    '''
    if scale == 2:
        get_results(hr, bicubic, image_file, f"{scale}x", dataset, 'bicubic')
    if scale == 4:
        get_results(hr, bicubic, image_file, f"{scale}x", dataset, 'bicubic')
    if scale == 8:
        get_results(hr, bicubic, image_file, f"{scale}x", dataset, 'bicubic')
    '''
#     results = {}
#     results["2x"] = get_results(hr_2x, preds_2x, image_file, "2x")
#     results["4x"] = get_results(hr_4x, preds_4x, image_file, "4x")
#     results["8x"] = get_results(hr_8x, preds_8x, image_file, "8x")
    return None

In [84]:
import os

def do_test(psnr, ssim, BASE_DIR, scale, dataset, debug=DEBUG):
    scales = [2]
    psnr = {"2x": 0, "4x": 0, "8x": 0}
    ssim = {"2x": 0, "4x": 0, "8x": 0}
    total = 0
    for file in os.listdir(BASE_DIR):
        if file.endswith(".png"):
            image_file_path = os.path.join(BASE_DIR, file)
            if debug:
                print(file)
            test(f'checkpoint/lapcarn_579400.pth', image_file_path, scale, dataset)
#             psnr["2x"] += results["2x"][0]
#             ssim["2x"] += results["2x"][1]
#             psnr["4x"] += results["4x"][0]
#             ssim["4x"] += results["4x"][1]
#             psnr["8x"] += results["8x"][0]
#             ssim["8x"] += results["8x"][1]
#             total += 1
#     psnr["2x"] = psnr["2x"] / total
#     psnr["4x"] = psnr["4x"] / total
#     psnr["8x"] = psnr["8x"] / total
#     ssim["2x"] = ssim["2x"] / total
#     ssim["4x"] = ssim["4x"] / total
#     ssim["8x"] = ssim["8x"] / total
    if DEBUG:
        print(psnr)
        print(ssim)

In [85]:
psnr = {}
ssim = {}
DATASETS = ['Set5', 'Set14', 'BSDS100', 'Urban100', 'Manga109']
for DATASET in DATASETS:
    for i in [2]:
        do_test(psnr, ssim, f'dataset/test/{DATASET}/', i, DATASET)