In [10]:
import torch
from torch import nn
import piq

# helper func

In [11]:
def ssim_per_channel(
    pred: torch.Tensor, 
    target: torch.Tensor, 
    data_range: float = 1.0,
    ssim_channels: list = [6, 7, 8, 9, 10]
) -> torch.Tensor:
    """
    對指定通道執行 SSIM 計算，確保每個通道的數據都轉換到 [0, data_range] 範圍。

    :param pred: 模型的預測輸出 (B, C, H, W)。
    :param target: 真實數據 (B, C, H, W)。
    :param data_range: piq.ssim() 期望的數據範圍 (例如 1.0)。
    :param ssim_channels: 欲計算 SSIM 的通道索引列表。若為 None，則計算所有通道。
    :return: 所有指定通道的 SSIM 值的平均值。
    """
    
    # 若未指定通道，則對所有通道計算
    if ssim_channels is None:
        ssim_channels = list(range(pred.size(1)))
        
    ssim_values = []
    
    for c in ssim_channels:
        # 1. 擷取單一通道數據
        pred_c = pred[:, c:c+1, :, :] # 確保 (B, 1, H, W) 格式
        target_c = target[:, c:c+1, :, :]
        
        # 2. 確定當前通道的 min/max 範圍
        # 為了穩健性，我們使用 batch 和 target 的合併 min/max
        min_val = torch.min(pred_c.min(), target_c.min())
        max_val = torch.max(pred_c.max(), target_c.max())
        
        # 3. Min-Max 正規化到 [0, data_range]
        # (X - min) / (max - min) * data_range
        range_val = max_val - min_val
        
        if range_val > 1e-6:
            # 確保分母不為零，且範圍足夠大
            
            # 轉換公式: (X - min) / range_val * data_range
            pred_normalized = (pred_c - min_val) / range_val * data_range
            target_normalized = (target_c - min_val) / range_val * data_range
            
            # 4. 計算單一通道的 SSIM
            # 注意: 這裡設置 reduction='mean' 會對 B, H, W 平均，但只對 C=1
            # 由於 piq.ssim 要求 BCHW 格式，我們保持 B, 1, H, W
            ssim_c = piq.ssim(
                pred_normalized, 
                target_normalized, 
                data_range=data_range,
                reduction='mean'

            )
            ssim_values.append(ssim_c)
        else:
            # 如果數據範圍為零 (例如全為常數)，則 SSIM 視為完美
            ssim_values.append(torch.tensor(1.0, device=pred.device))

    # 5. 返回所有通道 SSIM 的平均值
    if not ssim_values:
        return torch.tensor(0.0, device=pred.device)
        
    return torch.mean(torch.stack(ssim_values))

# ScaledPhysicsMetrics Class

In [12]:

# 修正後的通道索引 (假設從 C=6 開始依序為 u, v, U, TKE, T_uw)
IDX_U_COMPONENTS = 6  # u_ped
IDX_V_COMPONENTS = 7  # v_ped
IDX_U_MAGNITUDE = 8   # U_ped
IDX_TKE = 9           # TKE_ped (修正)
IDX_T_UW = 10         # T_uw (Flux)

