# Data preprocessing & dataloaders

In [None]:
import torch, subprocess
print("cuda_available:", torch.cuda.is_available())
print("num_gpus:", torch.cuda.device_count())
print("name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else None)


In [None]:
# ============================================================
# Data preprocessing & dataloaders (single-file, no argparse, flat configs)
# ============================================================
import os, math, random
from dataclasses import dataclass
from typing import List, Tuple, Optional

import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader

# ----------------------
# Minimal flat "config"
# ----------------------
# === Test sets & example output (as you requested) ===
RAIN100H_INP_DIR = "data/rain100H/input"    # 只包含 .png 雨图
RAIN100H_GT_DIR  = "data/rain100H/target"
RAIN100L_INP_DIR = "data/rain100L/input"
RAIN100L_GT_DIR  = "data/rain100L/target"
OUT_DIR_GF       = "output/gf"              # 示例：GF 输出保存处（本段代码不使用）

# === Train sets ===  (按你之前描述：从 Rain200L/H 各取 400 组 [100.png-499.png])
# 若你的实际文件夹是大写(Rain200L)，把这两行改成相应大小写
RAIN200L_TRAIN_INP_DIR = "data/Rain200L/train/input"
RAIN200L_TRAIN_GT_DIR  = "data/Rain200L/train/target"
RAIN200H_TRAIN_INP_DIR = "data/Rain200H/train/input"
RAIN200H_TRAIN_GT_DIR  = "data/Rain200H/train/target"

# === Split & reproducibility ===
VAL_PERCENT         = 0.15     # 15% 作为验证集
SEED                = 3407

# === Preprocess / Augment ===
TRAIN_PATCH_SIZE    = 128      # 训练用随机裁剪 patch
USE_PATCH_TRAIN     = True
AUG_FLIP            = True     # 随机水平/垂直翻转
AUG_ROT90           = True     # 随机 90° 旋转
VAL_USE_FULL_IMAGE  = True     # 验证用整图；如显存紧张可设为 False
VAL_CENTER_CROP     = 256      # 仅在 VAL_USE_FULL_IMAGE=False 时使用

# === Dataloader ===
BATCH_SIZE          = 8
NUM_WORKERS         = 0        # macOS/M1 推荐 0 或 2
PIN_MEMORY          = False

# === Metrics 统一设置（PRN / GF 都使用） ===
METRIC_USE_Y_CHANNEL = True   # 指标用 Y 通道
METRIC_SHAVE_PIXELS  = 4      # 统一 shave 边缘像素数；如需更严格对齐可设 4/8


In [None]:
# ----------------------
# Utils
# ----------------------
def set_global_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def _is_image_file(name: str) -> bool:
    return os.path.splitext(name)[1].lower() in [".png", ".jpg", ".jpeg", ".bmp"]

def list_paired_samples(dir_input: str, dir_target: str) -> List[Tuple[str, str, str]]:
    """
    返回按文件名排序的 (input_path, target_path, basename) 列表，只保留两侧都存在的配对。
    """
    assert os.path.isdir(dir_input), f"Not found: {dir_input}"
    assert os.path.isdir(dir_target), f"Not found: {dir_target}"
    names = [n for n in os.listdir(dir_input) if _is_image_file(n)]

    pairs = []
    for n in names:
        ip = os.path.join(dir_input, n)
        tp = os.path.join(dir_target, n)
        if os.path.isfile(tp):
            pairs.append((ip, tp, n))

    def key_fn(x):
        base = os.path.splitext(x[2])[0]
        try:
            return int(base)    # 让 100.png < 101.png …
        except:
            return x[2]
    pairs.sort(key=key_fn)
    return pairs

def split_train_val(
    pairs: List[Tuple[str, str, str]],
    val_percent: float,
    by_tail: bool = True,
    seed: int = 0
):
    """
    训练/验证划分。
    默认 by_tail=True：取“按文件名排序后的尾部 val%”为验证集（简单稳定、可复现）。
    如果更喜欢随机划分：by_tail=False，并提供 seed。
    """
    assert 0.0 < val_percent < 1.0
    n = len(pairs)
    k = int(math.ceil(n * val_percent))
    if by_tail:
        return pairs[: n - k], pairs[n - k :]
    else:
        rng = random.Random(seed)
        idx = list(range(n))
        rng.shuffle(idx)
        val_set = set(idx[:k])
        tr, va = [], []
        for i, p in enumerate(pairs):
            (va if i in val_set else tr).append(p)
        return tr, va

# ----------------------
# Tensor transforms
# ----------------------
def pil_to_tensor(img: Image.Image) -> torch.Tensor:
    arr = np.array(img, dtype=np.float32) / 255.0
    if arr.ndim == 2:  # 灰度
        arr = np.expand_dims(arr, -1)
    arr = arr.transpose(2, 0, 1)  # HWC -> CHW
    return torch.from_numpy(arr)

def random_crop_pair(inp: torch.Tensor, tgt: torch.Tensor, size: int, rng: random.Random):
    _, H, W = inp.shape
    if H < size or W < size:
        pad_h, pad_w = max(0, size - H), max(0, size - W)
        inp = torch.nn.functional.pad(inp, (0, pad_w, 0, pad_h), mode="reflect")
        tgt = torch.nn.functional.pad(tgt, (0, pad_w, 0, pad_h), mode="reflect")
        _, H, W = inp.shape
    y = rng.randint(0, H - size)
    x = rng.randint(0, W - size)
    return inp[:, y:y+size, x:x+size], tgt[:, y:y+size, x:x+size]

def random_flip_rotate_pair(inp: torch.Tensor, tgt: torch.Tensor, rng: random.Random, flip=True, rot90=True):
    if flip:
        if rng.random() < 0.5:  # 水平翻转
            inp = torch.flip(inp, dims=[2]); tgt = torch.flip(tgt, dims=[2])
        if rng.random() < 0.5:  # 垂直翻转
            inp = torch.flip(inp, dims=[1]); tgt = torch.flip(tgt, dims=[1])
    if rot90 and rng.random() < 0.5:
        k = rng.choice([1, 2, 3])
        inp = torch.rot90(inp, k, dims=[1, 2]); tgt = torch.rot90(tgt, k, dims=[1, 2])
    return inp, tgt

def center_crop_pair(inp: torch.Tensor, tgt: torch.Tensor, size: int):
    _, H, W = inp.shape
    if H < size or W < size:
        pad_h, pad_w = max(0, size - H), max(0, size - W)
        inp = torch.nn.functional.pad(inp, (0, pad_w, 0, pad_h), mode="reflect")
        tgt = torch.nn.functional.pad(tgt, (0, pad_w, 0, pad_h), mode="reflect")
        _, H, W = inp.shape
    y = (H - size) // 2
    x = (W - size) // 2
    return inp[:, y:y+size, x:x+size], tgt[:, y:y+size, x:x+size]

import torch

def rgb_to_y(img: torch.Tensor) -> torch.Tensor:
    """
    img: [B,3,H,W] in [0,1] (sRGB 假设)
    返回: [B,1,H,W] Y 通道（BT.601, Studio Range 简化为 0..1）
    公式来源（系数常用）：Y = 0.299 R + 0.587 G + 0.114 B
    """
    assert img.ndim == 4 and img.size(1) == 3, "rgb_to_y expects [B,3,H,W]"
    r, g, b = img[:, 0:1, :, :], img[:, 1:2, :, :], img[:, 2:2+1, :, :]
    y = 0.299 * r + 0.587 * g + 0.114 * b
    return y.clamp(0.0, 1.0)

def crop_border(img: torch.Tensor, shave: int) -> torch.Tensor:
    """
    img: [B,C,H,W]; 在 H/W 维度裁掉边缘 shave 个像素
    """
    if shave <= 0:
        return img
    return img[:, :, shave:-shave, shave:-shave]


In [None]:
# ----------------------
# Dataset
# ----------------------
@dataclass
class SampleItem:
    input_path: str
    target_path: str
    name: str
    subset: str  # e.g., "Rain200L", "Rain100H"

class RainPairDataset(Dataset):
    """
    mode:
      - "train": 随机裁剪 + 轻量增强
      - "val"  : 整图 or 中心裁剪，无随机增强
      - "test" : 整图，保留 target 以便后续算指标
    """
    def __init__(
        self,
        samples: List[SampleItem],
        mode: str,
        train_patch_size: int = TRAIN_PATCH_SIZE,
        use_patch_train: bool = USE_PATCH_TRAIN,
        aug_flip: bool = AUG_FLIP,
        aug_rot90: bool = AUG_ROT90,
        val_use_full_image: bool = VAL_USE_FULL_IMAGE,
        val_center_crop: int = VAL_CENTER_CROP,
        seed: int = SEED,
    ):
        assert mode in ["train", "val", "test"]
        self.samples = samples
        self.mode = mode
        self.train_patch_size = train_patch_size
        self.use_patch_train = use_patch_train
        self.aug_flip = aug_flip
        self.aug_rot90 = aug_rot90
        self.val_use_full_image = val_use_full_image
        self.val_center_crop = val_center_crop
        self.rng = random.Random(seed)

    def __len__(self): return len(self.samples)

    def _load_pair(self, input_path: str, target_path: Optional[str]):
        inp = Image.open(input_path).convert("RGB")
        inp_t = pil_to_tensor(inp)
        tgt_t = None
        if target_path is not None and os.path.isfile(target_path):
            tgt = Image.open(target_path).convert("RGB")
            tgt_t = pil_to_tensor(tgt)
        return inp_t, tgt_t

    def __getitem__(self, idx: int):
        it = self.samples[idx]
        inp_t, tgt_t = self._load_pair(it.input_path, it.target_path)

        if self.mode == "train":
            if self.use_patch_train:
                inp_t, tgt_t = random_crop_pair(inp_t, tgt_t, self.train_patch_size, self.rng)
            if self.aug_flip or self.aug_rot90:
                inp_t, tgt_t = random_flip_rotate_pair(inp_t, tgt_t, self.rng, self.aug_flip, self.aug_rot90)

        elif self.mode == "val":
            if not self.val_use_full_image:
                inp_t, tgt_t = center_crop_pair(inp_t, tgt_t, self.val_center_crop)

        # "test": 保持整图
        return {
            "input": inp_t, "target": tgt_t,
            "name": it.name, "subset": it.subset,
            "input_path": it.input_path, "target_path": it.target_path
        }

In [None]:
# ----------------------
# Builders
# ----------------------
def build_train_val_datasets():
    set_global_seed(SEED)

    # Rain200L
    L_pairs = list_paired_samples(RAIN200L_TRAIN_INP_DIR, RAIN200L_TRAIN_GT_DIR)
    L_tr, L_va = split_train_val(L_pairs, VAL_PERCENT, by_tail=True, seed=SEED)
    L_tr_items = [SampleItem(ip, tp, n, "Rain200L") for ip, tp, n in L_tr]
    L_va_items = [SampleItem(ip, tp, n, "Rain200L") for ip, tp, n in L_va]

    # Rain200H
    H_pairs = list_paired_samples(RAIN200H_TRAIN_INP_DIR, RAIN200H_TRAIN_GT_DIR)
    H_tr, H_va = split_train_val(H_pairs, VAL_PERCENT, by_tail=True, seed=SEED)
    H_tr_items = [SampleItem(ip, tp, n, "Rain200H") for ip, tp, n in H_tr]
    H_va_items = [SampleItem(ip, tp, n, "Rain200H") for ip, tp, n in H_va]

    train_ds = RainPairDataset(L_tr_items + H_tr_items, mode="train")
    val_ds   = RainPairDataset(L_va_items + H_va_items, mode="val")

    print(f"[Data] Rain200L total={len(L_pairs)} | train={len(L_tr_items)} | val={len(L_va_items)}")
    print(f"[Data] Rain200H total={len(H_pairs)} | train={len(H_tr_items)} | val={len(H_va_items)}")
    print(f"[Data] Combined train={len(train_ds)} | val={len(val_ds)}")
    return train_ds, val_ds

def build_train_val_loaders():
    train_ds, val_ds = build_train_val_datasets()
    train_loader = DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=True
    )
    val_loader = DataLoader(
        val_ds, batch_size=1 if VAL_USE_FULL_IMAGE else BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=False
    )
    return train_loader, val_loader

