In [None]:
import lightning as L
from model import BaseLineUnet
import torch
import torch.nn as nn
from dataset import GoProDataset
from torchvision.transforms import v2
import matplotlib.pyplot as plt
import random
from PIL import Image
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure, LearnedPerceptualImagePatchSimilarity
from tqdm import tqdm

In [None]:
import glob
ckpt_paths = glob.glob('lightning_logs/version_15/**/*.ckpt', recursive=True)
ckpt_paths

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = BaseLineUnet(num_encoder_blocks=[1,1,1,14])
model.to(device)

In [None]:
psnr_fn = PeakSignalNoiseRatio().to(device)
ssim_fn = StructuralSimilarityIndexMeasure().to(device)
lpips_fn = LearnedPerceptualImagePatchSimilarity(normalize=True).to(device)
psnr_arr = []
ssim_arr = []
lpips_arr = []

In [None]:
def big_img_inference(img_tensor, sharp_tensor):
    B, C, H, W = img_tensor.shape
    pad_h = 256 - (H % 256) if H % 256 != 0 else 0
    pad_w = 256 - (W % 256) if W % 256 != 0 else 0
    img_tensor = F.pad(img_tensor, (0, pad_w, 0, pad_h))
    out_tensor = torch.zeros_like(img_tensor, device=device)
    with torch.inference_mode():
        out_tensor = model(img_tensor)
                
    out_tensor = torch.clamp(out_tensor, 0, 1)
    out_tensor = out_tensor[:, :, :out_tensor.shape[2]-pad_h, :out_tensor.shape[3]-pad_w]
    psnr_arr.append(psnr_fn(out_tensor, sharp_tensor))
    ssim_arr.append(ssim_fn(out_tensor, sharp_tensor))
    lpips_loss = lpips_fn(out_tensor, sharp_tensor)
    lpips_arr.append(lpips_loss)

In [None]:
dataset = GoProDataset('E:\\Downloads\\GOPRO_Large\\test', mode='test')
dataloader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=11)

In [None]:
for path in ckpt_paths:
    psnr_arr = []
    ssim_arr = []
    lpips_arr = []
    checkpoint = torch.load(path)
    model_weights = checkpoint["state_dict"]
    for key in list(model_weights):
        model_weights[key.replace("model.", "")] = model_weights.pop(key)
    for key in list(model_weights):
        if key.startswith("loss_fn."):
            model_weights.pop(key)
    model.load_state_dict(model_weights)
    model.eval()
    for sample in tqdm(dataloader):
        x, y = sample
        x = x.to(device)
        y = y.to(device)
        big_img_inference(x, y)

    psnr_avg = sum(psnr_arr) / len(psnr_arr)
    ssim_avg = sum(ssim_arr) / len(ssim_arr)
    lpips_avg = sum(lpips_arr) / len(lpips_arr)
    print(path)
    print(f"{psnr_avg.item():.4f}\t{ssim_avg.item():.4f}\t{lpips_avg.item():.4f}")