In [1]:
import cv2
import numpy as np
import os
from pytorch_msssim import ms_ssim
import torch
from os import path as osp
from piq import psnr

In [2]:
def to_y_channel(img):
    img = img.astype(np.float32) / 255.
    if img.ndim == 3 and img.shape[2] == 3:
        img = bgr2ycbcr(img, y_only=True)
        img = img[..., None]
    return img * 255.

def reorder_image(img, input_order='HWC'):
    if input_order not in ['HWC', 'CHW']:
        raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are '"'HWC' and 'CHW'")
    if len(img.shape) == 2:
        img = img[..., None]
    if input_order == 'CHW':
        img = img.transpose(1, 2, 0)
    img = img.astype(np.float64)
    return img

def reorder_image_msssim(img, input_order='HWC'):
    if input_order not in ['HWC', 'CHW']:
        raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are '"'HWC' and 'CHW'")
    if len(img.shape) == 2:
        img = img[..., None]
    # if input_order == 'CHW':
    #     img = img.transpose(1, 2, 0)
    if input_order == 'HWC':
        img = img.transpose(2, 0, 1)
    img = img.astype(np.float64)
    return img

def _convert_input_type_range(img):
    img_type = img.dtype
    img = img.astype(np.float32)
    if img_type == np.float32:
        pass
    elif img_type == np.uint8:
        img /= 255.
    else:
        raise TypeError('The img type should be np.float32 or np.uint8, 'f'but got {img_type}')
    return img

def _convert_output_type_range(img, dst_type):
    if dst_type not in (np.uint8, np.float32):
        raise TypeError('The dst_type should be np.float32 or np.uint8, 'f'but got {dst_type}')
    if dst_type == np.uint8:
        img = img.round()
    else:
        img /= 255.
    return img.astype(dst_type)

def bgr2ycbcr(img, y_only=False):
    img_type = img.dtype
    img = _convert_input_type_range(img)
    if y_only:
        out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
    else:
        out_img = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
                            [65.481, -37.797, 112.0]]) + [16, 128, 128]
    out_img = _convert_output_type_range(out_img, img_type)
    return out_img

def scandir(dir_path, suffix=None, recursive=False, full_path=False):
    if (suffix is not None) and not isinstance(suffix, (str, tuple)):
        raise TypeError('"suffix" must be a string or tuple of strings')

    root = dir_path

    def _scandir(dir_path, suffix, recursive):
        for entry in os.scandir(dir_path):
            if not entry.name.startswith('.') and entry.is_file():
                if full_path:
                    return_path = entry.path
                else:
                    return_path = osp.relpath(entry.path, root)

                if suffix is None:
                    yield return_path
                elif return_path.endswith(suffix):
                    yield return_path
            else:
                if recursive:
                    yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
                else:
                    continue

    return _scandir(dir_path, suffix=suffix, recursive=recursive)

def calculate_psnr(img1,img2,crop_border,input_order='HWC',test_y_channel=False):
    assert img1.shape == img2.shape, (
        f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
    if input_order not in ['HWC', 'CHW']:
        raise ValueError(
            f'Wrong input_order {input_order}. Supported input_orders are '
            '"HWC" and "CHW"')
    img1 = reorder_image(img1, input_order=input_order)
    img2 = reorder_image(img2, input_order=input_order)

    if crop_border != 0:
        img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
        img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]

    if test_y_channel:
        img1 = to_y_channel(img1)
        img2 = to_y_channel(img2)
    
    return psnr(torch.from_numpy(np.expand_dims(np.clip(img1, 0, 255), axis=0)), torch.from_numpy(np.expand_dims(img2, axis=0)), data_range=255);

def _ssim(img1, img2):
    C1 = (0.01 * 255)**2
    C2 = (0.03 * 255)**2

    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    kernel = cv2.getGaussianKernel(11, 1.5)
    window = np.outer(kernel, kernel.transpose())

    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
    mu1_sq = mu1**2
    mu2_sq = mu2**2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()


def calculate_ssim(img1,img2,crop_border,input_order='HWC',test_y_channel=False):
    assert img1.shape == img2.shape, (
        f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
    if input_order not in ['HWC', 'CHW']:
        raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ''"HWC" and "CHW"')
    img1 = reorder_image(img1, input_order=input_order)
    img2 = reorder_image(img2, input_order=input_order)

    if crop_border != 0:
        img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
        img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]

    if test_y_channel:
        img1 = to_y_channel(img1)
        img2 = to_y_channel(img2)

    ssims = []
    for i in range(img1.shape[2]):
        ssims.append(_ssim(img1[..., i], img2[..., i]))
    return np.array(ssims).mean()

def calculate_ms_ssim(img1,img2,crop_border,input_order='HWC',test_y_channel=False):
    assert img1.shape == img2.shape, (
        f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
    if input_order not in ['HWC', 'CHW']:
        raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ''"HWC" and "CHW"')
    img1 = reorder_image_msssim(img1, input_order=input_order)
    img2 = reorder_image_msssim(img2, input_order=input_order)

    if crop_border != 0:
        img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
        img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]

    if test_y_channel:
        img1 = to_y_channel(img1)
        img2 = to_y_channel(img2)
    
    return ms_ssim(torch.from_numpy(np.array([[img1]])), torch.from_numpy(np.array([[img2]])), data_range=255, size_average=True)

In [3]:
# -------------------------------------------------------------------------
folder_gt = '/Users/clemens/vsc-data/test_images'
# folder_restored = '/home/lv71585/cwansch_w/mbt2018_q8_msssim_fullres/'
crop_border = 0
# suffix = '_x2_SR'
# -------------------------------------------------------------------------

results = []

