#  STEP 0: Config 

In [None]:
# ===== STEP 0: config =====
import os, glob, json, math
from typing import Optional, List, Tuple
import numpy as np
import cv2
import re
import matplotlib.pyplot as plt

# 依赖检测（PSNR/SSIM）
from skimage.metrics import peak_signal_noise_ratio as sk_psnr
from skimage.metrics import structural_similarity as sk_ssim
import torch
import skvideo.measure as skm # niqe


In [None]:
# 数据与输出路径（请按需修改）
RAIN100H_INP_DIR = "data/rain100H/input"   # 只包含 .png 雨图
RAIN100H_GT_DIR  = "data/rain100H/target"   
RAIN100L_INP_DIR = "data/rain100L/input"
RAIN100L_GT_DIR = "data/rain100L/target"
OUT_DIR_GF = "output/gf"            # 保存 GF 输出



# Process input

In [None]:

# -----------------------------
# 工具函数
# -----------------------------
def ensure_dir(path: str) -> None:
    """
    创建输出目录（若不存在）。
    Args:
        path: 目录路径
    """
    os.makedirs(path, exist_ok=True)

def imread_bgr_uint8(path: str) -> np.ndarray:
    """
    读入 BGR uint8 图像，丢弃 alpha 通道，失败则抛错。
    Returns:
        img_bgr: np.uint8 [H, W, 3], 值域 0..255
    """
    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    if img is None:
        raise FileNotFoundError(f"Failed to read: {path}")
    if img.ndim == 2:  # 灰度 -> 3 通道
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
    if img.shape[2] == 4:  # 丢弃 alpha
        img = img[:, :, :3]
    if img.dtype != np.uint8:
        img = np.clip(img, 0, 255).astype(np.uint8)
    return img

def list_pngs(dir_path: str) -> List[str]:
    """
    列出目录下所有 .png（不递归），按字母序排序。
    Args:
        dir_path: 目录路径
    Returns:
        文件完整路径列表（排序后）
    """
    return sorted(glob.glob(os.path.join(dir_path, "*.png")))

# -----------------------------
# RAIN100H 配对（rain-001.png ↔ norain-001.png）
# -----------------------------
_H_IN_RE  = re.compile(r"rain-(\d+)\.png$", re.IGNORECASE)
_H_GT_RE  = re.compile(r"norain-(\d+)\.png$", re.IGNORECASE)

def build_pairs_rain100h(inp_dir: str, gt_dir: str) -> List[Tuple[str, str]]:
    """
    基于文件名编号配对 RAIN100H：
    input:  rain-001.png
    target: norain-001.png
    Args:
        inp_dir: 雨图目录
        gt_dir:  无雨 GT 目录
    Returns:
        pairs: [(inp_path, gt_path), ...] 仅返回两边都存在的编号
    """
    inp_paths = list_pngs(inp_dir)
    gt_paths  = list_pngs(gt_dir)

    id2inp: Dict[str, str] = {}
    id2gt:  Dict[str, str] = {}

    for p in inp_paths:
        m = _H_IN_RE.search(os.path.basename(p))
        if m:
            id2inp[m.group(1)] = p

    for p in gt_paths:
        m = _H_GT_RE.search(os.path.basename(p))
        if m:
            id2gt[m.group(1)] = p

    common_ids = sorted(set(id2inp.keys()) & set(id2gt.keys()), key=lambda x: int(x))
    pairs = [(id2inp[_id], id2gt[_id]) for _id in common_ids]
    if len(pairs) == 0:
        raise RuntimeError("RAIN100H: 没有匹配到任何成对样本，请检查路径与命名。")
    return pairs

# -----------------------------
# RAIN100L 配对（1.png ↔ 1.png）
# -----------------------------
_L_RE = re.compile(r"(\d+)\.png$", re.IGNORECASE)

def build_pairs_rain100l(inp_dir: str, gt_dir: str) -> List[Tuple[str, str]]:
    """
    基于相同文件名配对 RAIN100L：
    input:  1.png
    target: 1.png
    Args:
        inp_dir: 雨图目录
        gt_dir:  无雨 GT 目录
    Returns:
        pairs: [(inp_path, gt_path), ...] 仅返回两边都存在的文件名
    """
    inp_paths = list_pngs(inp_dir)
    gt_paths  = list_pngs(gt_dir)

    name2inp: Dict[str, str] = {os.path.basename(p).lower(): p for p in inp_paths}
    name2gt:  Dict[str, str] = {os.path.basename(p).lower(): p for p in gt_paths}

    common_names = sorted(
        set(name2inp.keys()) & set(name2gt.keys()),
        key=lambda x: int(_L_RE.search(x).group(1)) if _L_RE.search(x) else x
    )
    pairs = [(name2inp[n], name2gt[n]) for n in common_names]
    if len(pairs) == 0:
        raise RuntimeError("RAIN100L: 没有匹配到任何成对样本，请检查路径与命名。")
    return pairs