class ScaledPhysicsMetrics(nn.Module):
    """
    計算並將多個 Metric 縮放到 [0, 1] 區間的 Metric 類別。
    Baseline 在初始化時根據目標數據集計算。
    """
    def __init__(self, target_dataset: torch.Tensor, alpha: float = 2.25):
        """
        初始化並計算 Baseline (Worst Case)
        :param target_dataset: 完整的目標數據集 (所有 BCHW 數據點合併成的 Tensor)
        :param alpha: Flux Bound 的係數 (|T_uw| <= alpha * TKE)
        """
        super().__init__()
        self.alpha = alpha
        
        # 1. 儲存 Baseline 數據 (僅計算用，不參與梯度)
        target_flat = target_dataset.flatten(start_dim=1) # 攤平為 (N, C*H*W)
        
        # 2. 計算 Baseline Worst Cases (Worst Score = 0, Best Score = 1)

        # Baseline: 使用全零預測來計算 MAE_max
        # 假設 Baseline Worst Case 是預測全部為零 (最笨的模型)
        zero_pred = torch.zeros_like(target_dataset)
        
        # 2a. Data MAE Worst:
        # MAE_max = MAE(zero_pred, target_dataset)
        self.mae_max = torch.abs(target_dataset).mean().item()

        # 2b. L_U (Huber) Worst: 
        # 我們需要一個 Worst Case L_U 來計算 L_U_max
        u_comp_t = target_dataset[:, IDX_U_COMPONENTS]
        v_comp_t = target_dataset[:, IDX_V_COMPONENTS]
        U_mag_t = target_dataset[:, IDX_U_MAGNITUDE]
        
        # 重建 U_mag_recons = sqrt(u^2 + v^2)
        U_recons_t = torch.sqrt(u_comp_t**2 + v_comp_t**2)
        
        # 誤差項: U_mag_t - U_recons_t
        # 由於 U_mag_t >= U_recons_t (湍流效應)，我們取其平均差異作為 L_U Worst Case 的估計
        self.l_u_max = torch.abs(U_mag_t - U_recons_t).mean().item() * 2 # 乘以 2 作為較寬鬆的上限
        
        # 2c. Reconstructed U MAE Worst:
        # M_max = MAE(U_mag_t, U_recons_t) 
        # 假設 Worst Case 是 U 預測為 0
        self.rec_u_mae_max = torch.abs(U_recons_t).mean().item()
        
        print(f"--- Metric Baseline Calculated (Worst Score = 0) ---")
        print(f"MAE Max Baseline: {self.mae_max:.4f}")
        print(f"L_U Max Baseline: {self.l_u_max:.4f}")
        print(f"Rec. U MAE Max Baseline: {self.rec_u_mae_max:.4f}")
        
    def _calculate_metrics(self, pred: torch.Tensor, target: torch.Tensor) -> dict:
        """計算所有原始 Metric 和 Loss (不帶 Scaling)"""
        
        # 1. Data MAE
        mae_raw = torch.abs(pred - target).mean()
        
        # 2. SSIM (PIQ)
        ssim_val_raw = ssim_per_channel(pred, target)
        
        # 3. 物理 Metric
        u_pred = pred[:, IDX_U_COMPONENTS]
        v_pred = pred[:, IDX_V_COMPONENTS]
        U_pred = pred[:, IDX_U_MAGNITUDE]
        TKE_pred = pred[:, IDX_TKE]
        T_uw_pred = pred[:, IDX_T_UW]

        # 3a. L_U (Huber) - 計算 Loss
        U_reconstructed = torch.sqrt(u_pred**2 + v_pred**2)
        U_consistency_error = U_pred - U_reconstructed
        l_u_raw = torch.abs(U_consistency_error).mean() # 簡化為 L1 (MAE) 進行 Metric 監測
        
        # 3b. Reconstructed U MAE
        rec_u_mae_raw = torch.abs(U_pred - U_reconstructed).mean()
        
        # 3c. TKE 負值比例 (P_k<0)
        p_k_neg_raw = (TKE_pred < 0).float().mean()
        
        # 3d. Flux 超界比例 (P_flux)
        flux_bound_error = torch.abs(T_uw_pred) - (self.alpha * (TKE_pred + 1e-8))
        p_flux_over_raw = (flux_bound_error > 0).float().mean()
        
        return {
            'mae': mae_raw,
            'ssim': ssim_val_raw,
            'l_u_mae': l_u_raw,
            'rec_u_mae': rec_u_mae_raw,
            'p_k_neg': p_k_neg_raw,
            'p_flux_over': p_flux_over_raw,
        }

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> dict:
        """計算並返回所有 Metric 的 [0, 1] Scaled Score"""
        
        raw_metrics = self._calculate_metrics(pred, target)
        scaled_scores = {}
        
        # --- 1. Data 領域 (越低越好 -> 1 - M/M_max) ---
        
        # MAE Score: 1 - MAE / MAE_max
        score_mae = 1.0 - (raw_metrics['mae'] / self.mae_max).clamp(max=1.0)
        scaled_scores['data_mae_score'] = score_mae

        # --- 2. Structural Integrity 領域 (越高越好 -> SSIM 本身) ---
        
        # SSIM Score: SSIM (本身已在 [0, 1])
        scaled_scores['structural_ssim_score'] = raw_metrics['ssim']
        
        # --- 3. Physical Plausibility 領域 ---
        
        # TKE 負值比例 (P_k<0): 1 - P_k_neg
        score_p_k_neg = 1.0 - raw_metrics['p_k_neg']
        scaled_scores['physical_p_k_neg_score'] = score_p_k_neg
        
        # Flux 超界比例 (P_flux): 1 - P_flux_over
        score_p_flux = 1.0 - raw_metrics['p_flux_over']
        scaled_scores['physical_p_flux_score'] = score_p_flux
        
        # L_U MAE Score: 1 - L_U_MAE / L_U_MAE_max
        score_l_u = 1.0 - (raw_metrics['l_u_mae'] / self.l_u_max).clamp(max=1.0)
        scaled_scores['physical_l_u_score'] = score_l_u
        
        # Reconstructed U MAE Score: 1 - Rec_U_MAE / Rec_U_MAE_max
        score_rec_u = 1.0 - (raw_metrics['rec_u_mae'] / self.rec_u_mae_max).clamp(max=1.0)
        scaled_scores['physical_rec_u_score'] = score_rec_u

        # --- 4. (可選) 領域總分 ---
        # 這裡可以加入加權平均來建立綜合分數，例如：
        # total_physics_score = (score_p_k_neg + score_p_flux + score_l_u + score_rec_u) / 4.0
        # scaled_scores['total_physics_score'] = total_physics_score
        
        return scaled_scores

