In [4]:
# ======================================================
# 三路融合：CLIP SVM + PRNU CNN + ELA CNN  (Val學權重 → Test評估)
# ======================================================
import os, json, time, glob, math
from pathlib import Path
import numpy as np
from tqdm import tqdm

import torch, torch.nn as nn
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report, confusion_matrix, roc_curve
import joblib

# ---------------- Config ----------------
SCRIPT_ROOT = "/home/yaya/ai-detect-proj/Script"
SPLITS_JSON = f"{SCRIPT_ROOT}/saved_models/splits_clip_feature_iid_ood.json"
SAVED_DIR   = f"{SCRIPT_ROOT}/saved_models"

# 可調參：TTA 次數（0=關閉）
TTA_PRNU = 0   # 建議先 0 或 8
TTA_ELA  = 0   # 建議先 0 或 8
INPUT_SIZE = 256
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("device =", DEVICE)

# ---------------- 讀 splits + 目錄資訊 ----------------
with open(SPLITS_JSON, "r", encoding="utf-8") as f:
    J = json.load(f)
SPL = J["splits"]
META_DIRS = (J.get("meta", {}) or {}).get("dirs", {})

# 回傳各 split 的 clip 路徑與標籤（以 clip 為主索引）
def clip_paths_labels(split_name):
    real = SPL[split_name]["clip"]["real"]; fake = SPL[split_name]["clip"]["fake"]
    P = real + fake
    y = np.array([0]*len(real) + [1]*len(fake), dtype=int)
    stems = [Path(p).stem.lower() for p in P]
    return P, y, stems

# 建立索引：從 stem 找到對應的 PRNU / ELA 檔案
def index_dir(d, exts=(".npy",".npz")):
    d = Path(d); idx = {}
    if not d.exists(): return idx
    for ext in exts:
        for q in d.glob(f"*{ext}"):
            idx[q.stem.lower()] = str(q)
            idx[q.name.lower()] = str(q)
    return idx

# 取得 PRNU/ELA 目錄（從 meta 取，若無則給預設）
PRNU_REAL_DIR = META_DIRS.get("prnu", {}).get("real", f"{SCRIPT_ROOT}/features_i8/prnu_real_i8_npy")
PRNU_FAKE_DIR = META_DIRS.get("prnu", {}).get("fake", f"{SCRIPT_ROOT}/features_i8/prnu_fake_i8_npy")
ELA_REAL_DIR  = META_DIRS.get("ela",  {}).get("real", f"{SCRIPT_ROOT}/features_npy/ela_real_npy")
ELA_FAKE_DIR  = META_DIRS.get("ela",  {}).get("fake", f"{SCRIPT_ROOT}/features_npy/ela_fake_npy")

IDX_PRNU_REAL = index_dir(PRNU_REAL_DIR, (".npy",))
IDX_PRNU_FAKE = index_dir(PRNU_FAKE_DIR, (".npy",))
IDX_ELA_REAL  = index_dir(ELA_REAL_DIR,  (".npy",".npz"))
IDX_ELA_FAKE  = index_dir(ELA_FAKE_DIR,  (".npy",".npz"))

def map_stems_to_paths(stems, is_real, kind):
    if kind=="prnu":
        idx = IDX_PRNU_REAL if is_real==0 else IDX_PRNU_FAKE
    else:
        idx = IDX_ELA_REAL if is_real==0 else IDX_ELA_FAKE
    out = []
    miss = 0
    for s in stems:
        q = idx.get(s) or idx.get(s + ".npy") or idx.get(s + ".npz")
        if q is None:
            miss += 1
            out.append(None)
        else:
            out.append(q)
    return out, miss

# ---------------- 載入三路模型 ----------------
# 1) CLIP SVM
svm_paths = sorted(Path(SAVED_DIR).glob("clip_linear_svm_feature_*.joblib"))
assert svm_paths, "找不到 clip_linear_svm_feature_*.joblib，請先訓練 CLIP SVM。"
clip_svm = joblib.load(svm_paths[-1])
print("Loaded CLIP SVM:", svm_paths[-1].name)

# 2) PRNU CNN（網路結構需與訓練一致）
class SmallForensicCNN(nn.Module):
    def __init__(self, in_ch=1):
        super().__init__()
        def blk(ci, co, groups=8):
            return nn.Sequential(
                nn.Conv2d(ci, co, 3, padding=1, bias=False),
                nn.GroupNorm(num_groups=min(groups, co), num_channels=co),
                nn.ReLU(inplace=True)
            )
        self.net = nn.Sequential(
            blk(in_ch,32), blk(32,32), nn.AvgPool2d(2),
            blk(32,64),   blk(64,64),  nn.AvgPool2d(2),
            blk(64,128),  blk(128,128), nn.AdaptiveAvgPool2d(1),
        )
        self.fc = nn.Linear(128, 1)
    def forward(self,x): return self.fc(self.net(x).flatten(1)).squeeze(1)

prnu_ckpts = sorted(Path(SAVED_DIR).glob("prnu_cnn_i8_best_*.pt"))
assert prnu_ckpts, "找不到 prnu_cnn_i8_best_*.pt，請先訓練 PRNU CNN。"
prnu_model = SmallForensicCNN(1).to(DEVICE).eval()
prnu_model.load_state_dict(torch.load(prnu_ckpts[-1], map_location=DEVICE))
print("Loaded PRNU CNN:", prnu_ckpts[-1].name)

# 3) ELA CNN
class ELAForensicCNN(nn.Module):
    def __init__(self, in_ch=3):
        super().__init__()
        def bnblk(ci, co):
            return nn.Sequential(nn.Conv2d(ci, co, 3, padding=1, bias=False),
                                 nn.BatchNorm2d(co), nn.ReLU(inplace=True))
        def gnblk(ci, co, groups=8):
            return nn.Sequential(nn.Conv2d(ci, co, 3, padding=1, bias=False),
                                 nn.GroupNorm(num_groups=min(groups, co), num_channels=co),
                                 nn.ReLU(inplace=True))
        self.net = nn.Sequential(
            bnblk(in_ch,32), bnblk(32,32), nn.AvgPool2d(2),
            bnblk(32,64),    bnblk(64,64), nn.AvgPool2d(2),
            gnblk(64,128),   gnblk(128,128), nn.AvgPool2d(2),
            gnblk(128,256),  gnblk(256,256), nn.AdaptiveAvgPool2d(1),
        )
        self.fc = nn.Linear(256, 1)
    def forward(self,x): return self.fc(self.net(x).flatten(1)).squeeze(1)

ela_ckpts = sorted(Path(SAVED_DIR).glob("ela_fromnpy_cnn_best_*.pt"))
assert ela_ckpts, "找不到 ela_fromnpy_cnn_best_*.pt，請先訓練 ELA CNN。"
ela_model = ELAForensicCNN(3).to(DEVICE).eval()
ela_model.load_state_dict(torch.load(ela_ckpts[-1], map_location=DEVICE))
print("Loaded ELA CNN:", ela_ckpts[-1].name)



# ---------------- 三路分數計算（含前處理/TTA） ----------------
def l2norm(v):
    v = v.astype(np.float32).reshape(-1)
    n = np.linalg.norm(v) + 1e-12
    return v / n

def clip_scores(paths, batch=2048):
    X = np.vstack([l2norm(np.load(p, allow_pickle=True)) for p in tqdm(paths, desc="CLIP load")])
    sc = clip_svm.decision_function(X).astype(np.float32)  # >0 => fake
    return sc