def build_test_datasets():
    L_pairs = list_paired_samples(RAIN100L_INP_DIR, RAIN100L_GT_DIR)
    H_pairs = list_paired_samples(RAIN100H_INP_DIR, RAIN100H_GT_DIR)
    testL = RainPairDataset([SampleItem(ip, tp, n, "Rain100L") for ip, tp, n in L_pairs], mode="test")
    testH = RainPairDataset([SampleItem(ip, tp, n, "Rain100H") for ip, tp, n in H_pairs], mode="test")
    print(f"[Data] Rain100L test={len(testL)} | Rain100H test={len(testH)}")
    return testL, testH

def build_test_loaders():
    testL, testH = build_test_datasets()
    testL_loader = DataLoader(testL, batch_size=1, shuffle=False,
                              num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=False)
    testH_loader = DataLoader(testH, batch_size=1, shuffle=False,
                              num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=False)
    return testL_loader, testH_loader

# ----------------------
# Quick sanity check (可按需启用)
# ----------------------
tl, vl = build_train_val_loaders()
b = next(iter(tl))
print("[Peek][train]", b["input"].shape, b["target"].shape, b["input"].min().item(), b["input"].max().item())
b = next(iter(vl))
print("[Peek][val]", b["input"].shape, b["target"].shape, b["input"].min().item(), b["input"].max().item())

