In [1]:
import os
from PIL import Image
import torch
import torchvision.transforms as transforms
from skimage.metrics import structural_similarity as ssim_metric
from skimage.metrics import peak_signal_noise_ratio as psnr_metric
import numpy as np
from tqdm import tqdm

In [2]:
def calculate_metrics(pred_dir, gt_dir):
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    ssim_scores = []
    psnr_scores = []

    image_files = sorted([f for f in os.listdir(pred_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])

    for img_name in tqdm(image_files, desc="Evaluating images"):
        pred_path = os.path.join(pred_dir, img_name)
        gt_path = os.path.join(gt_dir, img_name)

        if not os.path.exists(gt_path):
            print(f"Warning: Ground truth for {img_name} not found. Skipping.")
            continue

        pred_img = Image.open(pred_path).convert("RGB")
        gt_img = Image.open(gt_path).convert("RGB")

        pred_tensor = transform(pred_img)
        gt_tensor = transform(gt_img)

        # Convert to numpy arrays for metric calculation
        # Shape will be (C, H, W), e.g., (3, 224, 224)
        pred_np = pred_tensor.numpy()
        gt_np = gt_tensor.numpy()
        
        # --- FIX IS HERE ---
        # Tell ssim_metric that the channel is on axis 0
        ssim_val = ssim_metric(gt_np, pred_np, multichannel=True, data_range=1.0, channel_axis=0)
        
        # psnr_metric does not require channel_axis, it works correctly on (C, H, W)
        psnr_val = psnr_metric(gt_np, pred_np, data_range=1.0)

        ssim_scores.append(ssim_val)
        psnr_scores.append(psnr_val)

    avg_ssim = np.mean(ssim_scores)
    avg_psnr = np.mean(psnr_scores)

    print(f"\nAverage SSIM: {avg_ssim:.4f}")
    print(f"Average PSNR: {avg_psnr:.2f} dB")

In [None]:
pred_dir="data/test/outputs"
gt_dir="data/test/sharp"
calculate_metrics(pred_dir,gt_dir)

Evaluating images: 100%|██████████| 1111/1111 [06:21<00:00,  2.91it/s]


Average SSIM: 0.8688
Average PSNR: 28.85 dB