for dir in [
    'bicubic_supersampled',
    'mbt2018_q4_mse_fullres',
    'mbt2018_q4_msssim_fullres',
    'mbt2018_q6_mse_fullres',
    'mbt2018_q6_msssim_fullres',
    'mbt2018_q8_mse_fullres',
    'mbt2018_q8_msssim_fullres',
    'mbt2018_q4_mse_x2res_supersampled',
    'mbt2018_q4_msssim_x2res_supersampled',
    'mbt2018_q6_mse_x2res_supersampled',
    'mbt2018_q6_msssim_x2res_supersampled',
    'mbt2018_q8_mse_x2res_supersampled',
    'mbt2018_q8_msssim_x2res_supersampled',
]:
    folder_restored = f'/Users/clemens/vsc-data/{dir}/'
    psnr_rgb_all = []
    psnr_y_all = []
    ssim_rgb_all = []
    ssim_y_all = []
    msssim_rgb_all = []
    msssim_y_all = []
    img_list = sorted(scandir(folder_gt, recursive=False, full_path=True))

    print(dir)
    
    suffix = ''
    if dir.endswith('_supersampled'):
        suffix = '_x2_SR'
    for i, img_path in enumerate(img_list):
        basename, ext = osp.splitext(osp.basename(img_path))
        img_gt_rgb = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
        img_restored_rgb = cv2.imread(osp.join(folder_restored, basename + suffix + ext),cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
        
        if img_gt_rgb.shape != img_restored_rgb.shape:
            cut_x = int((img_restored_rgb.shape[0] - img_gt_rgb.shape[0]) / 2)
            cut_y = int((img_restored_rgb.shape[1] - img_gt_rgb.shape[1]) / 2)
            if cut_x == 0:
                img_restored_rgb = img_restored_rgb[:,cut_y:-cut_y]
            elif cut_y == 0:
                img_restored_rgb = img_restored_rgb[cut_x:-cut_x,:]
            else:
                img_restored_rgb = img_restored_rgb[cut_x:-cut_x,cut_y:-cut_y]
        
        img_gt_ycbcr = bgr2ycbcr(img_gt_rgb, y_only=True)
        img_restored_ycbcr = bgr2ycbcr(img_restored_rgb, y_only=True)
        
        if i % 10 == 0:
            print(i)
    
        psnr_rgb = calculate_psnr(img_gt_rgb * 255,img_restored_rgb * 255,crop_border=crop_border,input_order='HWC')
        psnr_y = calculate_psnr(img_gt_ycbcr * 255,img_restored_ycbcr * 255,crop_border=crop_border,input_order='HWC')
        ssim_rgb = calculate_ssim(img_gt_rgb * 255,img_restored_rgb * 255,crop_border=crop_border,input_order='HWC')
        ssim_y = calculate_ssim(img_gt_ycbcr * 255,img_restored_ycbcr * 255,crop_border=crop_border,input_order='HWC')
        # msssim_rgb = calculate_ms_ssim(img_gt_rgb * 255,img_restored_rgb * 255,crop_border=crop_border,input_order='HWC')
        msssim_y = calculate_ms_ssim(img_gt_ycbcr * 255,img_restored_ycbcr * 255,crop_border=crop_border,input_order='HWC')
        psnr_rgb_all.append(psnr_rgb)
        psnr_y_all.append(psnr_y)
        ssim_rgb_all.append(ssim_rgb)
        ssim_y_all.append(ssim_y)
        # msssim_rgb_all.append(msssim_rgb)
        msssim_y_all.append(msssim_y)
        
    psnr_rgb_mean = np.mean(psnr_rgb_all)
    psnr_y_mean = np.mean(psnr_y_all)
    ssim_rgb_mean = np.mean(ssim_rgb_all)
    ssim_y_mean = np.mean(ssim_y_all)
    # msssim_rgb_mean = np.mean(msssim_rgb_all)
    msssim_y_mean = torch.mean(torch.stack(msssim_y_all))
    results.append([dir, psnr_rgb_mean, psnr_y_mean, ssim_rgb_mean, ssim_y_mean, msssim_y_mean])

bicubic_supersampled
0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
400
410
420
430
440
450
460
470
480
490
500
510
520
mbt2018_q4_mse_fullres
0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
400
410
420
430
440
450
460
470
480
490
500
510
520
mbt2018_q4_msssim_fullres
0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
400
410
420
430
440
450
460
470
480
490
500
510
520
mbt2018_q6_mse_fullres
0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
400
410
420
430
440
450
460
470
480
490
500
510
520
mbt2018_q6_msssim_fullres
0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210


In [4]:
from tabulate import tabulate
print(tabulate(results))

------------------------------------  -------  -------  --------  --------  --------
bicubic_supersampled                  38.9306  40.3197  0.969609  0.973673  0.997868
mbt2018_q4_mse_fullres                34.7707  36.6957  0.923346  0.940176  0.984101
mbt2018_q4_msssim_fullres             33.8863  35.5944  0.924993  0.939194  0.989894
mbt2018_q6_mse_fullres                37.6689  39.7326  0.956693  0.967833  0.992928
mbt2018_q6_msssim_fullres             35.8378  37.5315  0.95101   0.960852  0.995046
mbt2018_q8_mse_fullres                41.6856  43.9544  0.980405  0.986601  0.997336
mbt2018_q8_msssim_fullres             38.8819  40.5737  0.973731  0.979525  0.997799
mbt2018_q4_mse_x2res_supersampled     29.5499  31.2033  0.815105  0.843316  0.937425
mbt2018_q4_msssim_x2res_supersampled  29.0169  30.5737  0.803351  0.830897  0.935894
mbt2018_q6_mse_x2res_supersampled     30.5356  32.1893  0.843727  0.867576  0.951083
mbt2018_q6_msssim_x2res_supersampled  29.8517  31.3928  0.826651 