## README

This part is for testing image, the dataset is using Set 5 to test on the performance through this website(https://deepai.org/dataset/set5-super-resolution)

The weighted file should be download through (https://github.com/spmallick/learnopencv.git) in the Super-Resolution-in-OpenCV folder. 
Make sure to download the dataset in the images directory, the result will show in the image folder as well. 

Changes:

I had changed the overall structure and the model.

### Note

If this code had run before the result image will be in the same folder, so you might need to delete the result image before run this program again.

In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [2]:
import argparse

import tensorflow as tf
import math
from os import listdir
import torch
from torch import nn
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_image 
from PIL import ImageFile, ImageDraw, ImageChops, ImageFilter
from tqdm import tqdm


class ESPCN(nn.Module):
    def __init__(self, scale_factor, num_channels=1):
        super(ESPCN, self).__init__()
        self.first_part = nn.Sequential(
            nn.Conv2d(num_channels, 64, kernel_size=5, padding=5//2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=3//2),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=3//2),
            nn.ReLU(),
        )
        self.last_part = nn.Sequential(
            nn.Conv2d(32, num_channels * (scale_factor ** 2), kernel_size=3, padding=3 // 2),
            nn.PixelShuffle(scale_factor)
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m.in_channels == 32:
                    nn.init.normal_(m.weight.data, mean=0.0, std=0.001)
                    nn.init.zeros_(m.bias.data)
                else:
                    nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                    nn.init.zeros_(m.bias.data)

    def forward(self, x):
        x = self.first_part(x)
        x = self.last_part(x)
        return x


def calc_patch_size(func):
    def wrapper(args):
        if args.scale == 2:
            args.patch_size = 10
        elif args.scale == 3:
            args.patch_size = 7
        elif args.scale == 4:
            args.patch_size = 6
        else:
            raise Exception('Scale Error', args.scale)
        return func(args)
    return wrapper


def convert_rgb_to_y(img, dim_order='hwc'):
    if dim_order == 'hwc':
        return 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
    else:
        return 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.


def convert_rgb_to_ycbcr(img, dim_order='hwc'):
    if dim_order == 'hwc':
        y = 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
        cb = 128. + (-37.945 * img[..., 0] - 74.494 * img[..., 1] + 112.439 * img[..., 2]) / 256.
        cr = 128. + (112.439 * img[..., 0] - 94.154 * img[..., 1] - 18.285 * img[..., 2]) / 256.
    else:
        y = 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.
        cb = 128. + (-37.945 * img[0] - 74.494 * img[1] + 112.439 * img[2]) / 256.
        cr = 128. + (112.439 * img[0] - 94.154 * img[1] - 18.285 * img[2]) / 256.
    return np.array([y, cb, cr]).transpose([1, 2, 0])


def convert_ycbcr_to_rgb(img, dim_order='hwc'):
    if dim_order == 'hwc':
        r = 298.082 * img[..., 0] / 256. + 408.583 * img[..., 2] / 256. - 222.921
        g = 298.082 * img[..., 0] / 256. - 100.291 * img[..., 1] / 256. - 208.120 * img[..., 2] / 256. + 135.576
        b = 298.082 * img[..., 0] / 256. + 516.412 * img[..., 1] / 256. - 276.836
    else:
        r = 298.082 * img[0] / 256. + 408.583 * img[2] / 256. - 222.921
        g = 298.082 * img[0] / 256. - 100.291 * img[1] / 256. - 208.120 * img[2] / 256. + 135.576
        b = 298.082 * img[0] / 256. + 516.412 * img[1] / 256. - 276.836
    return np.array([r, g, b]).transpose([1, 2, 0])


def preprocess(img, device):
    img = np.array(img).astype(np.float32)
    ycbcr = convert_rgb_to_ycbcr(img)
    x = ycbcr[..., 0]
    x /= 255.
    x = torch.from_numpy(x).to(device)
    x = x.unsqueeze(0).unsqueeze(0)
    return x, ycbcr


def calc_psnr(img1, img2):
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

import math
from torch import nn


class ESPCN(nn.Module):
    def __init__(self, scale_factor, num_channels=1):
        super(ESPCN, self).__init__()
        self.first_part = nn.Sequential(
            nn.Conv2d(num_channels, 64, kernel_size=5, padding=5//2),
            nn.Tanh(),
            nn.Conv2d(64, 32, kernel_size=3, padding=3//2),
            nn.Tanh(),
        )
        self.last_part = nn.Sequential(
            nn.Conv2d(32, num_channels * (scale_factor ** 2), kernel_size=3, padding=3 // 2),
            nn.PixelShuffle(scale_factor)
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m.in_channels == 32:
                    nn.init.normal_(m.weight.data, mean=0.0, std=0.001)
                    nn.init.zeros_(m.bias.data)
                else:
                    nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                    nn.init.zeros_(m.bias.data)

    def forward(self, x):
        x = self.first_part(x)
        x = self.last_part(x)
        return x

if __name__ == '__main__':
    weights_file = '/content/drive/MyDrive/ECE570Project/ESPCNV1/espcn_x3.pth'
    image_file = '/content/drive/MyDrive/ECE570Project/ESPCNV1/images/'
    scale = 3
    images_name = [x for x in listdir(image_file)]

    total_test_psnr = 0.0
    cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    model = ESPCN(scale_factor=scale).to(device)

    state_dict = model.state_dict()
    for n, p in torch.load(weights_file, map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)

    model.eval()


    for image_name in tqdm(images_name):
        image = pil_image.open(image_file + image_name).convert('RGB')

        image_width = (image.width // scale) * scale
        image_height = (image.height // scale) * scale

        hr = image.resize((image_width, image_height))
        img_blur = hr.copy()
        lowres_input = img_blur.filter(ImageFilter.GaussianBlur)

        lr = lowres_input.resize((hr.width // scale, hr.height // scale))
        bicubic = lr.resize((lr.width * scale, lr.height * scale), resample=pil_image.BICUBIC)
        bicubic.save(image_file+image_name.replace('.', '_bicubic_x{}.'.format(scale)))

        lr, _ = preprocess(lr, device)
        hr, _ = preprocess(hr, device)
        _, ycbcr = preprocess(bicubic, device)


        with torch.no_grad():
            preds = model(lr).clamp(0.0, 1.0)

        psnr = calc_psnr(hr, preds)
        total_test_psnr+=psnr
        
        
        print('PSNR: {:.2f}'.format(psnr))

        preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)

        output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
        output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
        output = pil_image.fromarray(output)
        output.save(image_file+image_name.replace('.', '_espcn_x{}.'.format(scale)))
lens = len(images_name)
print("Avg. PSNR of lowres images is %.4f" % (total_test_psnr / lens))


  6%|▌         | 1/17 [00:01<00:17,  1.11s/it]

PSNR: 27.35
PSNR: 22.03


 12%|█▏        | 2/17 [00:03<00:30,  2.02s/it]

PSNR: 22.51


 18%|█▊        | 3/17 [00:06<00:29,  2.12s/it]

PSNR: 31.42


 24%|██▎       | 4/17 [00:08<00:28,  2.17s/it]

PSNR: 31.98


 47%|████▋     | 8/17 [00:10<00:07,  1.23it/s]

PSNR: 34.63
PSNR: 33.69
PSNR: 39.65


 59%|█████▉    | 10/17 [00:10<00:03,  2.08it/s]

PSNR: 38.48
PSNR: 40.07


 76%|███████▋  | 13/17 [00:10<00:00,  4.17it/s]

PSNR: 39.45
PSNR: 43.42
PSNR: 42.75


 82%|████████▏ | 14/17 [00:11<00:00,  3.24it/s]

PSNR: 35.50


 88%|████████▊ | 15/17 [00:12<00:00,  2.31it/s]

PSNR: 35.26


 94%|█████████▍| 16/17 [00:12<00:00,  2.06it/s]

PSNR: 35.20


100%|██████████| 17/17 [00:13<00:00,  1.25it/s]

PSNR: 34.92
Avg. PSNR of lowres images is 34.6073