# PRN Model (Progressive Residual Network)

In [None]:
# ============================================================
# PRN Model (Progressive Residual Network, lightweight)
#  - 多阶段渐进细化（T 个阶段）
#  - 采用残差预测：x̂_t = x̂_{t-1} + f([y, x̂_{t-1}])
#  - 可选阶段间权重共享（节省参数/显存）
#  - 适合 M1: C=32, T=6, 每阶段 5 个 ResBlock
# ============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F

# --------
# 模型配置（扁平变量）
# --------
PRN_T_STAGES        = 6     # 阶段数
PRN_BASE_CHANNELS   = 32    # 主干通道
PRN_NUM_RESBLOCKS   = 5     # 每阶段 ResBlock 个数
PRN_WEIGHT_SHARING  = True  # 阶段间共享主干
PRN_USE_RESIDUAL    = True  # 预测残差而非直接输出
PRN_CLAMP_OUTPUT    = True  # 将输出裁剪到 [0,1]（训练/推理都安全）

# --------
# 基础模块
# --------
def kaiming_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
        if m.bias is not None:
            nn.init.zeros_(m.bias)

class ConvReLU(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, s=1, p=1):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=k, stride=s, padding=p)
        self.relu = nn.ReLU(inplace=True)
        self.apply(kaiming_init)

    def forward(self, x):
        return self.relu(self.conv(x))

class ResBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.conv1 = nn.Conv2d(ch, ch, 3, 1, 1)
        self.conv2 = nn.Conv2d(ch, ch, 3, 1, 1)
        self.relu  = nn.ReLU(inplace=True)
        self.apply(kaiming_init)

    def forward(self, x):
        idt = x
        out = self.relu(self.conv1(x))
        out = self.conv2(out)
        return self.relu(out + idt)

class PRNBackbone(nn.Module):
    """
    单个阶段的主干（可被多个阶段共享）：
      输入通道数 = 6（concat: y 与 x_prev）
      输出通道数 = 3（预测残差或预测图像）
    结构：Conv3x3 -> N*ResBlock -> Conv3x3
    """
    def __init__(self, base_ch=32, num_blocks=5, in_ch=6, out_ch=3):
        super().__init__()
        layers = [ConvReLU(in_ch, base_ch, 3, 1, 1)]
        for _ in range(num_blocks):
            layers.append(ResBlock(base_ch))
        self.trunk = nn.Sequential(*layers)
        self.head  = nn.Conv2d(base_ch, out_ch, 3, 1, 1)
        kaiming_init(self.head)

    def forward(self, x):
        feat = self.trunk(x)
        return self.head(feat)

class PRN(nn.Module):
    """
    Progressive Residual Network（无循环状态，轻量稳定）
      - forward(y) 返回：
          outputs_all: [x̂_1, x̂_2, ..., x̂_T]  （深度监督可用）
          output:      x̂_T（最终输出）
    """
    def __init__(
        self,
        T=PRN_T_STAGES,
        base_ch=PRN_BASE_CHANNELS,
        num_blocks=PRN_NUM_RESBLOCKS,
        weight_sharing=PRN_WEIGHT_SHARING,
        use_residual=PRN_USE_RESIDUAL,
        clamp_out=PRN_CLAMP_OUTPUT
    ):
        super().__init__()
        self.T = T
        self.use_residual = use_residual
        self.clamp_out = clamp_out

        if weight_sharing:
            self.shared = PRNBackbone(base_ch, num_blocks, in_ch=6, out_ch=3)
            self.stages = None
        else:
            self.shared = None
            self.stages = nn.ModuleList(
                [PRNBackbone(base_ch, num_blocks, in_ch=6, out_ch=3) for _ in range(T)]
            )

    def _stage_block(self, t: int):
        return self.shared if self.shared is not None else self.stages[t]

    @staticmethod
    def _concat(y, x_prev):
        # [B,3,H,W] cat [B,3,H,W] => [B,6,H,W]
        return torch.cat([y, x_prev], dim=1)

    def forward(self, y: torch.Tensor):
        """
        y: 带雨图像，范围建议在 [0,1]，shape [B,3,H,W]
        """
        x_prev = y
        outputs = []
        for t in range(self.T):
            inp = self._concat(y, x_prev)
            res_or_img = self._stage_block(t)(inp)
            if self.use_residual:
                x_cur = x_prev + res_or_img
            else:
                x_cur = res_or_img
            if self.clamp_out:
                x_cur = torch.clamp(x_cur, 0.0, 1.0)
            outputs.append(x_cur)
            x_prev = x_cur
        return {"all": outputs, "final": outputs[-1]}

