In [2]:
import sys
import os
import torch
import numpy as np
from pathlib import Path
from PIL import Image
from torchvision.transforms import ToTensor
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

current_dir = Path(os.getcwd())
project_root = current_dir.parent
sys.path.append(str(project_root))

from src.train.model import Generator

In [3]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CHECKPOINT_EPOCH = 35 # Select best epoch
UPSCALE_FACTOR = 2

# Initialize Model
model = Generator(upscale_factor=UPSCALE_FACTOR).to(DEVICE)

# Load Weights
ckpt_path = project_root / "outputs" / "checkpoints" / f"checkpoint_epoch_{CHECKPOINT_EPOCH}.pth"
checkpoint = torch.load(ckpt_path, map_location=DEVICE)
if 'G_state' in checkpoint:
    model.load_state_dict(checkpoint['G_state'])
else:
    model.load_state_dict(checkpoint)
model.eval()

Generator(
  (conv1): Conv2d(3, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (prelu): PReLU(num_parameters=1)
  (res_blocks): Sequential(
    (0): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (prelu): PReLU(num_parameters=1)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (prelu): PReLU(num_parameters=1)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): Residua

In [6]:
# Folder containing HR images for testing (e.g., Set5)
test_dir = project_root / "data" / "custom" / "Set14"
test_files = list(test_dir.glob("*.png")) + list(test_dir.glob("*.jpg"))

total_psnr = 0
total_ssim = 0
count = 0

print(f"{'Image':<20} | {'PSNR':<10} | {'SSIM':<10}")
print("-" * 45)

for img_path in test_files:
    # 1. Prepare HR
    hr_img = Image.open(img_path).convert("RGB")
    w, h = hr_img.size
    new_w = w - (w % UPSCALE_FACTOR)
    new_h = h - (h % UPSCALE_FACTOR)
    hr_img = hr_img.resize((new_w, new_h), Image.BICUBIC)
    
    # 2. Simulate LR
    lr_img = hr_img.resize((new_w // UPSCALE_FACTOR, new_h // UPSCALE_FACTOR), Image.BICUBIC)
    
    # 3. Inference
    lr_tensor = ToTensor()(lr_img).unsqueeze(0).to(DEVICE)
    lr_tensor = lr_tensor * 2 - 1 # Normalize
    
    with torch.no_grad():
        sr_tensor = model(lr_tensor)
        
    # 4. Post-process
    sr_tensor = (sr_tensor.squeeze(0).cpu().clamp(-1, 1) + 1) / 2
    sr_np = sr_tensor.permute(1, 2, 0).numpy()
    hr_np = np.array(hr_img) / 255.0 # Normalize to [0,1] for skimage
    
    # 5. Calculate Metrics
    p_val = psnr(hr_np, sr_np, data_range=1.0)
    s_val = ssim(hr_np, sr_np, data_range=1.0, channel_axis=2)
    
    total_psnr += p_val
    total_ssim += s_val
    count += 1
    
    print(f"{img_path.name:<20} | {p_val:.4f}     | {s_val:.4f}")

print("-" * 45)
print(f"Average PSNR: {total_psnr/count:.4f}")
print(f"Average SSIM: {total_ssim/count:.4f}")

Image                | PSNR       | SSIM      
---------------------------------------------
baboon.png           | 22.3416     | 0.7001
barbara.png          | 26.2287     | 0.8507
bridge.png           | 26.6603     | 0.8408
coastguard.png       | 27.0375     | 0.7634
comic.png            | 27.0506     | 0.9228
face.png             | 30.8454     | 0.8051
flowers.png          | 30.4088     | 0.9048
foreman.png          | 32.8445     | 0.9552
lenna.png            | 32.4981     | 0.8533
man.png              | 28.9449     | 0.8602
monarch.png          | 35.2825     | 0.9592
pepper.png           | 31.0394     | 0.8309
ppt3.png             | 28.0242     | 0.9550
zebra.png            | 31.1504     | 0.9291
---------------------------------------------
Average PSNR: 29.3112
Average SSIM: 0.8665