# -----------------------------
# 数据集加载与预览
# -----------------------------
def load_pair_bgr(inp_path: str, gt_path: Optional[str] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
    """
    载入一对图像（BGR）。
    Args:
        inp_path: 雨图路径
        gt_path:  GT 路径（可为 None）
    Returns:
        (img_rain_bgr, img_gt_bgr or None)
    """
    rain = imread_bgr_uint8(inp_path)
    gt   = imread_bgr_uint8(gt_path) if gt_path is not None else None
    return rain, gt

def _resize_by_height(img: np.ndarray, target_h: int) -> np.ndarray:
    """按高度等比缩放到 target_h，宽度自适应。"""
    h, w = img.shape[:2]
    if h == target_h:
        return img
    new_w = int(round(w * (target_h / h)))
    return cv2.resize(img, (new_w, target_h), interpolation=cv2.INTER_AREA)

def preview_groups(
    groups: List[Tuple[str, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]],
    col_names: Tuple[str, str, str] = ("Rainy", "Processed", "GT"),
    row_height: int = 320,
    suptitle: str = ""
) -> None:
    """
    同时可视化多组图（每组最多三张：Rainy / Processed / GT），按行排版。
    Args:
        groups: 列表，每个元素形如：
            (group_title, img_rain_bgr, img_proc_bgr_or_None, img_gt_bgr_or_None)
        col_names: 三列的列名（Processed/GT 可为空时仍保留列标题）
        row_height: 每行统一的显示高度（像素），宽度等比缩放
        suptitle: 整体大标题
    """
    assert len(groups) > 0, "groups 不能为空"
    n_rows = len(groups)
    n_cols = 3

    # 预处理：BGR->RGB，并按高度resize
    rows_rgb = []
    for (gtitle, rain_bgr, proc_bgr, gt_bgr) in groups:
        row_imgs = []
        for img_bgr in (rain_bgr, proc_bgr, gt_bgr):
            if img_bgr is None:
                row_imgs.append(None)
            else:
                img_resized = _resize_by_height(img_bgr, row_height)
                img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
                row_imgs.append(img_rgb)
        rows_rgb.append((gtitle, row_imgs))

    # 计算画布宽度：按照每行三张拼接后的最大宽度估计
    row_widths = []
    for _, (im1, im2, im3) in rows_rgb:
        widths = [im.shape[1] for im in (im1, im2, im3) if im is not None]
        row_widths.append(sum(widths) if widths else 0)
    fig_w = max(8, min(24, int(np.ceil(max(row_widths)/80))))  # 粗略自适应宽度
    fig_h = max(4, n_rows * (row_height / 80 + 0.8))

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_w, fig_h))
    if n_rows == 1:
        axes = np.expand_dims(axes, 0)  # 保证可用 axes[row, col] 索引

    for r, (gtitle, (im_rain, im_proc, im_gt)) in enumerate(rows_rgb):
        for c, im in enumerate((im_rain, im_proc, im_gt)):
            ax = axes[r, c]
            ax.axis("off")
            if im is not None:
                ax.imshow(im)
            ax.set_title(col_names[c], fontsize=11)
        # 每行左侧添加组名
        axes[r, 0].set_ylabel(gtitle, fontsize=12, rotation=0, labelpad=40, va='center')

    if suptitle:
        fig.suptitle(suptitle, fontsize=14)
    plt.tight_layout()
    plt.show()

# -----------------------------
# 构建 H/L 成对样本列表
# -----------------------------

# ==== CELL 1: Image import demo (Rain100L & Rain100H first pair) ====

# 1) 构建配对列表
pairs_H = build_pairs_rain100h(RAIN100H_INP_DIR, RAIN100H_GT_DIR)
pairs_L = build_pairs_rain100l(RAIN100L_INP_DIR, RAIN100L_GT_DIR)
print(f"RAIN100H pairs: {len(pairs_H)}")
print(f"RAIN100L pairs: {len(pairs_L)}")

assert len(pairs_H) > 0 and len(pairs_L) > 0, "pairs_H 或 pairs_L 为空，请检查数据路径与命名。"

# 2) 取“第一组”样本路径（后续 Cell 会复用这些变量）
L_inp_path, L_gt_path = pairs_L[0]
H_inp_path, H_gt_path = pairs_H[0]
print("L sample:", os.path.basename(L_inp_path), "|", os.path.basename(L_gt_path))
print("H sample:", os.path.basename(H_inp_path), "|", os.path.basename(H_gt_path))

# 3) 读取图像（统一用 imread_bgr_uint8）
rainL_bgr, gtL_bgr = imread_bgr_uint8(L_inp_path), imread_bgr_uint8(L_gt_path)
rainH_bgr, gtH_bgr = imread_bgr_uint8(H_inp_path), imread_bgr_uint8(H_gt_path)

