## Import Libraries

In [1]:
from skimage.metrics import structural_similarity
import cv2
import lpips
import torchvision
import torch
import torchvision.transforms as Tvt
from DISTS_pytorch import DISTS
import os

  from .autonotebook import tqdm as notebook_tqdm


## Set Parameters

In [2]:
results_file_name = 'metrics_file_cn-200.txt'
pred_type = 'cn-200'
gt_dir = 'frames/hr'
pred_dir = 'frames/{}'.format(pred_type)
video_names = ['000', '011', '015', '020']
img_nums = [i for i in range(100)]

In [3]:
file = open(results_file_name, "w")
file.write('{}\nn'.format(pred_type))

8

## PSNR Claculation

In [4]:
file.write("PSNR \n\n")
psnr_global = []
for vid in video_names:
    psnr_local = []
    for img in img_nums:
        img1 = cv2.cvtColor(cv2.imread(os.path.join(gt_dir, vid, '{:08d}.png'.format(img))), cv2.COLOR_RGB2BGR)
        img2 = cv2.cvtColor(cv2.imread(os.path.join(pred_dir, vid, '{:08d}.png'.format(img))), cv2.COLOR_RGB2BGR)
        temp = cv2.PSNR(img1, img2)
        psnr_local.append(temp)
    psnr_vid = sum(psnr_local)/len(psnr_local)
    psnr_global.append(psnr_vid)
    file.write("PSNR_{}: {}\n".format(vid, psnr_vid))
psnr_total = sum(psnr_global)/len(psnr_global)
file.write("PSNR_TOTAL: {}\n".format(psnr_total))

31

## SSIM Calculation

In [5]:
file.write("\n\n SSIM \n\n")

ssim_global = []
for vid in video_names:
    ssim_local = []
    for img in img_nums:
        img1 = cv2.cvtColor(cv2.imread(os.path.join(gt_dir, vid, '{:08d}.png'.format(img))), cv2.COLOR_RGB2BGR)
        img2 = cv2.cvtColor(cv2.imread(os.path.join(pred_dir, vid, '{:08d}.png'.format(img))), cv2.COLOR_RGB2BGR)
        (score, diff) = structural_similarity(img1, img2, channel_axis = 2, full=True)
        ssim_local.append(score)
    ssim_vid = sum(ssim_local)/len(ssim_local)
    ssim_global.append(ssim_vid)
    file.write("SSIM_{}: {}\n".format(vid, ssim_vid))
ssim_total = sum(ssim_global)/len(ssim_global)
file.write("SSIM_TOTAL: {}\n\n".format(ssim_total))

31

## LPIPS

In [6]:
loss_fn_alex = lpips.LPIPS(net='alex').to('cuda')

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /home/nd2794/miniconda3/envs/vsr/lib/python3.11/site-packages/lpips/weights/v0.1/alex.pth


In [7]:
file.write("\n\n LPIPS \n\n")

lpips_global = []
for vid in video_names:
    lpips_local = []
    for img in img_nums:
        img1 = Tvt.Normalize(mean=0.5, std=0.5)(Tvt.ConvertImageDtype(torch.float32)(torchvision.io.read_image(os.path.join(gt_dir, vid, '{:08d}.png'.format(img)))))
        img1 = torch.unsqueeze(img1, 0).to('cuda')

        img2 = Tvt.Normalize(mean=0.5, std=0.5)(Tvt.ConvertImageDtype(torch.float32)(torchvision.io.read_image(os.path.join(pred_dir, vid, '{:08d}.png'.format(img)))))
        img2 = torch.unsqueeze(img2, 0).to('cuda')

        score = loss_fn_alex(img1, img2).cpu().detach().numpy()[0, 0, 0, 0]
        
        lpips_local.append(score)
    lpips_vid = sum(lpips_local)/len(lpips_local)
    lpips_global.append(lpips_vid)
    file.write("LPIPS_{}: {}\n".format(vid, lpips_vid))
lpips_total = sum(lpips_global)/len(lpips_global)
file.write("LPIPS_TOTAL: {}\n\n".format(lpips_total))

34

## DISTS

In [8]:
D = DISTS()
D.to('cuda')
pass



In [9]:
file.write("\n\n DISTS \n\n")

dists_global = []
for vid in video_names:
    dists_local = []
    for img in img_nums:
        img1 = Tvt.Normalize(mean=0.5, std=0.5)(Tvt.ConvertImageDtype(torch.float32)(torchvision.io.read_image(os.path.join(gt_dir, vid, '{:08d}.png'.format(img)))))
        img1 = img1/2.0+0.5
        img1 = torch.unsqueeze(img1, 0).to('cuda')

        img2 = Tvt.Normalize(mean=0.5, std=0.5)(Tvt.ConvertImageDtype(torch.float32)(torchvision.io.read_image(os.path.join(pred_dir, vid, '{:08d}.png'.format(img)))))
        img2 = img2/2.0+0.5
        img2 = torch.unsqueeze(img2, 0).to('cuda')

        score = loss_fn_alex(img1, img2).cpu().detach().numpy()[0, 0, 0, 0]
        
        dists_local.append(score)
    dists_vid = sum(dists_local)/len(dists_local)
    dists_global.append(dists_vid)
    file.write("DISTS_{}: {}\n".format(vid, dists_vid))
dists_total = sum(dists_global)/len(dists_global)
file.write("DISTS_TOTAL: {}\n\n".format(dists_total))

34

In [10]:
file.close()