In [1]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy.io import loadmat
from torch.nn import functional

KeyboardInterrupt: 

In [2]:
import torch
from torch.nn import functional

IMG_DIM: int = 4
EPS = 1e-6

class SSIMcal(torch.nn.Module):
    def __init__(
        self,
        win_size: int = 11,
        k1: float = 0.01,
        k2: float = 0.03,
    ):
        super().__init__()
        self.win_size = win_size
        self.k1, self.k2 = k1, k2
        self.register_buffer("w", torch.ones(1, 1, win_size, win_size) / win_size**2)
        np = win_size**2
        self.cov_norm = np / (np - 1)

    def forward(
        self,
        img: torch.Tensor,
        ref: torch.Tensor,
        data_range: torch.Tensor,
    ) -> torch.Tensor:
        data_range = data_range[:, None, None, None]
        C1 = (self.k1 * data_range) ** 2
        C2 = (self.k2 * data_range) ** 2

        ux = functional.conv2d(img, self.w.to(img.device))
        uy = functional.conv2d(ref, self.w.to(img.device))
        uxx = functional.conv2d(img * img, self.w.to(img.device))
        uyy = functional.conv2d(ref * ref, self.w.to(img.device))
        uxy = functional.conv2d(img * ref, self.w.to(img.device))

        vx = self.cov_norm * (uxx - ux * ux)
        vy = self.cov_norm * (uyy - uy * uy)
        vxy = self.cov_norm * (uxy - ux * uy)

        A1 = 2 * ux * uy + C1
        A2 = 2 * vxy + C2
        B1 = ux**2 + uy**2 + C1
        B2 = vx + vy + C2

        S = (A1 * A2) / (B1 * B2)
        return torch.mean(S, dim=[2, 3], keepdim=True)


ssim_cal = SSIMcal()


def calculate_ssim(
    img: torch.Tensor,
    ref: torch.Tensor,
    mask: torch.Tensor | None = None,
) -> torch.Tensor:
    if not (img.dim() == IMG_DIM and ref.dim() == IMG_DIM):
        raise ValueError("All tensors must be 4D.")

    if mask is not None and (mask.dim() != IMG_DIM):
        raise ValueError("Mask must be 4D.")

    if img.shape[1] == 2:
        img = torch.sqrt(img[:, :1, ...] ** 2 + img[:, 1:, ...] ** 2)
        ref = torch.sqrt(ref[:, :1, ...] ** 2 + ref[:, 1:, ...] ** 2)

    img_mask = img
    ref_mask = ref

    ones = torch.ones(ref.shape[0], device=ref.device)
    ssim = ssim_cal.forward(img_mask, ref_mask, ones)
    return ssim


def calculate_psnr(
    img: torch.Tensor,
    ref: torch.Tensor,
    mask: torch.Tensor | None = None,
) -> torch.Tensor:
    if not (img.dim() == IMG_DIM and ref.dim() == IMG_DIM):
        raise ValueError("All tensors must be 4D.")

    if mask is not None and mask.dim() != IMG_DIM:
        raise ValueError("Mask must be 4D.")

    if img.shape[1] == 2:
        img = torch.sqrt(img[:, :1, ...] ** 2 + img[:, 1:, ...] ** 2)
        ref = torch.sqrt(ref[:, :1, ...] ** 2 + ref[:, 1:, ...] ** 2)

    if mask is not None:
        if mask.shape[1] == 2:
            mask = torch.sqrt(mask[:, :1, ...] ** 2 + mask[:, 1:, ...] ** 2)

        img_mask = img * mask
        ref_mask = ref * mask

        mse = torch.sum((img_mask - ref_mask) ** 2, dim=(1, 2, 3)) / torch.sum(mask, dim=(1, 2, 3))
    else:
        mse = torch.mean(functional.mse_loss(img, ref, reduction="none"), dim=(1, 2, 3), keepdim=True)

    img_max = torch.amax(ref, dim=(1, 2, 3), keepdim=True) + EPS
    psnr = 10 * torch.log10(img_max**2 / (mse + EPS))
    return psnr

In [46]:
path = "/home/intern4/fm2026/fm_flow/code_it/logs"
run_idx = 1
run_ep = 8
log_root = Path(path) / f"{run_idx:05d}_train/test/ep_{run_ep}"
res_files = sorted(log_root.glob("*.mat"))
print(len(res_files))

144


In [47]:
import re

try:
    from skimage.metrics import peak_signal_noise_ratio, structural_similarity
    _HAS_SKIMAGE = True
except Exception:
    _HAS_SKIMAGE = False

cnt = 0
psnr_tot = []
ssim_tot = []
psnr_lib_tot = []
ssim_lib_tot = []
            
for idx in range(len(res_files)):
    res_mat = loadmat(res_files[idx])
    input = res_mat["input"].squeeze()
    out = res_mat["out"].squeeze()
    label = res_mat["label"].squeeze()

    instruction = res_mat["instruction"][0]
    if any(x in instruction.flatten() for x in [263]):
        continue
    cnt = cnt + 1

    mask = torch.zeros_like(torch.from_numpy(input))
    mask[label > 1] = 1.0
    if mask.mean() < 0.1:
        continue

    out_t = torch.from_numpy(out[None, None, ...])
    label_t = torch.from_numpy(label[None, None, ...])
    psnr = calculate_psnr(out_t, label_t)
    ssim = calculate_ssim(out_t, label_t)
    psnr_tot.append(psnr.item())
    ssim_tot.append(ssim.item())

    if _HAS_SKIMAGE:
        data_range = float(label.max() - label.min() + EPS)
        psnr_lib = peak_signal_noise_ratio(label, out, data_range=data_range)
        ssim_lib = structural_similarity(label, out, data_range=data_range)
        psnr_lib_tot.append(psnr_lib)
        ssim_lib_tot.append(ssim_lib)

print(f"PSNR: {np.mean(psnr_tot):.2f} ± {np.std(psnr_tot):.2f}")
print(f"SSIM: {np.mean(ssim_tot):.4f} ± {np.std(ssim_tot):.4f}")
if _HAS_SKIMAGE and psnr_lib_tot:
    print(f"PSNR (skimage): {np.mean(psnr_lib_tot):.2f} ± {np.std(psnr_lib_tot):.2f}")
    print(f"SSIM (skimage): {np.mean(ssim_lib_tot):.4f} ± {np.std(ssim_lib_tot):.4f}")
else:
    print("skimage not available; library metrics skipped")
print("count:", cnt)

PSNR: 22.01 ± 3.02
SSIM: 0.6268 ± 0.1158
PSNR (skimage): 23.26 ± 3.00
SSIM (skimage): 0.7775 ± 0.0787
count: 77
