## Metrics (PSNR + SSIM) Playground

In [1]:
import os 
import sys 
import torch 
import torch.nn as nn
import torch.nn.functional as F

In [2]:
def psnr(output, target, max_val=1.0):
    mse = nn.functional.mse_loss(output, target)
    if mse == 0:
        return float('inf')  # PSNR is infinite if there is no error
    return 20 * torch.log10(max_val / torch.sqrt(mse))

def ssim(output, target, max_val=1.0):
    C1 = (0.01 * max_val) ** 2
    C2 = (0.03 * max_val) ** 2
    
    mu_x = F.avg_pool2d(output, kernel_size=11, stride=1, padding=5)
    mu_y = F.avg_pool2d(target, kernel_size=11, stride=1, padding=5)
    
    sigma_x = F.avg_pool2d(output * output, kernel_size=11, stride=1, padding=5) - mu_x * mu_x
    sigma_y = F.avg_pool2d(target * target, kernel_size=11, stride=1, padding=5) - mu_y * mu_y
    sigma_xy = F.avg_pool2d(output * target, kernel_size=11, stride=1, padding=5) - mu_x * mu_y
    
    ssim_map = ((2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2)) / ((mu_x ** 2 + mu_y ** 2 + C1) * (sigma_x + sigma_y + C2))
    
    return torch.mean(ssim_map)


In [3]:
# print working directory 
print("Current working directory:", os.getcwd())

Current working directory: /export1/project/mingi/MRI-Reconstruction


In [4]:
# Load images function 
def load_image(image_path):
    from PIL import Image
    import torchvision.transforms as transforms
    image = Image.open(image_path).convert('RGB')
    transform = transforms.ToTensor()
    return transform(image).unsqueeze(0)  # Add batch dimension



In [8]:
clean_dir = "data/4000_img/Clean/0336.png"
noisy_dir = "data/4000_img/GRAPPA_acc2/image_0336.png"
clean_image = load_image(clean_dir)
noisy_image = load_image(noisy_dir)

print(clean_image.shape, noisy_image.shape)

psnr_value = psnr(noisy_image, clean_image)
ssim_value = ssim(noisy_image, clean_image)

print("PSNR value: ", psnr_value.item())
print("SSIM value: ", ssim_value.item())

torch.Size([1, 3, 320, 320]) torch.Size([1, 3, 320, 320])
PSNR value:  30.54088592529297
SSIM value:  0.8830821514129639


In [6]:
clean_dir = "data/4000_img/GRAPPA_acc2/image_0336.png"
noisy_dir = "data/4000_img/GRAPPA_acc6/image_0336.png"
clean_image = load_image(clean_dir)
noisy_image = load_image(noisy_dir)

print(clean_image.shape, noisy_image.shape)

psnr_value = psnr(noisy_image, clean_image)
ssim_value = ssim(noisy_image, clean_image)

print("PSNR value: ", psnr_value.item())
print("SSIM value: ", ssim_value.item())

torch.Size([1, 3, 320, 320]) torch.Size([1, 3, 320, 320])
PSNR value:  32.79168701171875
SSIM value:  0.9352624416351318