In [None]:
x_pool, y_pool = create_pool(loaded_dataset, pool_size=256, channel_c=24)
full_target_data = y_pool.to(device)
if __name__ == '__main__':
    # 1. 模擬整個目標數據集 (所有 BCHW 數據點的合併)
    # N_samples, channels, H, W = 100, 11, 32, 32
    # full_target_data = torch.randn(N_samples, channels, H, W) 
    
    # 確保 TKE 領域數據大部分為正 (模擬物理真實性)
    full_target_data[:, IDX_TKE] = torch.abs(full_target_data[:, IDX_TKE])
    
    # 2. 初始化 Metric 類 (計算 Baseline)
    metric_calculator = ScaledPhysicsMetrics(full_target_data)
    
    # 3. 模擬一次訓練 Batch 的預測和目標
    with torch.no_grad():
        batch_idx = np.random.choice(len(x_pool), 20, replace=False)
        pred_batch = x_pool[batch_idx].clone()  # clone 避免 in-place 污染
        target_batch = y_pool[batch_idx]

    # 4. 基礎 Scaled Score 計算 (參考值)
    scores = metric_calculator(pred_batch, target_batch)

    print("\n--- 基礎 Scaled Metric Scores (參考值) ---")
    for key, score in scores.items():
        print(f"{key:<30}: {score.item():.4f}")
        
    # ====================================================================
    #           測試案例 A: 數據領域 MAE Score 歸零測試 (Worst Case)
    # ====================================================================
    print("\n--- 測試案例 A: MAE Score 歸零測試 (極端數據誤差) ---")
    bad_pred_mae = target_batch + 100 * metric_calculator.mae_max # 引入比 Baseline 大 100 倍的誤差
    bad_scores_mae = metric_calculator(bad_pred_mae, target_batch)
    
    mae_score = bad_scores_mae['data_mae_score'].item()
    expected_msg_mae = "Score 應為 0.0000 (誤差遠超 Baseline)"
    print(f"Data MAE Score: {mae_score:.4f} ({expected_msg_mae})")


    # ====================================================================
    #           測試案例 B: 物理領域 L_U / Rec. U Score 歸零測試 (Worst Case)
    # ====================================================================
    print("\n--- 測試案例 B: L_U / Rec. U Score 歸零測試 (速度不一致性極端) ---")
    bad_pred_u = pred_batch.clone()
    bad_pred_u[:, IDX_U_MAGNITUDE, :, :] = 100.0 # 設置 U_pred 遠超 Reconstructed U
    
    bad_scores_u = metric_calculator(bad_pred_u, target_batch)
    
    l_u_score = bad_scores_u['physical_l_u_score'].item()
    rec_u_score = bad_scores_u['physical_rec_u_score'].item()
    
    expected_msg_u = "Score 應為 0.0000 (速度一致性完全崩潰)"
    print(f"Physical L_U Score:    {l_u_score:.4f} ({expected_msg_u})")
    print(f"Physical Rec. U Score: {rec_u_score:.4f} ({expected_msg_u})")


    # ====================================================================
    #           測試案例 C: 物理領域 Flux 超界比例 P_flux Score 歸零測試 (Worst Case)
    # ====================================================================
    print("\n--- 測試案例 C: P_flux Score 歸零測試 (Flux 爆炸) ---")
    bad_pred_flux = pred_batch.clone()
    
    TKE_val = 0.1 
    alpha_val = metric_calculator.alpha 
    bad_pred_flux[:, IDX_TKE, :, :] = TKE_val 
    T_uw_val = alpha_val * TKE_val * 100.0 # 設置 T_uw 為 alpha * TKE 的 100 倍
    bad_pred_flux[:, IDX_T_UW, :, :] = T_uw_val 
    
    bad_scores_flux = metric_calculator(bad_pred_flux, target_batch)
    
    p_flux_score = bad_scores_flux['physical_p_flux_score'].item()
    expected_msg_flux = "Score 應為 0.0000 (Flux 100% 超界)"
    print(f"Physical P_flux Score: {p_flux_score:.4f} ({expected_msg_flux})")


    # ====================================================================
    #           測試案例 D: 良好狀態數據輸入 (Good Prediction)
    # ====================================================================
    print("\n--- 測試案例 D: 良好狀態數據輸入 (Good Prediction) ---")
    
    # 創建一個接近 target 的預測 (加入微小的隨機雜訊)
    good_pred_batch = target_batch + torch.randn_like(target_batch) * 0.01 
    
    good_scores = metric_calculator(good_pred_batch, target_batch)
    
    print("預期所有 Metric Score 均接近 1.0000 (少量誤差):")
    for key, score in good_scores.items():
        print(f"{key:<30}: {score.item():.4f}")


    # ====================================================================
    #           測試案例 E: Ground Truth 輸入 (Perfect Case)
    # ====================================================================
    print("\n--- 測試案例 E: Ground Truth 輸入 (Perfect Case) ---")
    
    # 直接使用 target 作為 pred (MAE=0, SSIM=1, P_k<0=0, P_flux=0, L_U_MAE=0)
    perfect_pred_batch = target_batch.clone()
    
    perfect_scores = metric_calculator(perfect_pred_batch, target_batch)
    
    print("預期所有 Metric Score 均為 1.0000 (完美匹配):")
    for key, score in perfect_scores.items():
        print(f"{key:<30}: {score.item():.4f}")