# -------------
# 构建/统计工具
# -------------
def build_prn():
    """
    统一的模型构建函数；若后续想换配置，改上面的扁平变量即可。
    """
    model = PRN(
        T=PRN_T_STAGES,
        base_ch=PRN_BASE_CHANNELS,
        num_blocks=PRN_NUM_RESBLOCKS,
        weight_sharing=PRN_WEIGHT_SHARING,
        use_residual=PRN_USE_RESIDUAL,
        clamp_out=PRN_CLAMP_OUTPUT,
    )
    return model

def count_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# ----------------------
# (Optional) quick model sanity
# ----------------------
m = build_prn()
print(m)
x = torch.rand(2, 3, 128, 128)  # B=2
y = m(x)
print("params(M):", count_parameters(m)/1e6)
print("outs:", len(y["all"]), y["final"].shape)


# Losses, Metrics, Train/Val Loop, Test & Saving

In [None]:
# ============================================================
# Losses, Metrics, Train/Val Loop, Test & Saving (single-file)
# ============================================================
import time, csv
from pathlib import Path

# ----------------------
# 训练/验证/测试超参（扁平变量）
# ----------------------
EPOCHS             = 120
LR                 = 1e-4
WEIGHT_DECAY       = 0.0
LR_MILESTONES      = [72, 96]   # 与 3 小时目标匹配的衰减点，可按需调整
LR_GAMMA           = 0.2
EARLY_STOP_PATIENCE= 12         # 监控验证 SSIM 的早停
SAVE_DIR_CKPT      = "checkpoints"
CKPT_BEST_PATH     = os.path.join(SAVE_DIR_CKPT, "prn_best.pth")
CKPT_LAST_PATH     = os.path.join(SAVE_DIR_CKPT, "prn_last.pth")

# PRN 输出目录（按题意）
PRN_OUT_DIR_L      = "output/prn/rain100L"
PRN_OUT_DIR_H      = "output/prn/rain100H"

# 深度监督与损失权重
USE_DEEP_SUPERVISION = True
LAMBDA_SSIM          = 0.2      # Loss = L1 + LAMBDA_SSIM * (1-SSIM)
DS_WEIGHTS_MODE      = "equal"  # "equal" 或 "last_heavier"
DS_LAST_EXTRA        = 1.0      # 当 last_heavier 时，最后一阶段的额外权重

# AMP（混合精度）在 mps/cuda 上启用
USE_AMP = True