# 4) 可视化（此时不做任何预处理/算法）
preview_groups(
    groups=[
        ("Rain100L", rainL_bgr, None, gtL_bgr),
        ("Rain100H", rainH_bgr, None, gtH_bgr),
    ],
    col_names=("Rainy", "Processed", "GT"),
    suptitle="Image Import L vs H (Rainy / GT)"
)

# 5) 可选：快速自检，确保都是 HxWx3 uint8
def _check_img(img, name):
    assert img.dtype == np.uint8 and img.ndim == 3 and img.shape[2] == 3, \
        f"{name}: expect uint8 HxWx3, got {img.dtype} {img.shape}"
    print(f"{name}: ok, shape={img.shape}")

_check_img(rainL_bgr, "rainL_bgr")
_check_img(gtL_bgr,   "gtL_bgr")
_check_img(rainH_bgr, "rainH_bgr")
_check_img(gtH_bgr,   "gtH_bgr")



In [None]:
# -----------------------------
# 基础工具：读入/类型/颜色
# -----------------------------

def bgr_to_ycrcb(img_bgr_uint8: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    BGR(uint8) -> Y,Cr,Cb 三通道（uint8）。
    返回:
        Y, Cr, Cb: np.uint8, 形状与输入一致的单通道
    """
    ycrcb = cv2.cvtColor(img_bgr_uint8, cv2.COLOR_BGR2YCrCb)
    Y, Cr, Cb = cv2.split(ycrcb)
    return Y, Cr, Cb

def merge_ycrcb_to_bgr(Y_uint8: np.ndarray, Cr_uint8: np.ndarray, Cb_uint8: np.ndarray) -> np.ndarray:
    """
    Y,Cr,Cb(uint8) 合成 YCrCb 再转回 BGR(uint8)。
    """
    ycrcb = cv2.merge([Y_uint8, Cr_uint8, Cb_uint8])
    bgr   = cv2.cvtColor(ycrcb, cv2.COLOR_YCrCb2BGR)
    return bgr

def to_float01(x_uint8: np.ndarray) -> np.ndarray:
    """
    uint8 -> float32 的 [0,1] 归一化。
    """
    return x_uint8.astype(np.float32) / 255.0

def to_uint8_from01(x_float01: np.ndarray) -> np.ndarray:
    """
    float32[0,1] -> uint8（截断到 0..255）。
    """
    return np.clip(x_float01 * 255.0, 0, 255).astype(np.uint8)

# -----------------------------
# 半径一致性（可选：按短边比例缩放）
# -----------------------------
def scale_radius_by_short_side(r: int, img_shape_hw: Tuple[int, int],
                               ref_short: int = 512) -> int:
    """
    将像素半径 r 按图像短边与 ref_short 的比例缩放。
    当你的数据分辨率差异较大时建议开启；Rain100L/H 通常接近，可不缩放。
    Args:
        r:         原始半径（像素）
        img_shape_hw: (H, W)
        ref_short: 参考短边（默认 512）
    Returns:
        r_scaled:  缩放后的半径（至少 1）
    """
    H, W = img_shape_hw
    short_side = min(H, W)
    scale = short_side / float(ref_short)
    r_scaled = max(1, int(round(r * scale)))
    return r_scaled

# -----------------------------
# 预处理主函数（给 Guided Filter 用）
# -----------------------------
def prepare_for_gf(
    img_bgr_uint8: np.ndarray,
    r_base: int,
    eps_base: float,
    r_detail: Optional[int] = None,
    eps_detail: Optional[float] = None,
    scale_radius: bool = False,
    ref_short: int = 512
):
    """
    将输入 BGR(uint8) 预处理为 GF 所需的数据：
      - 转 YCrCb，Y 作为引导与处理主线；
      - 内部用 float32[0,1] 进行计算；
      - 可选按短边缩放半径，保证不同分辨率下视觉行为一致。
    Args:
        img_bgr_uint8: 输入 BGR(uint8)
        r_base, eps_base: 基础层 GF 参数
        r_detail, eps_detail: 细节层平滑 GF 参数（可 None 表示不做）
        scale_radius: 是否按短边缩放半径
        ref_short: 短边基准，用于缩放
    Returns:
        dict 包含：
          - Y01: float32[0,1] 的 Y 通道
          - Cr, Cb: uint8（原样保存，用于重建）
          - r_base_eff, r_detail_eff: 实际使用的半径（可能等于输入，或缩放后）
          - eps_base, eps_detail: 原样返回
          - shape_hw: (H, W)
    """
    H, W = img_bgr_uint8.shape[:2]
    Y_u8, Cr_u8, Cb_u8 = bgr_to_ycrcb(img_bgr_uint8)
    Y01 = to_float01(Y_u8)

    if scale_radius:
        r_base_eff = scale_radius_by_short_side(r_base, (H, W), ref_short=ref_short)
        r_detail_eff = None if r_detail is None else scale_radius_by_short_side(r_detail, (H, W), ref_short=ref_short)
    else:
        r_base_eff = int(r_base)
        r_detail_eff = None if r_detail is None else int(r_detail)

    return {
        "Y01": Y01,
        "Cr_u8": Cr_u8,
        "Cb_u8": Cb_u8,
        "r_base_eff": r_base_eff,
        "r_detail_eff": r_detail_eff,
        "eps_base": float(eps_base),
        "eps_detail": None if eps_detail is None else float(eps_detail),
        "shape_hw": (H, W),
    }

# -----------------------------
# 评测准备：PSNR/SSIM 在 Y 通道；可选 shave 边框
# -----------------------------
def prepare_y_for_metrics(
    pred_bgr_uint8: np.ndarray,
    gt_bgr_uint8: Optional[np.ndarray] = None,
    shave: int = 0
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
    """
    提取 Y 通道用于评测；可选 shave 去除边框像素（减少边界效应）。
    Args:
        pred_bgr_uint8: 预测/处理后的 BGR
        gt_bgr_uint8:   GT（可为 None）
        shave:          在四周裁掉的像素数（0 表示不裁）
    Returns:
        y_pred_u8, y_gt_u8(or None), 都是 uint8 的单通道
    """
    yp_u8 = bgr_to_ycrcb(pred_bgr_uint8)[0]
    yg_u8 = bgr_to_ycrcb(gt_bgr_uint8)[0] if gt_bgr_uint8 is not None else None

    if shave > 0:
        yp_u8 = yp_u8[shave:-shave, shave:-shave]
        if yg_u8 is not None:
            yg_u8 = yg_u8[shave:-shave, shave:-shave]
    return yp_u8, yg_u8

# -----------------------------
# 评测准备：NIQE 的灰度输入（无 GT）
# -----------------------------
def to_gray_for_niqe(img_bgr_uint8: np.ndarray) -> np.ndarray:
    """
    转为灰度 uint8，用于 skvideo.measure.niqe。
    注意：调用 niqe 时按其要求的形状传入（例如 2D 灰度或额外扩一维）。
    """
    gray = cv2.cvtColor(img_bgr_uint8, cv2.COLOR_BGR2GRAY)
    return gray

# -----------------------------
# 示例：各取一对 Rain100L / Rain100H 做预处理并并排可视化
# -----------------------------
# 1) 预处理（参数可与后续 GF 保持一致）
prep_L = prepare_for_gf(
    rainL_bgr,
    r_base=16, eps_base=1e-3,
    r_detail=3, eps_detail=1e-3,
    scale_radius=False
)
prep_H = prepare_for_gf(
    rainH_bgr,
    r_base=16, eps_base=1e-3,
    r_detail=3, eps_detail=1e-3,
    scale_radius=False
)

# 2) 把预处理得到的 Y01 转成可视化用的“灰度伪三通道”
def y01_to_bgr_gray(y01):
    y_u8 = to_uint8_from01(y01)
    return cv2.cvtColor(y_u8, cv2.COLOR_GRAY2BGR)

yL_vis = y01_to_bgr_gray(prep_L["Y01"])
yH_vis = y01_to_bgr_gray(prep_H["Y01"])

# 3) 可视化：Rainy / Y(preproc) / GT
preview_groups(
    groups=[
        ("Rain100L", rainL_bgr, yL_vis, gtL_bgr),
        ("Rain100H", rainH_bgr, yH_vis, gtH_bgr),
    ],
    col_names=("Rainy", "Y (preproc)", "GT"),
    suptitle="Preprocess demo - Rain100L vs Rain100H"
)

# 4) 可选：打印一些预处理信息，便于核对参数与数值域
print(
    "L -> r_base_eff:", prep_L["r_base_eff"],
    "| r_detail_eff:", prep_L["r_detail_eff"],
    "| Y01 range:", (float(prep_L["Y01"].min()), float(prep_L["Y01"].max())),
    "| shape:", prep_L["shape_hw"]
)
print(
    "H -> r_base_eff:", prep_H["r_base_eff"],
    "| r_detail_eff:", prep_H["r_detail_eff"],
    "| Y01 range:", (float(prep_H["Y01"].min()), float(prep_H["Y01"].max())),
    "| shape:", prep_H["shape_hw"]
)



# Guided Filter

In [None]:
# ==========================
# CELL 3 — Guided Filter demo
# (Rain100L & Rain100H first pair)
# ==========================
import cv2, numpy as np, matplotlib.pyplot as plt

# ---- utility: box filter for NumPy fallback ----
def _boxfilter(img: np.ndarray, r: int) -> np.ndarray:
    return cv2.blur(img, (2*r+1, 2*r+1))

# ---- (1) NumPy & OpenCV guided filter ----
def guided_filter_numpy(p: np.ndarray, I: np.ndarray = None,
                        r: int = 16, eps: float = 1e-3) -> np.ndarray:
    if I is None: I = p
    mean_I  = _boxfilter(I, r)
    mean_p  = _boxfilter(p, r)
    cov_Ip  = _boxfilter(I*p, r) - mean_I*mean_p
    var_I   = _boxfilter(I*I, r) - mean_I*mean_I
    a = cov_Ip / (var_I + eps)
    b = mean_p - a*mean_I
    mean_a = _boxfilter(a, r); mean_b = _boxfilter(b, r)
    return np.clip(mean_a*I + mean_b, 0., 1.)

def guided_filter_cv(p_u8: np.ndarray, I_u8: np.ndarray = None,
                     r: int = 16, eps: float = 1e-3) -> np.ndarray:
    if I_u8 is None: I_u8 = p_u8
    if hasattr(cv2, "ximgproc") and hasattr(cv2.ximgproc, "guidedFilter"):
        p32, I32 = p_u8.astype(np.float32)/255., I_u8.astype(np.float32)/255.
        q32 = cv2.ximgproc.guidedFilter(I32, p32, r, eps)
        return np.clip(q32*255., 0, 255).astype(np.uint8)

    # fallback — NumPy per-channel
    if p_u8.ndim == 3:
        chans = [guided_filter_numpy(p_u8[...,c]/255., I_u8[...,c]/255., r, eps)
                 for c in range(3)]
        return np.stack([c*255 for c in chans], 2).astype(np.uint8)
    else:
        q = guided_filter_numpy(p_u8/255., I_u8/255., r, eps)
        return (q*255).astype(np.uint8)

# ---- (2) strong derain GF with soft-threshold & contrast match ----
def derain_guided_filter_bgr_strong(
        img_bgr_u8,
        r_base=28, eps_base=5e-3,
        use_detail_smooth=True, r_detail=2, eps_detail=1e-3,
        detail_gain=0.55, iters=2,
        soft_tau_val=0.60, soft_tau_mode="percentile",
        do_contrast_match=True):
    # colour ↔ YCrCb
    y_u8, cr_u8, cb_u8 = cv2.split(cv2.cvtColor(img_bgr_u8, cv2.COLOR_BGR2YCrCb))
    y_cur = y_u8.copy()

    for _ in range(max(1, iters)):
        base_u8 = guided_filter_cv(y_cur, y_cur, r_base, eps_base)
        h_i16   = y_cur.astype(np.int16) - base_u8.astype(np.int16)
        h_abs   = cv2.absdiff(y_cur, base_u8).astype(np.float32)

        if use_detail_smooth and r_detail > 0:
            h_abs = guided_filter_cv(h_abs.astype(np.uint8), h_abs.astype(np.uint8),
                                     r_detail, eps_detail).astype(np.float32)

        tau = (np.percentile(h_abs, soft_tau_val*100)
               if soft_tau_mode=="percentile"
               else np.median(h_abs)+2.5*(np.median(np.abs(h_abs-np.median(h_abs)))+1e-6))

        h_shrink = np.sign(h_i16).astype(np.float32)*np.maximum(h_abs-tau, 0.)
        y_cur    = np.clip(base_u8.astype(np.float32)+detail_gain*h_shrink,
                           0, 255).astype(np.uint8)

    if do_contrast_match:                                  # mean-std match
        y0, y1 = y_u8.astype(np.float32), y_cur.astype(np.float32)
        y_cur  = np.clip((y1-y1.mean())*(y0.std()/(y1.std()+1e-6))+y0.mean(),
                         0, 255).astype(np.uint8)

    out_bgr = cv2.cvtColor(cv2.merge([y_cur, cr_u8, cb_u8]), cv2.COLOR_YCrCb2BGR)
    return out_bgr

# ---- 0) ensure previous cells ran ----
for n in ["rainL_bgr","gtL_bgr","rainH_bgr","gtH_bgr"]:
    assert n in globals(), f"{n} missing – run Cell 1 first"

# ---- 1) run GF  (L & H)  ----
# very–strong GF preset (overwrites previous gfL_bgr / gfH_bgr)
# gfL_bgr = derain_guided_filter_bgr_strong(
#     rainL_bgr,
#     r_base     = 48,         # much larger than before
#     eps_base   = 2e-2,
#     r_detail   = 0,          # no extra smoothing on light set
#     detail_gain= 0.25,
#     iters      = 2,
#     soft_tau_val = 0.8,
#     do_contrast_match = True)

# gfH_bgr = derain_guided_filter_bgr_strong(
#     rainH_bgr,
#     r_base     = 64,         # even larger for heavy streaks
#     eps_base   = 3e-2,
#     r_detail   = 3,          # keep small-radius polish
#     detail_gain= 0.15,
#     iters      = 2,
#     soft_tau_val = 0.8,
#     do_contrast_match = True)

# ---- 1) run GF  (L & H)  ----
# very–strong GF preset (overwrites previous gfL_bgr / gfH_bgr)
gfL_bgr = derain_guided_filter_bgr_strong(
    rainL_bgr,
    r_base     = 64,         # much larger than before
    eps_base   = 5e-3,         # no extra smoothing on light set
    detail_gain= 1.1,
    iters      = 4,
    soft_tau_val = 0.55,
    do_contrast_match = True)


gfH_bgr = derain_guided_filter_bgr_strong(
    rainH_bgr,
    r_base     = 64,         # much larger than before
    eps_base   = 0.04,         # no extra smoothing on light set
    detail_gain= 1,
    iters      = 4,
    soft_tau_val = 0.4,
    r_detail= 2,
    do_contrast_match = True)

# ---- 2) visualize ----
preview_groups(
    groups=[("Rain100L", rainL_bgr, gfL_bgr, gtL_bgr),
            ("Rain100H", rainH_bgr, gfH_bgr, gtH_bgr)],
    col_names=("Rainy","GF","GT"),
    suptitle="Guided Filter Results – Rain100L & Rain100H")

# ---- 3) quick diff check ----
mad_L = np.mean(np.abs(rainL_bgr.astype(np.float32)-gfL_bgr.astype(np.float32)))
mad_H = np.mean(np.abs(rainH_bgr.astype(np.float32)-gfH_bgr.astype(np.float32)))
print(f"MAD  L={mad_L:.2f} , H={mad_H:.2f}")


# Store the single Demo Image

In [None]:
OUT_DIR_GF_ROOT = "output/gf"
OUT_DIR_GF_L = os.path.join(OUT_DIR_GF_ROOT, "rain100L")
OUT_DIR_GF_H = os.path.join(OUT_DIR_GF_ROOT, "rain100H")
os.makedirs(OUT_DIR_GF_L, exist_ok=True)
os.makedirs(OUT_DIR_GF_H, exist_ok=True)

# ---------- 1) 保存“当前 demo”两张图 ----------
# 需要前面已经得到的：
#   L_inp_path, H_inp_path  （第一对样本的文件路径）
#   gfL_bgr, gfH_bgr        （对应的 GF 处理结果）
save_name_L = f"gf_{os.path.basename(L_inp_path)}"
save_name_H = f"gf_{os.path.basename(H_inp_path)}"
cv2.imwrite(os.path.join(OUT_DIR_GF_L, save_name_L), gfL_bgr)
cv2.imwrite(os.path.join(OUT_DIR_GF_H, save_name_H), gfH_bgr)
print(f"[Saved] {os.path.join(OUT_DIR_GF_L, save_name_L)}")
print(f"[Saved] {os.path.join(OUT_DIR_GF_H, save_name_H)}")

# PSNR/SSIM for single Demo Image

In [None]:
# ===== Unified PSNR/SSIM (Y-channel) for Rain100L/H =====
import os
from skimage.metrics import peak_signal_noise_ratio as sk_psnr
from skimage.metrics import structural_similarity as sk_ssim

def compute_metrics_rain100(mode: str, index: int, shave: int = 0,
                            inp_dir_L="data/rain100L/input",   gt_dir_L="data/rain100L/target",   gf_dir_L="output/gf/rain100L",
                            inp_dir_H="data/rain100H/input",   gt_dir_H="data/rain100H/target",   gf_dir_H="output/gf/rain100H"):
    """
    mode: 'L' for Rain100L  (filenames: 1.png)
          'H' for Rain100H  (filenames: rain-001.png / norain-001.png)
    index: L 用自然数 (1..100)；H 用 1..100，会自动格式化为 3 位编号
    shave: 可选，边框裁剪像素
    """
    mode = mode.upper()
    assert mode in ("L","H"), "mode must be 'L' or 'H'"

    if mode == "L":
        fname   = f"{index}.png"
        inp_dir, gt_dir, gf_dir = inp_dir_L, gt_dir_L, gf_dir_L
        inp_path = os.path.join(inp_dir, fname)
        gt_path  = os.path.join(gt_dir,  fname)
        gf_path  = os.path.join(gf_dir,  f"gf_{fname}")
    else:
        id_str  = f"{index:03d}"
        inp_dir, gt_dir, gf_dir = inp_dir_H, gt_dir_H, gf_dir_H
        inp_name = f"rain-{id_str}.png"
        gt_name  = f"norain-{id_str}.png"
        inp_path = os.path.join(inp_dir, inp_name)
        gt_path  = os.path.join(gt_dir,  gt_name)
        gf_path  = os.path.join(gf_dir,  f"gf_{inp_name}")

    # 读图
    assert os.path.exists(inp_path) and os.path.exists(gt_path) and os.path.exists(gf_path), \
        f"Missing: {inp_path} or {gt_path} or {gf_path}"
    pred_bgr = imread_bgr_uint8(gf_path)
    gt_bgr   = imread_bgr_uint8(gt_path)

    # 在 Y 通道上评测
    y_pred, y_gt = prepare_y_for_metrics(pred_bgr, gt_bgr, shave=shave)
    psnr_y  = sk_psnr(y_gt, y_pred, data_range=255)
    ssim_y  = sk_ssim(y_gt, y_pred, data_range=255)

    tag = f"Rain100{mode} #{index if mode=='L' else id_str}"
    print(f"[{tag}] PSNR-Y: {psnr_y:.4f}  SSIM-Y: {ssim_y:.4f}")
    return psnr_y, ssim_y

# Rain100L 第 1 张
psnr_y_1L, ssim_y_1L = compute_metrics_rain100('L', 1, shave=4)

# Rain100H 第 1 张（对应 rain-001.png）
psnr_y_1H, ssim_y_1H = compute_metrics_rain100('H', 1, shave=4)


# Batch Processing Output

In [None]:
for i in range(1, 100):
    print("=============== Image ", i, "==========================================")
    # 2) 取“第一组”样本路径（后续 Cell 会复用这些变量）
    L_inp_path, L_gt_path = pairs_L[i]
    H_inp_path, H_gt_path = pairs_H[i]
    print("L sample:", os.path.basename(L_inp_path), "|", os.path.basename(L_gt_path))
    print("H sample:", os.path.basename(H_inp_path), "|", os.path.basename(H_gt_path))

    # 3) 读取图像（统一用 imread_bgr_uint8）
    rainL_bgr, gtL_bgr = imread_bgr_uint8(L_inp_path), imread_bgr_uint8(L_gt_path)
    rainH_bgr, gtH_bgr = imread_bgr_uint8(H_inp_path), imread_bgr_uint8(H_gt_path)

    # 5) 可选：快速自检，确保都是 HxWx3 uint8
    _check_img(rainL_bgr, "rainL_bgr")
    _check_img(gtL_bgr,   "gtL_bgr")
    _check_img(rainH_bgr, "rainH_bgr")
    _check_img(gtH_bgr,   "gtH_bgr")

    # 示例：各取一对 Rain100L / Rain100H 做预处理并并排可视化
    # 1) 预处理（参数可与后续 GF 保持一致）
    prep_L = prepare_for_gf(
        rainL_bgr,
        r_base=16, eps_base=1e-3,
        r_detail=3, eps_detail=1e-3,
        scale_radius=False
    )
    prep_H = prepare_for_gf(
        rainH_bgr,
        r_base=16, eps_base=1e-3,
        r_detail=3, eps_detail=1e-3,
        scale_radius=False
    )

    yL_vis = y01_to_bgr_gray(prep_L["Y01"])
    yH_vis = y01_to_bgr_gray(prep_H["Y01"])

    # 4) 可选：打印一些预处理信息，便于核对参数与数值域
    print(
        "L -> r_base_eff:", prep_L["r_base_eff"],
        "| r_detail_eff:", prep_L["r_detail_eff"],
        "| Y01 range:", (float(prep_L["Y01"].min()), float(prep_L["Y01"].max())),
        "| shape:", prep_L["shape_hw"]
    )
    print(
        "H -> r_base_eff:", prep_H["r_base_eff"],
        "| r_detail_eff:", prep_H["r_detail_eff"],
        "| Y01 range:", (float(prep_H["Y01"].min()), float(prep_H["Y01"].max())),
        "| shape:", prep_H["shape_hw"]
    )

    # ---- 0) ensure previous cells ran ----
    for n in ["rainL_bgr","gtL_bgr","rainH_bgr","gtH_bgr"]:
        assert n in globals(), f"{n} missing – run Cell 1 first"

    # ---- 1) run GF  (L & H)  ----
    # very–strong GF preset (overwrites previous gfL_bgr / gfH_bgr)
    gfL_bgr = derain_guided_filter_bgr_strong(
        rainL_bgr,
        r_base     = 64,         # much larger than before
        eps_base   = 5e-3,         # no extra smoothing on light set
        detail_gain= 1.1,
        iters      = 4,
        soft_tau_val = 0.55,
        do_contrast_match = True)


    gfH_bgr = derain_guided_filter_bgr_strong(
        rainH_bgr,
        r_base     = 64,         # much larger than before
        eps_base   = 0.04,         # no extra smoothing on light set
        detail_gain= 1,
        iters      = 4,
        soft_tau_val = 0.4,
        r_detail= 2,
        do_contrast_match = True)
    
    # ---- 3) quick diff check ----
    mad_L = np.mean(np.abs(rainL_bgr.astype(np.float32)-gfL_bgr.astype(np.float32)))
    mad_H = np.mean(np.abs(rainH_bgr.astype(np.float32)-gfH_bgr.astype(np.float32)))
    print(f"MAD  L={mad_L:.2f} , H={mad_H:.2f}")

    # ---------- 1) 保存“当前 demo”两张图 ----------
    # 需要前面已经得到的：
    #   L_inp_path, H_inp_path  （第一对样本的文件路径）
    #   gfL_bgr, gfH_bgr        （对应的 GF 处理结果）
    save_name_L = f"gf_{os.path.basename(L_inp_path)}"
    save_name_H = f"gf_{os.path.basename(H_inp_path)}"
    cv2.imwrite(os.path.join(OUT_DIR_GF_L, save_name_L), gfL_bgr)
    cv2.imwrite(os.path.join(OUT_DIR_GF_H, save_name_H), gfH_bgr)
    print(f"[Saved] {os.path.join(OUT_DIR_GF_L, save_name_L)}")
    print(f"[Saved] {os.path.join(OUT_DIR_GF_H, save_name_H)}")

    






# Batch Processing PSNR/SSIM for the Test Dataset

In [None]:
# ==========================
# Batch evaluation on Rain100L/H (PSNR/SSIM on Y)
# ==========================
import os, csv, numpy as np
from datetime import datetime

def eval_rain100_batch(mode: str, shave: int = 4,
                       inp_dir_L="data/rain100L/input",   gt_dir_L="data/rain100L/target",   gf_dir_L="output/gf/rain100L",
                       inp_dir_H="data/rain100H/input",   gt_dir_H="data/rain100H/target",   gf_dir_H="output/gf/rain100H",
                       csv_out_dir="output/metrics"):
    """
    对 Rain100L/H 的 GF 结果做批量评测（PSNR/SSIM@Y），并保存 CSV。
    mode: 'L' or 'H'
    shave: 边框裁剪像素（GF/PRN 对比时务必保持一致）
    依赖: compute_metrics_rain100(mode, index, shave, ...)
    """
    os.makedirs(csv_out_dir, exist_ok=True)
    mode = mode.upper()
    assert mode in ("L","H")

    # 统计范围与路径
    if mode == "L":
        n_items = len([f for f in os.listdir(inp_dir_L) if f.lower().endswith(".png")])
        get_metrics = lambda i: compute_metrics_rain100('L', i, shave=shave,
                            inp_dir_L=inp_dir_L, gt_dir_L=gt_dir_L, gf_dir_L=gf_dir_L,
                            inp_dir_H=inp_dir_H, gt_dir_H=gt_dir_H, gf_dir_H=gf_dir_H)
        name_of = lambda i: f"{i}.png"
    else:
        n_items = len([f for f in os.listdir(inp_dir_H) if f.lower().endswith(".png") and f.startswith("rain-")])
        get_metrics = lambda i: compute_metrics_rain100('H', i, shave=shave,
                            inp_dir_L=inp_dir_L, gt_dir_L=gt_dir_L, gf_dir_L=gf_dir_L,
                            inp_dir_H=inp_dir_H, gt_dir_H=gt_dir_H, gf_dir_H=gf_dir_H)
        name_of = lambda i: f"rain-{i:03d}.png"

    # 遍历评测
    rows, psnrs, ssims = [], [], []
    for idx in range(1, n_items+1):
        try:
            psnr_y, ssim_y = get_metrics(idx)
            psnrs.append(psnr_y); ssims.append(ssim_y)
            rows.append([idx, name_of(idx), psnr_y, ssim_y])
        except AssertionError as e:
            print(f"[Skip {name_of(idx)}] {e}")
        except Exception as e:
            print(f"[Error {name_of(idx)}] {e}")

    if not rows:
        print(f"[Eval Rain100{mode}] No valid samples found.")
        return

    # 汇总统计
    P, S = np.array(psnrs), np.array(ssims)
    summary = {
        "count": len(rows),
        "psnr_mean": float(P.mean()), "psnr_median": float(np.median(P)), "psnr_std": float(P.std()),
        "ssim_mean": float(S.mean()), "ssim_median": float(np.median(S)), "ssim_std": float(S.std()),
        "shave": shave,
    }

    # 输出 CSV
    stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    csv_path = os.path.join(csv_out_dir, f"rain100{mode}_gf_psnr_ssim_shave{shave}_{stamp}.csv")
    with open(csv_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["index", "filename", "psnr_y", "ssim_y"])
        writer.writerows(rows)
        writer.writerow([])
        writer.writerow(["count", "psnr_mean", "psnr_median", "psnr_std", "ssim_mean", "ssim_median", "ssim_std", "shave"])
        writer.writerow([summary["count"], summary["psnr_mean"], summary["psnr_median"], summary["psnr_std"],
                         summary["ssim_mean"], summary["ssim_median"], summary["ssim_std"], summary["shave"]])

    print(f"[Eval Rain100{mode}] N={summary['count']}  "
          f"PSNR(mean±std)={summary['psnr_mean']:.3f}±{summary['psnr_std']:.3f}  "
          f"SSIM(mean±std)={summary['ssim_mean']:.4f}±{summary['ssim_std']:.4f}  "
          f"(shave={shave})")
    print(f"[Saved] {csv_path}")


eval_rain100_batch('L', shave=4)
eval_rain100_batch('H', shave=4)