# PRNU helpers
def avg_pool_2x(x):
    H,W = x.shape
    H2,W2 = (H//2)*2, (W//2)*2
    x = x[:H2,:W2]  
    return x.reshape(H2//2,2,W2//2,2).mean(axis=(1,3))

def crop_center_or_rand(x, size=256, center=True, rng=None):
    h,w = x.shape[:2]
    if h<size or w<size:
        ph,pw = max(0,size-h), max(0,size-w)
        if x.ndim==2:
            x = np.pad(x, ((ph//2, ph-ph//2),(pw//2, pw-pw//2)), mode='edge')
        else:
            x = np.pad(x, ((ph//2, ph-ph//2),(pw//2, pw-pw//2),(0,0)), mode='edge')
        h,w = x.shape[:2]
    if h==size and w==size: 
        return x.copy()
    if center:
        y0,x0 = (h-size)//2, (w-size)//2
    else:
        if rng is None: rng = np.random.default_rng()
        y0 = int(rng.integers(0, h-size+1)); x0 = int(rng.integers(0, w-size+1))
    return x[y0:y0+size, x0:x0+size].copy() if x.ndim==2 else x[y0:y0+size, x0:x0+size, :].copy()

def prnu_per_image_norm(a_i8: np.ndarray):
    x = a_i8.astype(np.float32, copy=False)
    m,s = x.mean(), x.std()
    if (not np.isfinite(s)) or s<1e-6:
        m,s = 0.0, 20.0
    return (x - m) / (s if s>0 else 1.0)

def load_prnu_i8(path):
    a = np.load(path, mmap_mode='r')
    a = np.asarray(a)
    if a.ndim == 3 and (a.shape[0]==1 or a.shape[-1]==1):
        a = a.squeeze()
    assert a.ndim==2 and a.dtype==np.int8, f"PRNU expect 2D int8, got {a.shape} {a.dtype}"
    return a

@torch.no_grad()
def prnu_logits(paths, tta=TTA_PRNU, batch=256):
    prnu_model.eval()
    rng = np.random.default_rng(1337)
    out = np.zeros(len(paths), np.float32)
    for i in tqdm(range(0, len(paths), batch), desc=f"PRNU logits (TTA={tta})"):
        chunk = paths[i:i+batch]
        acc = np.zeros(len(chunk), np.float32)
        reps = max(1, int(tta))  # 0->1
        for _ in range(reps):
            xs=[]
            for p in chunk:
                a = load_prnu_i8(p)
                if a.shape[0] >= 512 and a.shape[1] >= 512:
                    a = avg_pool_2x(a)
                a = crop_center_or_rand(a, INPUT_SIZE, center=(tta==0), rng=rng)
                a = prnu_per_image_norm(a)
                xs.append(torch.from_numpy(a).unsqueeze(0))  # [1,H,W]
            xb = torch.stack(xs,0).to(DEVICE)               # [B,1,H,W]
            logit = prnu_model(xb).float().cpu().numpy()
            acc += logit.astype(np.float32)
        out[i:i+batch] = acc/float(reps)
    return out  # logits

# ELA helpers
def _npz_pick(z):
    for k in ('ela','arr','arr_0','data'):
        if isinstance(z, np.lib.npyio.NpzFile) and (k in z.files):
            return z[k]
    return z[z.files[0]] if isinstance(z, np.lib.npyio.NpzFile) else z

def load_ela_array(path):
    z = np.load(path, mmap_mode='r')
    a = _npz_pick(z); a = np.asarray(a)
    if a.ndim==2: a = np.repeat(a[...,None],3,axis=2)
    elif a.ndim==3 and a.shape[0] in (1,3) and a.shape[-1] not in (1,3):
        a = np.transpose(a,(1,2,0))
    elif a.ndim==3 and a.shape[-1]==1:
        a = np.repeat(a,3,axis=2)
    assert a.ndim==3 and a.shape[-1]==3, f"ELA expect HxWx3, got {a.shape}"
    a = a.astype(np.float32, copy=False)
    if np.nanmax(a) > 1.5: a *= (1.0/255.0)
    return a

def ela_zscore(x):
    m = x.mean(axis=(0,1), keepdims=True)
    s = x.std(axis=(0,1), keepdims=True); s[s<1e-6] = 1.0
    return (x - m) / s

@torch.no_grad()
def ela_logits(paths, tta=TTA_ELA, batch=256):
    ela_model.eval()
    rng = np.random.default_rng(1337)
    out = np.zeros(len(paths), np.float32)
    for i in tqdm(range(0, len(paths), batch), desc=f"ELA logits (TTA={tta})"):
        chunk = paths[i:i+batch]
        acc = np.zeros(len(chunk), np.float32)
        reps = max(1, int(tta))
        for _ in range(reps):
            xs=[]
            for p in chunk:
                x = load_ela_array(p)
                x = crop_center_or_rand(x, INPUT_SIZE, center=(tta==0), rng=rng)
                x = ela_zscore(x)
                xs.append(torch.from_numpy(np.transpose(x,(2,0,1))))  # [3,H,W]
            xb = torch.stack(xs,0).to(DEVICE)
            logit = ela_model(xb).float().cpu().numpy()
            acc += logit.astype(np.float32)
        out[i:i+batch] = acc/float(reps)
    return out  # logits

# === TENT: norm-only 自適應（BN / SyncBN / GroupNorm）===
import torch.nn.functional as F

def tent_enable_norm(model):
    for m in model.modules():
        if isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm, nn.GroupNorm)):
            m.train()
            for p in m.parameters(): p.requires_grad_(True)   # 只開 norm 層
        else:
            for p in getattr(m, 'parameters', lambda: [])():
                p.requires_grad_(False)
    return model

def tent_adapt(model, unlabeled_loader, steps=1, lr=1e-4, device=DEVICE):
    model = tent_enable_norm(model)
    opt = torch.optim.Adam([p for p in model.parameters() if p.requires_grad], lr=lr)
    model.train()
    for _ in range(steps):
        for xb,_ in unlabeled_loader:
            xb = xb.to(device, non_blocking=True)
            logits = model(xb)
            prob = torch.sigmoid(logits).clamp(1e-6, 1-1e-6)
            ent = -(prob*torch.log(prob) + (1-prob)*torch.log(1-prob)).mean()  # 最小化熵
            opt.zero_grad(); ent.backward(); opt.step()
    model.eval()
    return model


# ---------------- 建立三路對齊資料（Val / Test-IID / Test-OOD） ----------------
def aligned_triplet(split_name):
    clipP, y, stems = clip_paths_labels(split_name)
    # map stem -> prnu/ela path（依 label 選 real/fake 目錄）
    prnu_paths = []
    ela_paths  = []
    miss_prnu = miss_ela = 0
    for s, yi in zip(stems, y):
        p_prnu = (IDX_PRNU_REAL.get(s) if yi==0 else IDX_PRNU_FAKE.get(s)) \
                 or (IDX_PRNU_REAL.get(s+".npy") if yi==0 else IDX_PRNU_FAKE.get(s+".npy"))
        p_ela  = (IDX_ELA_REAL.get(s)  if yi==0 else IDX_ELA_FAKE.get(s)) \
                 or (IDX_ELA_REAL.get(s+".npy") if yi==0 else IDX_ELA_FAKE.get(s+".npy")) \
                 or (IDX_ELA_REAL.get(s+".npz") if yi==0 else IDX_ELA_FAKE.get(s+".npz"))
        prnu_paths.append(p_prnu)
        ela_paths.append(p_ela)
        miss_prnu += int(p_prnu is None); miss_ela += int(p_ela is None)
    if miss_prnu or miss_ela:
        print(f"[{split_name}] 對齊缺檔 → PRNU:{miss_prnu}  ELA:{miss_ela}（將跳過）")

    # 保留三路皆存在的索引
    keep = [i for i,(p1,p2) in enumerate(zip(prnu_paths, ela_paths)) if (p1 is not None and p2 is not None)]
    clipP = [clipP[i] for i in keep]
    prnu_paths = [prnu_paths[i] for i in keep]
    ela_paths  = [ela_paths[i] for i in keep]
    y = y[keep]
    print(f"[{split_name}] 使用樣本數（三路齊全）：{len(keep)}")
    return clipP, prnu_paths, ela_paths, y

def compute_triplet_scores(clipP, prnuP, elaP):
    s_clip = clip_scores(clipP)
    z_prnu = prnu_logits(prnuP, tta=TTA_PRNU)
    z_ela  = ela_logits(elaP,  tta=TTA_ELA)
    # 轉成特徵向量： [clip_decision, prnu_logit, ela_logit]
    X = np.stack([s_clip, z_prnu, z_ela], axis=1).astype(np.float32)
    return X

# ---------------- 取各 split 的三路分數 ----------------
X_va = None
clip_va, prnu_va, ela_va, y_va = aligned_triplet("val")
if len(y_va): X_va = compute_triplet_scores(clip_va, prnu_va, ela_va)

clip_ti, prnu_ti, ela_ti, y_ti = aligned_triplet("test_iid")
X_ti = compute_triplet_scores(clip_ti, prnu_ti, ela_ti) if len(y_ti) else None

clip_to, prnu_to, ela_to, y_to = aligned_triplet("test_ood")

# === OOD 的 unlabeled loader（只用影像，不用標籤）===
from torch.utils.data import Dataset, DataLoader

class PRNUUnlab(Dataset):
    def __init__(self, paths): self.paths = paths
    def __len__(self): return len(self.paths)
    def __getitem__(self, i):
        a = load_prnu_i8(self.paths[i]).astype(np.float32)
        if a.shape[0] >= 512 and a.shape[1] >= 512:
            a = avg_pool_2x(a)
        a = crop_center_or_rand(a, INPUT_SIZE, center=False)
        a = prnu_per_image_norm(a)
        return torch.from_numpy(a).unsqueeze(0), 0.0  # [1,H,W], dummy label

class ELAUnlab(Dataset):
    def __init__(self, paths): self.paths = paths
    def __len__(self): return len(self.paths)
    def __getitem__(self, i):
        x = load_ela_array(self.paths[i])
        x = crop_center_or_rand(x, INPUT_SIZE, center=False)
        x = ela_zscore(x)
        return torch.from_numpy(np.transpose(x,(2,0,1))), 0.0  # [3,H,W], dummy label

bs = 128
prnu_unlab = DataLoader(PRNUUnlab(prnu_to), batch_size=bs, shuffle=True, num_workers=0, drop_last=True)
ela_unlab  = DataLoader(ELAUnlab(ela_to),  batch_size=bs, shuffle=True, num_workers=0, drop_last=True)

# === 在 OOD 上做一次 TENT 自適應（小步長、一次即可）===
prnu_model = tent_adapt(prnu_model, prnu_unlab, steps=1, lr=1e-4)
ela_model  = tent_adapt(ela_model,  ela_unlab,  steps=1, lr=1e-4)

# === 開 TTA，重算 OOD 的 logits，再組合 X_to ===
TTA_PRNU = 8
TTA_ELA  = 8
s_clip_to = clip_scores(clip_to)                 # 原本就有
z_prnu_to = prnu_logits(prnu_to, tta=TTA_PRNU)   # 重新計算（已自適應）
z_ela_to  = ela_logits(ela_to,   tta=TTA_ELA)
X_to = np.stack([s_clip_to, z_prnu_to, z_ela_to], axis=1).astype(np.float32)


X_to = compute_triplet_scores(clip_to, prnu_to, ela_to) if len(y_to) else None

assert X_va is not None and len(X_va), "Val split 在三路對齊後為空，請檢查資料。"

# ---------------- 訓練融合器（StandardScaler → LogisticRegression） ----------------
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.calibration import CalibratedClassifierCV
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler().fit(X_va)
Xv = scaler.transform(X_va)

base = HistGradientBoostingClassifier(
    max_depth=3, learning_rate=0.1, max_iter=300,
    l2_regularization=1e-3, min_samples_leaf=20, random_state=1337
)
# 先訓練、再用 Val 做保留折校準（讓輸出可當機率）
fuser = CalibratedClassifierCV(base, method="isotonic", cv=3)
fuser.fit(Xv, y_va)

def proba_from(X): return fuser.predict_proba(scaler.transform(X))[:,1]
p_va = proba_from(X_va)
p_ti = proba_from(X_ti) if X_ti is not None else None
p_to = proba_from(X_to) if X_to is not None else None


# ---------------- 保存融合模型 ----------------
stamp = time.strftime("%Y%m%d_%H%M%S")
save_path = f"{SAVED_DIR}/fusion3_lr_{stamp}.joblib"
joblib.dump({
    "scaler": scaler,
    "fuser": fuser,
    "config": {
        "tta_prnu": TTA_PRNU, "tta_ela": TTA_ELA,
        "input_size": INPUT_SIZE, "device": DEVICE
    },
    "components": {
        "clip_svm": Path(svm_paths[-1]).name,
        "prnu_cnn": Path(prnu_ckpts[-1]).name,
        "ela_cnn":  Path(ela_ckpts[-1]).name,
    },
    "threshold_val_youden": th_star
}, save_path)
print("\n✅ Saved fusion model:", save_path)


device = cuda
Loaded CLIP SVM: clip_linear_svm_feature_20250819_011904.joblib
Loaded PRNU CNN: prnu_cnn_i8_best_20250819_020555.pt
Loaded ELA CNN: ela_fromnpy_cnn_best_20250819_085103.pt
[val] 使用樣本數（三路齊全）：14000


CLIP load: 100%|██████████| 14000/14000 [00:00<00:00, 23670.83it/s]
PRNU logits (TTA=0): 100%|██████████| 55/55 [00:56<00:00,  1.02s/it]
ELA logits (TTA=0): 100%|██████████| 55/55 [01:51<00:00,  2.04s/it]


[test_iid] 使用樣本數（三路齊全）：14000


CLIP load: 100%|██████████| 14000/14000 [00:08<00:00, 1565.58it/s]
PRNU logits (TTA=0): 100%|██████████| 55/55 [00:58<00:00,  1.06s/it]
ELA logits (TTA=0): 100%|██████████| 55/55 [01:57<00:00,  2.14s/it]


[test_ood] 使用樣本數（三路齊全）：47570


CLIP load: 100%|██████████| 47570/47570 [00:13<00:00, 3520.00it/s]
PRNU logits (TTA=8): 100%|██████████| 186/186 [12:38<00:00,  4.08s/it]
ELA logits (TTA=8): 100%|██████████| 186/186 [35:36<00:00, 11.49s/it]
CLIP load: 100%|██████████| 47570/47570 [00:14<00:00, 3295.08it/s]
PRNU logits (TTA=8): 100%|██████████| 186/186 [12:55<00:00,  4.17s/it]
ELA logits (TTA=8): 100%|██████████| 186/186 [35:44<00:00, 11.53s/it]



✅ Saved fusion model: /home/yaya/ai-detect-proj/Script/saved_models/fusion3_lr_20250819_130718.joblib


In [1]:
# ======================================================
# Test fusion model on splits (val / test_iid / test_ood) + single sample
# ======================================================
import os, json, glob, time
from pathlib import Path
import numpy as np
from tqdm import tqdm
import joblib

import torch, torch.nn as nn
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report, confusion_matrix, roc_curve

# ---------- Paths ----------
SCRIPT_ROOT = "/home/yaya/ai-detect-proj/Script"
SAVED_DIR   = f"{SCRIPT_ROOT}/saved_models"
SPLITS_JSON = f"{SAVED_DIR}/splits_clip_feature_iid_ood.json"

# 自動抓最後一個融合模型
FUSION_PATH = sorted(Path(SAVED_DIR).glob("fusion3_lr_*.joblib"))[-1]
print("Using fusion:", FUSION_PATH.name)

# ---------- Load fusion (scaler + fuser + thr) ----------
F = joblib.load(FUSION_PATH)
scaler = F["scaler"]
fuser  = F["fuser"]
thr_star = float(F.get("threshold_val_youden", 0.5))
print("Val-Youden threshold =", thr_star)

# ---------- Backbones (同你訓練時結構) ----------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("device =", DEVICE)

# PRNU CNN（GroupNorm 小網）
class SmallForensicCNN(nn.Module):
    def __init__(self, in_ch=1):
        super().__init__()
        def blk(ci, co, groups=8):
            return nn.Sequential(
                nn.Conv2d(ci, co, 3, padding=1, bias=False),
                nn.GroupNorm(num_groups=min(groups, co), num_channels=co),
                nn.ReLU(inplace=True)
            )
        self.net = nn.Sequential(
            blk(in_ch,32), blk(32,32), nn.AvgPool2d(2),
            blk(32,64),   blk(64,64),  nn.AvgPool2d(2),
            blk(64,128),  blk(128,128), nn.AdaptiveAvgPool2d(1),
        )
        self.fc = nn.Linear(128, 1)
    def forward(self,x): return self.fc(self.net(x).flatten(1)).squeeze(1)

# ELA CNN（BN → GN）
class ELAForensicCNN(nn.Module):
    def __init__(self, in_ch=3):
        super().__init__()
        def bnblk(ci, co):
            return nn.Sequential(nn.Conv2d(ci, co, 3, padding=1, bias=False),
                                 nn.BatchNorm2d(co), nn.ReLU(inplace=True))
        def gnblk(ci, co, groups=8):
            return nn.Sequential(nn.Conv2d(ci, co, 3, padding=1, bias=False),
                                 nn.GroupNorm(num_groups=min(groups, co), num_channels=co),
                                 nn.ReLU(inplace=True))
        self.net = nn.Sequential(
            bnblk(in_ch,32), bnblk(32,32), nn.AvgPool2d(2),
            bnblk(32,64),    bnblk(64,64), nn.AvgPool2d(2),
            gnblk(64,128),   gnblk(128,128), nn.AvgPool2d(2),
            gnblk(128,256),  gnblk(256,256), nn.AdaptiveAvgPool2d(1),
        )
        self.fc = nn.Linear(256, 1)
    def forward(self,x): return self.fc(self.net(x).flatten(1)).squeeze(1)

# 載入你最近的 PRNU/ELA 權重 + CLIP SVM
prnu_ckpt = sorted(Path(SAVED_DIR).glob("prnu_cnn_i8_best_*.pt"))[-1]
ela_ckpt  = sorted(Path(SAVED_DIR).glob("ela_fromnpy_cnn_best_*.pt"))[-1]
clip_svm  = joblib.load(sorted(Path(SAVED_DIR).glob("clip_linear_svm_feature_*.joblib"))[-1])

prnu_model = SmallForensicCNN(1).to(DEVICE).eval()
ela_model  = ELAForensicCNN(3).to(DEVICE).eval()
prnu_model.load_state_dict(torch.load(prnu_ckpt, map_location=DEVICE))
ela_model.load_state_dict(torch.load(ela_ckpt,  map_location=DEVICE))
print("Loaded:", prnu_ckpt.name, "|", ela_ckpt.name)

# ---------- IO helpers ----------
def l2norm(v): 
    v = v.astype(np.float32).reshape(-1)
    n = np.linalg.norm(v) + 1e-12
    return v / n

def load_prnu_i8(path):
    a = np.load(path, mmap_mode='r')
    a = np.asarray(a)
    if a.ndim==3 and (a.shape[0]==1 or a.shape[-1]==1): a = a.squeeze()
    assert a.ndim==2 and a.dtype==np.int8, f"PRNU expect 2D int8, got {a.shape} {a.dtype}"
    return a

def avg_pool_2x(x):
    H,W = x.shape; H2,W2=(H//2)*2,(W//2)*2
    x = x[:H2,:W2]; return x.reshape(H2//2,2,W2//2,2).mean(axis=(1,3))

def prnu_per_image_norm(a_i8):
    x = a_i8.astype(np.float32, copy=False)
    m,s = x.mean(), x.std()
    if (not np.isfinite(s)) or s<1e-6: m,s = 0.0,20.0
    return (x-m)/(s if s>0 else 1.0)

def load_ela_array(path):
    z = np.load(path, mmap_mode='r')
    if isinstance(z, np.lib.npyio.NpzFile):
        for k in ('ela','arr','arr_0','data'):
            if k in z.files: a = z[k]; break
        else: a = z[z.files[0]]
    else:
        a = z
    a = np.asarray(a)
    if a.ndim==2: a = np.repeat(a[...,None],3,axis=2)
    elif a.ndim==3 and a.shape[0] in (1,3) and a.shape[-1] not in (1,3):
        a = np.transpose(a,(1,2,0))
    elif a.ndim==3 and a.shape[-1]==1:
        a = np.repeat(a,3,axis=2)
    a = a.astype(np.float32, copy=False)
    if np.nanmax(a)>1.5: a *= (1.0/255.0)
    return a

def ela_zscore(x):
    m = x.mean(axis=(0,1), keepdims=True)
    s = x.std(axis=(0,1), keepdims=True); s[s<1e-6]=1.0
    return (x-m)/s

def crop_center_or_rand(x, size=256, center=True, rng=None):
    h,w = x.shape[:2]
    if h<size or w<size:
        ph,pw=max(0,size-h),max(0,size-w)
        if x.ndim==2: x=np.pad(x,((ph//2, ph-ph//2),(pw//2, pw-pw//2)),mode='edge')
        else: x=np.pad(x,((ph//2, ph-ph//2),(pw//2, pw-pw//2),(0,0)),mode='edge')
        h,w=x.shape[:2]
    if h==size and w==size: return x.copy()
    if center:
        y0,x0=(h-size)//2,(w-size)//2
    else:
        if rng is None: rng=np.random.default_rng()
        y0=int(rng.integers(0, h-size+1)); x0=int(rng.integers(0, w-size+1))
    return x[y0:y0+size, x0:x0+size].copy() if x.ndim==2 else x[y0:y0+size, x0:x0+size,:].copy()

# ---------- Score extractors (可調 batch / TTA / AMP) ----------
INPUT_SIZE=256
BATCH_CLIP=4096
BATCH_PRNU=512
BATCH_ELA =512
TTA_PRNU=0
TTA_ELA =0

def clip_scores(paths):
    out=[]
    for i in tqdm(range(0,len(paths),BATCH_CLIP), desc="CLIP load"):
        X = np.vstack([l2norm(np.load(p, allow_pickle=True)) for p in paths[i:i+BATCH_CLIP]]).astype(np.float32)
        out.append(clip_svm.decision_function(X).astype(np.float32))
    return np.concatenate(out,0)

@torch.no_grad()
def prnu_logits(paths, tta=TTA_PRNU):
    out = np.zeros(len(paths), np.float32)
    rng = np.random.default_rng(1337)
    for i in tqdm(range(0,len(paths),BATCH_PRNU), desc=f"PRNU logits (TTA={tta})"):
        chunk = paths[i:i+BATCH_PRNU]
        acc = np.zeros(len(chunk), np.float32); reps=max(1,int(tta))
        for _ in range(reps):
            xs=[]
            for p in chunk:
                a = load_prnu_i8(p)
                if a.shape[0]>=512 and a.shape[1]>=512: a = avg_pool_2x(a)
                a = crop_center_or_rand(a, INPUT_SIZE, center=(tta==0), rng=rng)
                a = prnu_per_image_norm(a)
                xs.append(torch.from_numpy(a).unsqueeze(0))
            xb = torch.stack(xs,0).to(DEVICE).contiguous(memory_format=torch.channels_last)
            with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
                logit = prnu_model(xb)
            acc += logit.float().cpu().numpy().astype(np.float32)
        out[i:i+BATCH_PRNU] = acc/float(reps)
    return out

@torch.no_grad()
def ela_logits(paths, tta=TTA_ELA):
    out = np.zeros(len(paths), np.float32)
    rng = np.random.default_rng(1337)
    for i in tqdm(range(0,len(paths),BATCH_ELA), desc=f"ELA logits (TTA={tta})"):
        chunk = paths[i:i+BATCH_ELA]
        acc = np.zeros(len(chunk), np.float32); reps=max(1,int(tta))
        for _ in range(reps):
            xs=[]
            for p in chunk:
                x = load_ela_array(p)
                x = crop_center_or_rand(x, INPUT_SIZE, center=(tta==0), rng=rng)
                x = ela_zscore(x)
                xs.append(torch.from_numpy(np.transpose(x,(2,0,1))))
            xb = torch.stack(xs,0).to(DEVICE).contiguous(memory_format=torch.channels_last)
            with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
                logit = ela_model(xb)
            acc += logit.float().cpu().numpy().astype(np.float32)
        out[i:i+BATCH_ELA] = acc/float(reps)
    return out

# ---------- Read splits & build aligned triplets ----------
with open(SPLITS_JSON, "r", encoding="utf-8") as f:
    J = json.load(f)
SPL = J["splits"]

def _get_split(name):
    real = SPL[name]["clip"]["real"]; fake = SPL[name]["clip"]["fake"]
    clipP = real + fake
    y = np.array([0]*len(real) + [1]*len(fake), dtype=int)
    # 以目錄快速推測 PRNU/ELA 路徑（與你訓練一致）
    prnu_real = (J["meta"]["dirs"]["prnu"]["real"])
    prnu_fake = (J["meta"]["dirs"]["prnu"]["fake"])
    ela_real  = (J["meta"]["dirs"]["ela"]["real"])
    ela_fake  = (J["meta"]["dirs"]["ela"]["fake"])
    prnuP=[]; elaP=[]
    miss_prnu=miss_ela=0
    for p,yi in zip(clipP, y):
        stem = Path(p).stem
        pr = Path(prnu_real if yi==0 else prnu_fake)/f"{stem}.npy"
        eln= Path(ela_real  if yi==0 else ela_fake )/f"{stem}.npy"
        elz= Path(ela_real  if yi==0 else ela_fake )/f"{stem}.npz"
        prnuP.append(str(pr) if pr.is_file() else None);  miss_prnu += int(not pr.is_file())
        elaP.append( str(eln) if eln.is_file() else (str(elz) if elz.is_file() else None) )
        miss_ela  += int(not (eln.is_file() or elz.is_file()))
    keep=[i for i,(a,b) in enumerate(zip(prnuP, elaP)) if (a and b)]
    if miss_prnu or miss_ela: print(f"[{name}] 缺檔 PRNU:{miss_prnu}  ELA:{miss_ela} → 使用齊全 {len(keep)}")
    clipP=[clipP[i] for i in keep]; prnuP=[prnuP[i] for i in keep]; elaP=[elaP[i] for i in keep]; y=y[keep]
    return clipP, prnuP, elaP, y

def scores_for(clipP, prnuP, elaP):
    s_clip = clip_scores(clipP)
    z_prnu = prnu_logits(prnuP, tta=TTA_PRNU)
    z_ela  = ela_logits(elaP,  tta=TTA_ELA)
    X = np.stack([s_clip, z_prnu, z_ela], 1).astype(np.float32)
    return X

def eval_block(name, X, y, thr=thr_star):
    Xs = scaler.transform(X)
    p  = fuser.predict_proba(Xs)[:,1]
    pred = (p>=thr).astype(int)
    acc = accuracy_score(y, pred)
    auc = roc_auc_score(y, p)
    print(f"\n[{name}] acc@thr={acc:.4f} | auc={auc:.4f} | n={len(y)}")
    print(confusion_matrix(y, pred))
    print(classification_report(y, pred, target_names=["real(0)","fake(1)"], digits=4))
    return p

# ---------- Run on splits ----------
for split in ["val","test_iid","test_ood"]:
    if split not in SPL: continue
    clipP, prnuP, elaP, y = _get_split(split)
    X = scores_for(clipP, prnuP, elaP)
    _ = eval_block(split.upper(), X, y, thr=thr_star)

# ---------- Single-sample inference ----------
def predict_one_by_stem(stem, is_real_unknown=True):
    # 從三個目錄拼出三路檔案
    prnu_real = (J["meta"]["dirs"]["prnu"]["real"]); prnu_fake = (J["meta"]["dirs"]["prnu"]["fake"])
    ela_real  = (J["meta"]["dirs"]["ela"]["real"]);  ela_fake  = (J["meta"]["dirs"]["ela"]["fake"])
    # 先嘗試 real 目錄，找不到再用 fake（若你已知真假可直接指定）
    pr = Path(prnu_real)/f"{stem}.npy"; 
    if not pr.is_file(): pr = Path(prnu_fake)/f"{stem}.npy"
    el = Path(ela_real)/f"{stem}.npy"; 
    if not el.is_file(): el = Path(ela_real)/f"{stem}.npz"
    if not el.exists():   el = Path(ela_fake)/f"{stem}.npy"
    if not el.exists():   el = Path(ela_fake)/f"{stem}.npz"
    cp = None
    # 在 clip 兩個資料夾找向量
    for cdir in [f"{SCRIPT_ROOT}/features_npy/clip_real_npy", f"{SCRIPT_ROOT}/features_npy/clip_fake_npy"]:
        cand = Path(cdir)/f"{stem}.npy"
        if cand.is_file(): cp = cand; break
    assert cp and pr.is_file() and el.exists(), f"找不到三路檔案：{stem}"
    # 算三路分數 → fusion
    s_clip = clip_scores([str(cp)])
    z_prnu = prnu_logits([str(pr)], tta=0)
    z_ela  = ela_logits ([str(el)], tta=0)
    X = np.stack([s_clip, z_prnu, z_ela], 1).astype(np.float32)
    p = fuser.predict_proba(scaler.transform(X))[:,1][0]
    yhat = int(p>=thr_star)
    return float(p), yhat  # prob_fake, 1=fake

# 範例：prob, pred = predict_one_by_stem("some_image_stem")


Using fusion: fusion3_lr_20250819_130718.joblib
Val-Youden threshold = 0.6154414870978013
device = cuda
Loaded: prnu_cnn_i8_best_20250819_020555.pt | ela_fromnpy_cnn_best_20250819_130848.pt


CLIP load: 100%|██████████| 4/4 [00:04<00:00,  1.15s/it]
PRNU logits (TTA=0): 100%|██████████| 28/28 [05:18<00:00, 11.36s/it]
ELA logits (TTA=0): 100%|██████████| 28/28 [04:13<00:00,  9.05s/it]



[VAL] acc@thr=0.9644 | auc=0.9885 | n=14000
[[6959   41]
 [ 458 6542]]
              precision    recall  f1-score   support

     real(0)     0.9382    0.9941    0.9654      7000
     fake(1)     0.9938    0.9346    0.9633      7000

    accuracy                         0.9644     14000
   macro avg     0.9660    0.9644    0.9643     14000
weighted avg     0.9660    0.9644    0.9643     14000



CLIP load: 100%|██████████| 4/4 [00:04<00:00,  1.13s/it]
PRNU logits (TTA=0): 100%|██████████| 28/28 [05:22<00:00, 11.53s/it]
ELA logits (TTA=0): 100%|██████████| 28/28 [03:48<00:00,  8.17s/it]



[TEST_IID] acc@thr=0.9642 | auc=0.9841 | n=14000
[[6956   44]
 [ 457 6543]]
              precision    recall  f1-score   support

     real(0)     0.9384    0.9937    0.9652      7000
     fake(1)     0.9933    0.9347    0.9631      7000

    accuracy                         0.9642     14000
   macro avg     0.9658    0.9642    0.9642     14000
weighted avg     0.9658    0.9642    0.9642     14000



CLIP load: 100%|██████████| 12/12 [00:18<00:00,  1.52s/it]
PRNU logits (TTA=0):   5%|▌         | 5/93 [01:04<18:57, 12.92s/it]


KeyboardInterrupt: 

In [2]:
# ============================================================
# 通用特徵測試：CLIP / PRNU / ELA 逐路評估 + OOD 逐資料集拆解
# ============================================================
import os, json, glob, math, time, re
from pathlib import Path
from collections import defaultdict, Counter

import numpy as np
from tqdm import tqdm
import joblib

import torch, torch.nn as nn
from sklearn.metrics import (accuracy_score, roc_auc_score, classification_report,
                             confusion_matrix, roc_curve, precision_recall_fscore_support)

# ---------------- 基本設定 ----------------
SCRIPT_ROOT = "/home/yaya/ai-detect-proj/Script"
SAVED_DIR   = f"{SCRIPT_ROOT}/saved_models"
SPLITS_JSON = f"{SAVED_DIR}/splits_clip_feature_iid_ood.json"

# 執行設定：可依硬體與需求調整
DEVICE      = "cuda" if torch.cuda.is_available() else "cpu"
AMP         = True                 # 推論時使用 AMP
BATCH_CLIP  = 4096
BATCH_PRNU  = 512
BATCH_ELA   = 512
INPUT_SIZE  = 256
TTA_PRNU    = 0                    # 先設 0（測速）；要更穩可拉到 4~8
TTA_ELA     = 0

print("device =", DEVICE)

# ---------------- 來源前綴解析（與你 split 腳本一致） ----------------
SEPS = ("__", "---", "--", "_", "-", " ")
ALIASES = {
    # real
    "imagenet1k":"imagenet", "imgnet":"imagenet", "imagenet":"imagenet",
    "unslpash":"unsplash", "unsplash":"unsplash",
    "flicker30k":"flickr30k", "flicker30K":"flickr30k", "flickr30k":"flickr30k",
    "places365":"places365", "coco2017":"coco2017", "div2k":"div2k",
    # fake
    "sd3":"sd3", "sdxl":"sd3",
    "flux":"flux", "black-forest-labs":"flux",
    "dalle3":"dalle3", "dalle-3":"dalle3",
    "midjourney-v6-llava":"midjourney", "midjourney":"midjourney"
}
def canonical(tag:str)->str:
    return ALIASES.get(tag.lower().strip(), tag.lower().strip())

def infer_tag_from_stem(stem:str, is_real:bool)->str:
    for k in sorted(ALIASES.keys(), key=lambda s: -len(s)):
        if stem.lower().startswith(k):
            return canonical(k)
    cut = None
    for s in SEPS:
        i = stem.find(s)
        if i != -1:
            cut = i if cut is None else min(cut, i)
    tag = stem[:cut] if cut is not None else stem
    if (not tag) or tag.isdigit():
        tag = "imagenet" if is_real else "unknown"
    return canonical(tag)

# ---------------- 讀 splits ----------------
assert os.path.isfile(SPLITS_JSON), f"找不到 splits json：{SPLITS_JSON}"
with open(SPLITS_JSON, "r", encoding="utf-8") as f:
    J = json.load(f)

META = J.get("meta", {})
DIRS = META.get("dirs", {})
CLIP_REAL_DIR = DIRS.get("clip", {}).get("real", f"{SCRIPT_ROOT}/features_npy/clip_real_npy")
CLIP_FAKE_DIR = DIRS.get("clip", {}).get("fake", f"{SCRIPT_ROOT}/features_npy/clip_fake_npy")
PRNU_REAL_DIR = DIRS.get("prnu", {}).get("real", f"{SCRIPT_ROOT}/features_i8/prnu_real_i8_npy")
PRNU_FAKE_DIR = DIRS.get("prnu", {}).get("fake", f"{SCRIPT_ROOT}/features_i8/prnu_fake_i8_npy")
ELA_REAL_DIR  = DIRS.get("ela",  {}).get("real", f"{SCRIPT_ROOT}/features_npy/ela_real_npy")
ELA_FAKE_DIR  = DIRS.get("ela",  {}).get("fake", f"{SCRIPT_ROOT}/features_npy/ela_fake_npy")

SPL = J["splits"]

def is_new_split_format(s):
    return isinstance(s, dict) and "clip" in s

def get_clip_lists(name):
    # 回傳 clip 路徑（real+fake）與標籤、stem
    S = SPL[name]
    if is_new_split_format(S):
        real = S["clip"]["real"]; fake = S["clip"]["fake"]
    else:
        # 老格式：直接是一串檔案路徑
        allp = S
        real = [p for p in allp if "/clip_real_npy/" in Path(p).as_posix()]
        fake = [p for p in allp if "/clip_fake_npy/" in Path(p).as_posix()]
    clip_paths = real + fake
    y = np.array([0]*len(real) + [1]*len(fake), dtype=int)
    stems = [Path(p).stem for p in clip_paths]
    return clip_paths, y, stems

def map_by_stem(stems, y, real_dir, fake_dir, suff=(".npy",)):
    # 依 stem 在 real/fake 目錄組路徑；找不到給 None
    out=[]; miss=0
    for st, yi in zip(stems, y):
        base = Path(real_dir if yi==0 else fake_dir)
        found=None
        for sx in suff:
            cand = base/f"{st}{sx}"
            if cand.is_file():
                found=str(cand); break
        if found is None:
            miss+=1
        out.append(found)
    return out, miss

# ---------------- 載入/建置各路模型與讀檔 ----------------
def l2norm(v):
    v = np.asarray(v, dtype=np.float32).reshape(-1)
    n = np.linalg.norm(v) + 1e-12
    return v / n

# CLIP：用你已訓練的 LinearSVC
CLIP_SVM = joblib.load(sorted(Path(SAVED_DIR).glob("clip_linear_svm_feature_*.joblib"))[-1])
def clip_scores(paths):
    # 每個 .npy 是一個向量；決策函數越大越像「fake」
    out=[]
    for i in tqdm(range(0,len(paths),BATCH_CLIP), desc="CLIP load"):
        chunk = paths[i:i+BATCH_CLIP]
        X = np.vstack([l2norm(np.load(p, allow_pickle=True)) for p in chunk]).astype(np.float32)
        out.append(CLIP_SVM.decision_function(X).astype(np.float32))
    return np.concatenate(out, 0)

# PRNU CNN（你的小網）
class SmallForensicCNN(nn.Module):
    def __init__(self, in_ch=1):
        super().__init__()
        def blk(ci, co, groups=8):
            return nn.Sequential(
                nn.Conv2d(ci, co, 3, padding=1, bias=False),
                nn.GroupNorm(num_groups=min(groups, co), num_channels=co),
                nn.ReLU(inplace=True)
            )
        self.net = nn.Sequential(
            blk(in_ch,32), blk(32,32), nn.AvgPool2d(2),
            blk(32,64),   blk(64,64),  nn.AvgPool2d(2),
            blk(64,128),  blk(128,128), nn.AdaptiveAvgPool2d(1),
        )
        self.fc = nn.Linear(128, 1)
    def forward(self,x): return self.fc(self.net(x).flatten(1)).squeeze(1)

PRNU_CKPT = sorted(Path(SAVED_DIR).glob("prnu_cnn_i8_best_*.pt"))[-1]
prnu_model = SmallForensicCNN(1).to(DEVICE).eval()
prnu_model.load_state_dict(torch.load(PRNU_CKPT, map_location=DEVICE))
def load_prnu_i8(path):
    a = np.load(path, mmap_mode='r')
    a = np.asarray(a)
    if a.ndim==3 and (a.shape[0]==1 or a.shape[-1]==1): a = a.squeeze()
    assert a.ndim==2 and a.dtype==np.int8, f"PRNU expect 2D int8, got {a.shape} {a.dtype} from {path}"
    return a
def avg_pool_2x(x):
    H,W = x.shape; H2,W2=(H//2)*2,(W//2)*2
    x = x[:H2,:W2]; return x.reshape(H2//2,2,W2//2,2).mean(axis=(1,3))
def prnu_per_image_norm(a_i8):
    x = a_i8.astype(np.float32, copy=False)
    m,s = x.mean(), x.std()
    if (not np.isfinite(s)) or s<1e-6: m,s=0.0,20.0
    return (x-m)/(s if s>0 else 1.0)
@torch.no_grad()
def prnu_logits(paths, tta=TTA_PRNU):
    out = np.zeros(len(paths), np.float32)
    rng = np.random.default_rng(1337)
    for i in tqdm(range(0,len(paths),BATCH_PRNU), desc=f"PRNU logits (TTA={tta})"):
        chunk = paths[i:i+BATCH_PRNU]
        acc = np.zeros(len(chunk), np.float32); reps=max(1,int(tta))
        for _ in range(reps):
            xs=[]
            for p in chunk:
                a = load_prnu_i8(p)
                if a.shape[0]>=512 and a.shape[1]>=512: a = avg_pool_2x(a)
                a = crop_center_or_rand(a, INPUT_SIZE, center=(tta==0), rng=rng)
                a = prnu_per_image_norm(a)
                xs.append(torch.from_numpy(a).unsqueeze(0))
            xb = torch.stack(xs,0).to(DEVICE).contiguous(memory_format=torch.channels_last)
            with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=AMP):
                logit = prnu_model(xb)
            acc += logit.float().cpu().numpy().astype(np.float32)
        out[i:i+BATCH_PRNU] = acc/float(reps)
    return out

# ELA CNN
class ELAForensicCNN(nn.Module):
    def __init__(self, in_ch=3):
        super().__init__()
        def bnblk(ci, co):
            return nn.Sequential(
                nn.Conv2d(ci, co, 3, padding=1, bias=False),
                nn.BatchNorm2d(co),
                nn.ReLU(inplace=True))
        def gnblk(ci, co, groups=8):
            return nn.Sequential(
                nn.Conv2d(ci, co, 3, padding=1, bias=False),
                nn.GroupNorm(num_groups=min(groups, co), num_channels=co),
                nn.ReLU(inplace=True))
        self.net = nn.Sequential(
            bnblk(in_ch,32), bnblk(32,32), nn.AvgPool2d(2),
            bnblk(32,64),    bnblk(64,64), nn.AvgPool2d(2),
            gnblk(64,128),   gnblk(128,128), nn.AvgPool2d(2),
            gnblk(128,256),  gnblk(256,256), nn.AdaptiveAvgPool2d(1),
        )
        self.fc = nn.Linear(256, 1)
    def forward(self,x): return self.fc(self.net(x).flatten(1)).squeeze(1)

ELA_CKPT = sorted(Path(SAVED_DIR).glob("ela_fromnpy_cnn_best_*.pt"))[-1]
ela_model = ELAForensicCNN(3).to(DEVICE).eval()
ela_model.load_state_dict(torch.load(ELA_CKPT, map_location=DEVICE))
def load_ela_array(path):
    z = np.load(path, mmap_mode='r')
    if isinstance(z, np.lib.npyio.NpzFile):
        for k in ('ela','arr','arr_0','data'):
            if k in z.files: a = z[k]; break
        else: a = z[z.files[0]]
    else:
        a = z
    a = np.asarray(a)
    if a.ndim==2: a = np.repeat(a[...,None],3,axis=2)
    elif a.ndim==3 and a.shape[0] in (1,3) and a.shape[-1] not in (1,3):
        a = np.transpose(a,(1,2,0))
    elif a.ndim==3 and a.shape[-1]==1:
        a = np.repeat(a,3,axis=2)
    a = a.astype(np.float32, copy=False)
    if np.nanmax(a)>1.5: a *= (1.0/255.0)
    return a
def ela_zscore(x):
    m = x.mean(axis=(0,1), keepdims=True)
    s = x.std(axis=(0,1), keepdims=True); s[s<1e-6]=1.0
    return (x-m)/s
@torch.no_grad()
def ela_logits(paths, tta=TTA_ELA):
    out = np.zeros(len(paths), np.float32)
    rng = np.random.default_rng(1337)
    for i in tqdm(range(0,len(paths),BATCH_ELA), desc=f"ELA logits (TTA={tta})"):
        chunk = paths[i:i+BATCH_ELA]
        acc = np.zeros(len(chunk), np.float32); reps=max(1,int(tta))
        for _ in range(reps):
            xs=[]
            for p in chunk:
                x = load_ela_array(p)
                x = crop_center_or_rand(x, INPUT_SIZE, center=(tta==0), rng=rng)
                x = ela_zscore(x)
                xs.append(torch.from_numpy(np.transpose(x,(2,0,1))))
            xb = torch.stack(xs,0).to(DEVICE).contiguous(memory_format=torch.channels_last)
            with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=AMP):
                logit = ela_model(xb)
            acc += logit.float().cpu().numpy().astype(np.float32)
        out[i:i+BATCH_ELA] = acc/float(reps)
    return out

# 通用裁切（和你前面一致）
def crop_center_or_rand(x, size=256, center=True, rng=None):
    h,w = x.shape[:2]
    if h<size or w<size:
        ph,pw=max(0,size-h),max(0,size-w)
        if x.ndim==2: x=np.pad(x,((ph//2, ph-ph//2),(pw//2, pw-pw//2)),mode='edge')
        else: x=np.pad(x,((ph//2, ph-ph//2),(pw//2, pw-pw//2),(0,0)),mode='edge')
        h,w=x.shape[:2]
    if h==size and w==size: return x.copy()
    if center:
        y0,x0=(h-size)//2,(w-size)//2
    else:
        if rng is None: rng=np.random.default_rng()
        y0=int(rng.integers(0, h-size+1)); x0=int(rng.integers(0, w-size+1))
    return x[y0:y0+size, x0:x0+size].copy() if x.ndim==2 else x[y0:y0+size, x0:x0+size,:].copy()

# ---------------- 驗證門檻（Youden / FPR@5%） ----------------
def thr_youden(y, s):
    fpr, tpr, thr = roc_curve(y, s)
    j = tpr - fpr
    return float(thr[int(np.argmax(j))])

def thr_fpr_at(y, s, target_fpr=0.05):
    fpr, tpr, thr = roc_curve(y, s)
    # 找到 <= target_fpr 的最大門檻
    ok = np.where(fpr <= target_fpr)[0]
    if len(ok)==0: return float(thr[-1])
    return float(thr[int(ok[-1])])

# ---------------- 逐路評估主程式 ----------------
def evaluate_feature(feature:str):
    assert feature in ("clip","prnu","ela")

    results = {}
    tags_cache = {}

    for split in ["val","test_iid","test_ood"]:
        if split not in SPL: continue

        clipP, y, stems = get_clip_lists(split)

        if feature=="clip":
            paths = clipP
        elif feature=="prnu":
            paths, miss = map_by_stem(stems, y, PRNU_REAL_DIR, PRNU_FAKE_DIR, (".npy",))
            if miss: print(f"[{split}|PRNU] 缺檔 {miss}，將忽略。")
        else: # ela
            paths, miss = map_by_stem(stems, y, ELA_REAL_DIR, ELA_FAKE_DIR, (".npy",".npz"))
            if miss: print(f"[{split}|ELA] 缺檔 {miss}，將忽略。")

        # 過濾缺檔
        keep = [i for i,p in enumerate(paths) if p is not None]
        paths = [paths[i] for i in keep]; yy = y[keep]
        ss   = [stems[i] for i in keep]

        print(f"\n[{feature.upper()}] {split} 使用樣本數：{len(paths)}")

        # 計分
        if feature=="clip":
            scores = clip_scores(paths)
        elif feature=="prnu":
            scores = prnu_logits(paths, tta=TTA_PRNU)
        else:
            scores = ela_logits(paths, tta=TTA_ELA)

        # 在 Val 上選門檻
        if split=="val":
            thrJ  = thr_youden(yy, scores)
            thr5  = thr_fpr_at(yy, scores, target_fpr=0.05)
            print(f"[{feature.upper()}|Val] thresholds → youden={thrJ:.3f} | fpr@5%={thr5:.3f}")
            results["thr_youden"] = thrJ
            results["thr_fpr5"]   = thr5

        # 總體評估（用 Val-Youden 門檻）
        thr_used = results.get("thr_youden", 0.5)
        pred = (scores >= thr_used).astype(int)
        acc  = accuracy_score(yy, pred)
        try:
            auc  = roc_auc_score(yy, scores)
        except ValueError:
            auc = float("nan")
        print(f"[{feature.upper()}|{split}] acc@thr={acc:.4f} | auc={auc:.4f} | thr={thr_used:.3f}")
        print(confusion_matrix(yy, pred))
        print(classification_report(yy, pred, target_names=["real(0)","fake(1)"], digits=4))

        # 存放以便 OOD 分析
        results[split] = {"y": yy, "scores": scores, "pred": pred, "stems": ss, "paths": paths}

    # -------- OOD 逐資料集拆解（若有 OOD） --------
    if "test_ood" in results:
        y  = results["test_ood"]["y"]
        ss = results["test_ood"]["stems"]
        sc = results["test_ood"]["scores"]
        pr = results["test_ood"]["pred"]

        # 以 label 決定 real/fake，再由 stem 推 tag
        keys = []
        for yi, st in zip(y, ss):
            tag = infer_tag_from_stem(st, is_real=(yi==0))
            keys.append(("real" if yi==0 else "fake") + ":" + tag)

        # 聚合指標
        groups = defaultdict(list)
        for i,k in enumerate(keys):
            groups[k].append(i)

        print(f"\n== OOD per-dataset ({feature.upper()}) | thr=Val-Youden ==")
        rows=[]
        for k, idxs in sorted(groups.items(), key=lambda kv: -len(kv[1])):
            yy = y[idxs]; pp = pr[idxs]; ss_ = sc[idxs]
            n  = len(idxs)
            acc = accuracy_score(yy, pp)
            # class-wise recall
            prec, rec, f1, _ = precision_recall_fscore_support(yy, pp, labels=[0,1], zero_division=0)
            r_real, r_fake = rec[0], rec[1]
            try:
                auc = roc_auc_score(yy, ss_)
            except ValueError:
                auc = float("nan")
            print(f"- {k:12s} n={n:5d} | acc={acc:.4f} | r_real={r_real:.4f} r_fake={r_fake:.4f} | auc={auc if not np.isnan(auc) else 'nan'}")
            rows.append((k, n, acc, r_real, r_fake, auc))

        # 額外列出「表現最差 Top-8」（依 acc 升序）
        worst = sorted(rows, key=lambda x: (x[2], x[1]))[:8]
        print("\n>>> 最差 Top-8（依 acc）")
        for k,n,acc,r0,r1,auc in worst:
            print(f"  {k:12s} n={n:5d} | acc={acc:.4f} | r_real={r0:.4f} r_fake={r1:.4f} | auc={auc if not np.isnan(auc) else 'nan'}")

    return results

# ================== 執行：分別測試三條路 ==================
print("\n===== Evaluate: CLIP =====")
RES_CLIP = evaluate_feature("clip")

print("\n===== Evaluate: PRNU =====")
RES_PRNU = evaluate_feature("prnu")

print("\n===== Evaluate: ELA =====")
RES_ELA  = evaluate_feature("ela")

print("\n完成。你可以從各路的 '最差 Top-8' 直接看到 OOD 是哪個資料集在拖分。")


device = cuda

===== Evaluate: CLIP =====

[CLIP] val 使用樣本數：14000


CLIP load: 100%|██████████| 4/4 [00:01<00:00,  2.88it/s]


[CLIP|Val] thresholds → youden=0.081 | fpr@5%=-0.225
[CLIP|val] acc@thr=0.9611 | auc=0.9842 | thr=0.081
[[6880  120]
 [ 424 6576]]
              precision    recall  f1-score   support

     real(0)     0.9419    0.9829    0.9620      7000
     fake(1)     0.9821    0.9394    0.9603      7000

    accuracy                         0.9611     14000
   macro avg     0.9620    0.9611    0.9611     14000
weighted avg     0.9620    0.9611    0.9611     14000


[CLIP] test_iid 使用樣本數：14000


CLIP load: 100%|██████████| 4/4 [00:00<00:00,  6.47it/s]


[CLIP|test_iid] acc@thr=0.9624 | auc=0.9834 | thr=0.081
[[6913   87]
 [ 440 6560]]
              precision    recall  f1-score   support

     real(0)     0.9402    0.9876    0.9633      7000
     fake(1)     0.9869    0.9371    0.9614      7000

    accuracy                         0.9624     14000
   macro avg     0.9635    0.9624    0.9623     14000
weighted avg     0.9635    0.9624    0.9623     14000


[CLIP] test_ood 使用樣本數：47570


CLIP load: 100%|██████████| 12/12 [00:01<00:00,  6.35it/s]


[CLIP|test_ood] acc@thr=0.6222 | auc=0.7130 | thr=0.081
[[ 8210 15575]
 [ 2398 21387]]
              precision    recall  f1-score   support

     real(0)     0.7739    0.3452    0.4774     23785
     fake(1)     0.5786    0.8992    0.7041     23785

    accuracy                         0.6222     47570
   macro avg     0.6763    0.6222    0.5908     47570
weighted avg     0.6763    0.6222    0.5908     47570


== OOD per-dataset (CLIP) | thr=Val-Youden ==
- fake:midjourney n=17273 | acc=0.8824 | r_real=0.0000 r_fake=0.8824 | auc=nan
- real:places365 n=15000 | acc=0.1260 | r_real=0.1260 r_fake=0.0000 | auc=nan
- real:coco2017 n= 8785 | acc=0.7194 | r_real=0.7194 r_fake=0.0000 | auc=nan
- fake:dalle3  n= 6512 | acc=0.9438 | r_real=0.0000 r_fake=0.9438 | auc=nan

>>> 最差 Top-8（依 acc）
  real:places365 n=15000 | acc=0.1260 | r_real=0.1260 r_fake=0.0000 | auc=nan
  real:coco2017 n= 8785 | acc=0.7194 | r_real=0.7194 r_fake=0.0000 | auc=nan
  fake:midjourney n=17273 | acc=0.8824 | r_real=0.000

PRNU logits (TTA=0): 100%|██████████| 28/28 [05:43<00:00, 12.26s/it]


[PRNU|Val] thresholds → youden=-0.189 | fpr@5%=1.226
[PRNU|val] acc@thr=0.8759 | auc=0.9321 | thr=-0.189
[[6156  844]
 [ 894 6106]]
              precision    recall  f1-score   support

     real(0)     0.8732    0.8794    0.8763      7000
     fake(1)     0.8786    0.8723    0.8754      7000

    accuracy                         0.8759     14000
   macro avg     0.8759    0.8759    0.8759     14000
weighted avg     0.8759    0.8759    0.8759     14000


[PRNU] test_iid 使用樣本數：14000


PRNU logits (TTA=0): 100%|██████████| 28/28 [05:05<00:00, 10.90s/it]


[PRNU|test_iid] acc@thr=0.8685 | auc=0.9286 | thr=-0.189
[[6086  914]
 [ 927 6073]]
              precision    recall  f1-score   support

     real(0)     0.8678    0.8694    0.8686      7000
     fake(1)     0.8692    0.8676    0.8684      7000

    accuracy                         0.8685     14000
   macro avg     0.8685    0.8685    0.8685     14000
weighted avg     0.8685    0.8685    0.8685     14000


[PRNU] test_ood 使用樣本數：47570


PRNU logits (TTA=0): 100%|██████████| 93/93 [15:26<00:00,  9.96s/it]


[PRNU|test_ood] acc@thr=0.6100 | auc=0.6544 | thr=-0.189
[[ 9508 14277]
 [ 4277 19508]]
              precision    recall  f1-score   support

     real(0)     0.6897    0.3997    0.5061     23785
     fake(1)     0.5774    0.8202    0.6777     23785

    accuracy                         0.6100     47570
   macro avg     0.6336    0.6100    0.5919     47570
weighted avg     0.6336    0.6100    0.5919     47570


== OOD per-dataset (PRNU) | thr=Val-Youden ==
- fake:midjourney n=17273 | acc=0.8487 | r_real=0.0000 r_fake=0.8487 | auc=nan
- real:places365 n=15000 | acc=0.1281 | r_real=0.1281 r_fake=0.0000 | auc=nan
- real:coco2017 n= 8785 | acc=0.8635 | r_real=0.8635 r_fake=0.0000 | auc=nan
- fake:dalle3  n= 6512 | acc=0.7446 | r_real=0.0000 r_fake=0.7446 | auc=nan

>>> 最差 Top-8（依 acc）
  real:places365 n=15000 | acc=0.1281 | r_real=0.1281 r_fake=0.0000 | auc=nan
  fake:dalle3  n= 6512 | acc=0.7446 | r_real=0.0000 r_fake=0.7446 | auc=nan
  fake:midjourney n=17273 | acc=0.8487 | r_real=0.000

ELA logits (TTA=0): 100%|██████████| 28/28 [03:47<00:00,  8.14s/it]


[ELA|Val] thresholds → youden=0.163 | fpr@5%=-0.137
[ELA|val] acc@thr=0.9309 | auc=0.9679 | thr=0.163
[[6705  295]
 [ 672 6328]]
              precision    recall  f1-score   support

     real(0)     0.9089    0.9579    0.9327      7000
     fake(1)     0.9555    0.9040    0.9290      7000

    accuracy                         0.9309     14000
   macro avg     0.9322    0.9309    0.9309     14000
weighted avg     0.9322    0.9309    0.9309     14000


[ELA] test_iid 使用樣本數：14000


ELA logits (TTA=0): 100%|██████████| 28/28 [03:44<00:00,  8.03s/it]


[ELA|test_iid] acc@thr=0.9341 | auc=0.9686 | thr=0.163
[[6727  273]
 [ 650 6350]]
              precision    recall  f1-score   support

     real(0)     0.9119    0.9610    0.9358      7000
     fake(1)     0.9588    0.9071    0.9322      7000

    accuracy                         0.9341     14000
   macro avg     0.9353    0.9341    0.9340     14000
weighted avg     0.9353    0.9341    0.9340     14000


[ELA] test_ood 使用樣本數：47570


ELA logits (TTA=0): 100%|██████████| 93/93 [14:14<00:00,  9.18s/it]

[ELA|test_ood] acc@thr=0.5665 | auc=0.6163 | thr=0.163
[[ 7665 16120]
 [ 4503 19282]]
              precision    recall  f1-score   support

     real(0)     0.6299    0.3223    0.4264     23785
     fake(1)     0.5447    0.8107    0.6516     23785

    accuracy                         0.5665     47570
   macro avg     0.5873    0.5665    0.5390     47570
weighted avg     0.5873    0.5665    0.5390     47570


== OOD per-dataset (ELA) | thr=Val-Youden ==
- fake:midjourney n=17273 | acc=0.9907 | r_real=0.0000 r_fake=0.9907 | auc=nan
- real:places365 n=15000 | acc=0.0268 | r_real=0.0268 r_fake=0.0000 | auc=nan
- real:coco2017 n= 8785 | acc=0.8268 | r_real=0.8268 r_fake=0.0000 | auc=nan
- fake:dalle3  n= 6512 | acc=0.3332 | r_real=0.0000 r_fake=0.3332 | auc=nan

>>> 最差 Top-8（依 acc）
  real:places365 n=15000 | acc=0.0268 | r_real=0.0268 r_fake=0.0000 | auc=nan
  fake:dalle3  n= 6512 | acc=0.3332 | r_real=0.0000 r_fake=0.3332 | auc=nan
  real:coco2017 n= 8785 | acc=0.8268 | r_real=0.8268 r_f