# ----------------------
# Metrics: PSNR / SSIM
# ----------------------
import torch
import torch.nn.functional as F

def _gaussian_window(window_size: int, sigma: float, device, channels: int):
    coords = torch.arange(window_size, dtype=torch.float32, device=device) - window_size // 2
    g = torch.exp(- (coords ** 2) / (2 * sigma * sigma))
    g = (g / g.sum()).unsqueeze(1)
    window_2d = g @ g.t()                       # [w, w]
    window_2d = window_2d.expand(channels, 1, window_size, window_size)
    return window_2d

def ssim_torch(x: torch.Tensor, y: torch.Tensor, window_size: int = 11, sigma: float = 1.5):
    """
    x,y: [B,C,H,W] in [0,1]，C 可为 1 或 3
    返回: [B] 的 SSIM
    """
    assert x.shape == y.shape, "SSIM expects same shape"
    device = x.device
    C = x.size(1)
    window = _gaussian_window(window_size, sigma, device, channels=C)
    C1 = (0.01 ** 2)
    C2 = (0.03 ** 2)

    mu_x = F.conv2d(x, window, groups=C, padding=window_size//2)
    mu_y = F.conv2d(y, window, groups=C, padding=window_size//2)
    sigma_x  = F.conv2d(x * x, window, groups=C, padding=window_size//2) - mu_x ** 2
    sigma_y  = F.conv2d(y * y, window, groups=C, padding=window_size//2) - mu_y ** 2
    sigma_xy = F.conv2d(x * y, window, groups=C, padding=window_size//2) - mu_x * mu_y

    ssim_map = ((2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2)) / \
               ((mu_x ** 2 + mu_y ** 2 + C1) * (sigma_x + sigma_y + C2))
    # 对通道取均值，再对空间取均值
    ssim_val = ssim_map.mean(dim=[1,2,3])
    return ssim_val


def psnr_torch(x: torch.Tensor, y: torch.Tensor, eps: float = 1e-8):
    """
    x,y: [B,3,H,W] in [0,1]
    """
    mse = F.mse_loss(x, y, reduction='none').mean(dim=[1,2,3])  # [B]
    psnr = -10.0 * torch.log10(mse + eps)
    return psnr

# ----------------------
# Loss with Deep Supervision
# ----------------------
def prn_loss(outputs: dict, target: torch.Tensor):
    """
    outputs: {"all": [x1, x2, ..., xT], "final": xT}
    target : clean image [B,3,H,W] in [0,1]
    """
    outs = outputs["all"] if USE_DEEP_SUPERVISION else [outputs["final"]]
    T = len(outs)
    if DS_WEIGHTS_MODE == "equal":
        weights = [1.0 / T] * T
    elif DS_WEIGHTS_MODE == "last_heavier":
        base = 1.0 / (T - 1 + DS_LAST_EXTRA)
        weights = [base] * (T - 1) + [base * DS_LAST_EXTRA]
    else:
        weights = [1.0 / T] * T

    total = 0.0
    for w, pred in zip(weights, outs):
        l1 = F.l1_loss(pred, target)
        ssim_val = ssim_torch(pred, target).mean()
        loss = l1 + LAMBDA_SSIM * (1.0 - ssim_val)
        total = total + w * loss
    return total

# Model Training & Validation

In [None]:
# ----------------------
# 设备选择 & AMP context
# ----------------------
def get_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")

def amp_autocast(device):
    if not USE_AMP:
        from contextlib import nullcontext
        return nullcontext()
    if device.type == "cuda":
        return torch.autocast(device_type="cuda", dtype=torch.float16)
    if device.type == "mps":
        return torch.autocast(device_type="mps", dtype=torch.float16)
    # cpu 不用 AMP
    from contextlib import nullcontext
    return nullcontext()

# ----------------------
# 训练与验证
# ----------------------
from torch.optim import Adam
from torch.optim.lr_scheduler import MultiStepLR

def train_one_epoch(model, optimizer, scaler, train_loader, device):
    model.train()
    running_loss = 0.0
    n = 0
    t0 = time.time()
    for batch in train_loader:
        inp = batch["input"].to(device, non_blocking=False)
        tgt = batch["target"].to(device, non_blocking=False)

        optimizer.zero_grad(set_to_none=True)
        with amp_autocast(device):
            out = model(inp)
            loss = prn_loss(out, tgt)

        if USE_AMP and device.type in ("cuda", "mps"):
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        running_loss += loss.item() * inp.size(0)
        n += inp.size(0)

    return running_loss / max(1, n), time.time() - t0

@torch.no_grad()
def validate(model, val_loader, device):
    """
    验证阶段：
      - loss 仍用 RGB（与训练一致）
      - 指标(PSNR/SSIM)统一到 Y 通道 + shave = METRIC_SHAVE_PIXELS
      - 返回: 验证集平均 loss、PSNR、SSIM、耗时
    """
    model.eval()
    total_loss = 0.0
    psnr_list, ssim_list = [], []
    n = 0
    t0 = time.time()

    for batch in val_loader:
        inp = batch["input"].to(device, non_blocking=False)   # [B,3,H,W], 0..1
        tgt = batch["target"].to(device, non_blocking=False)  # [B,3,H,W], 0..1

        # 前向与损失（RGB）
        with amp_autocast(device):
            out = model(inp)
            loss = prn_loss(out, tgt)   # RGB 上的 L1+SSIM 组合损失
            pred = out["final"]         # [B,3,H,W], 0..1

        # 累计平均 loss（按样本数加权）
        bs = inp.size(0)
        total_loss += loss.item() * bs
        n += bs

        # ---- 指标：Y 通道 + shave ----
        ref_eval  = tgt
        pred_eval = pred
        if METRIC_USE_Y_CHANNEL:
            ref_eval  = rgb_to_y(ref_eval)   # -> [B,1,H,W]
            pred_eval = rgb_to_y(pred_eval)  # -> [B,1,H,W]
        if METRIC_SHAVE_PIXELS > 0:
            ref_eval  = crop_border(ref_eval,  METRIC_SHAVE_PIXELS)
            pred_eval = crop_border(pred_eval, METRIC_SHAVE_PIXELS)

        psnr_list.append(psnr_torch(pred_eval, ref_eval))   # [B]
        ssim_list.append(ssim_torch(pred_eval, ref_eval))   # [B]

    avg_loss = total_loss / max(1, n)
    avg_psnr = torch.cat(psnr_list).mean().item() if psnr_list else 0.0
    avg_ssim = torch.cat(ssim_list).mean().item() if ssim_list else 0.0
    return avg_loss, avg_psnr, avg_ssim, time.time() - t0


def fit_prn():
    # 准备目录
    Path(SAVE_DIR_CKPT).mkdir(parents=True, exist_ok=True)

    # 数据
    train_loader, val_loader = build_train_val_loaders()

    # 模型 & 设备
    device = get_device()
    model = build_prn().to(device)
    optimizer = Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = MultiStepLR(optimizer, milestones=LR_MILESTONES, gamma=LR_GAMMA)
    scaler = torch.cuda.amp.GradScaler(enabled=(USE_AMP and device.type in ("cuda", "mps")))  # mps 也可用此 scaler

    print(f"[Device] {device} | params: {count_parameters(model)/1e6:.3f} M")
    best_ssim = -1.0
    best_state = None
    epochs_no_improve = 0

    for epoch in range(1, EPOCHS + 1):
        tr_loss, tr_time = train_one_epoch(model, optimizer, scaler, train_loader, device)
        val_loss, val_psnr, val_ssim, val_time = validate(model, val_loader, device)
        scheduler.step()

        print(f"[Epoch {epoch:03d}] "
              f"train_loss={tr_loss:.4f} ({tr_time:.1f}s) | "
              f"val_loss={val_loss:.4f} psnr={val_psnr:.2f} ssim={val_ssim:.4f} ({val_time:.1f}s) | "
              f"lr={optimizer.param_groups[0]['lr']:.2e}")

        # 保存 last
        torch.save({"epoch": epoch,
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "scheduler": scheduler.state_dict(),
                    "ssim": val_ssim}, CKPT_LAST_PATH)

        # 选优（以验证 SSIM）
        if val_ssim > best_ssim:
            best_ssim = val_ssim
            best_state = {
                "epoch": epoch,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "ssim": val_ssim
            }
            torch.save(best_state, CKPT_BEST_PATH)
            print(f"  -> New best (SSIM={best_ssim:.4f}), saved to {CKPT_BEST_PATH}")
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= EARLY_STOP_PATIENCE:
                print(f"  -> Early stopping at epoch {epoch} (no improve {epochs_no_improve} epochs).")
                break

    print(f"[Train Done] best_val_ssim={best_ssim:.4f}")

# Model Inference & Save Results & Metric Calculations on the Test set

In [None]:
# ----------------------
# 推理 & 保存结果 & 计算指标 (测试集)
# ----------------------
from torchvision.utils import save_image

@torch.no_grad()
def inference_and_save(model, loader, out_dir: str, device):
    """
    测试/推理阶段：
      - 保存 PRN 输出（RGB PNG）
      - 指标(PSNR/SSIM)统一到 Y 通道 + shave = METRIC_SHAVE_PIXELS
      - 生成 metrics.csv（逐图 + 均值）
    """
    Path(out_dir).mkdir(parents=True, exist_ok=True)
    rows = [("name", "psnr", "ssim")]
    psnr_all, ssim_all = [], []

    model.eval()
    for batch in loader:
        inp = batch["input"].to(device)   # [1,3,H,W]
        tgt = batch["target"].to(device)  # [1,3,H,W]
        name = batch["name"][0]           # 文件名，如 "1.png"

        # 前向
        with amp_autocast(device):
            out = model(inp)
            pred = out["final"]           # [1,3,H,W], 0..1

        # 保存预测图（RGB）
        save_path = os.path.join(out_dir, name)
        save_image(pred.clamp(0, 1), save_path)

        # ---- 指标：Y 通道 + shave ----
        ref_eval  = tgt
        pred_eval = pred
        if METRIC_USE_Y_CHANNEL:
            ref_eval  = rgb_to_y(ref_eval)    # [1,1,H,W]
            pred_eval = rgb_to_y(pred_eval)   # [1,1,H,W]
        if METRIC_SHAVE_PIXELS > 0:
            ref_eval  = crop_border(ref_eval,  METRIC_SHAVE_PIXELS)
            pred_eval = crop_border(pred_eval, METRIC_SHAVE_PIXELS)

        psnr = psnr_torch(pred_eval, ref_eval).item()
        ssimv = ssim_torch(pred_eval, ref_eval).item()

        rows.append((name, f"{psnr:.4f}", f"{ssimv:.4f}"))
        psnr_all.append(psnr)
        ssim_all.append(ssimv)

    # 写 metrics.csv（含均值）
    mean_psnr = sum(psnr_all) / len(psnr_all) if psnr_all else 0.0
    mean_ssim = sum(ssim_all) / len(ssim_all) if ssim_all else 0.0
    rows.append(("AVERAGE", f"{mean_psnr:.4f}", f"{mean_ssim:.4f}"))

    metrics_csv = os.path.join(out_dir, "metrics.csv")
    with open(metrics_csv, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerows(rows)

    print(f"[Test] Saved to {out_dir} | PSNR={mean_psnr:.3f} SSIM={mean_ssim:.4f}")


def test_prn_on_rain100():
    device = get_device()
    # 加载 best
    assert os.path.isfile(CKPT_BEST_PATH), f"Best checkpoint not found: {CKPT_BEST_PATH}"
    ckpt = torch.load(CKPT_BEST_PATH, map_location="cpu")
    model = build_prn().to(device)
    model.load_state_dict(ckpt["model"])
    print(f"[Load] {CKPT_BEST_PATH} (epoch={ckpt.get('epoch','?')}, ssim={ckpt.get('ssim','-')})")

    # 数据
    testL_loader, testH_loader = build_test_loaders()

    # 推理并保存
    inference_and_save(model, testL_loader, PRN_OUT_DIR_L, device)
    inference_and_save(model, testH_loader, PRN_OUT_DIR_H, device)

# ----------------------
# （可选）一键运行：先训后测
# ----------------------
# if __name__ == "__main__":
#     fit_prn()               # 训练，保存 best/last
#     test_prn_on_rain100()   # 加载 best 在 Rain100L/H 上评测并落盘