In [3]:
# ======================================================
# ELA(npy/npz) → CNN  — 直接讀取預先計算好的 ELA 陣列（修版：可直接跑通）
# ======================================================
import os, glob, time, math, random, json, warnings
from pathlib import Path
import numpy as np
from tqdm import tqdm

import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report, confusion_matrix, roc_curve

# ---------------- Config ----------------
SCRIPT_ROOT = "/home/yaya/ai-detect-proj/Script"
REAL_DIR = os.path.join(SCRIPT_ROOT, "features_npy", "ela_real_npy")
FAKE_DIR = os.path.join(SCRIPT_ROOT, "features_npy", "ela_fake_npy")
OUTPUT_DIR = os.path.join(SCRIPT_ROOT, "saved_models")
os.makedirs(OUTPUT_DIR, exist_ok=True)

VAL_SIZE = 0.2
RANDOM_SEED = 1337
CAP_TRAIN_PER_CLASS = None   # None 不限制
CAP_VAL_PER_CLASS   = None
AUTO_BALANCE        = True

INPUT_SIZE   = 256
TRAIN_CENTER = False

EPOCHS      = 15
BATCH_SIZE  = 64
ACCUM_STEPS = 1
LR_MAX      = 3e-4
WEIGHT_DECAY= 1e-2
GRAD_CLIP   = 1.0
NUM_WORKERS = max((os.cpu_count() or 1)-1, 0)
PIN_MEMORY  = torch.cuda.is_available()

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device =", device)
random.seed(RANDOM_SEED); np.random.seed(RANDOM_SEED); torch.manual_seed(RANDOM_SEED)

# ---------------- IO Utils ----------------
def list_np(folder):
    files = sorted(glob.glob(os.path.join(folder, "*.npy")) +
                   glob.glob(os.path.join(folder, "*.npz")))
    if not files:
        warnings.warn(f"No .npy/.npz under {folder}")
    return files

def _npz_pick(z):
    # 優先常見鍵；全無時取第一個
    for k in ('ela','arr','arr_0','data'):
        if isinstance(z, np.lib.npyio.NpzFile) and (k in z.files):
            return z[k]
    return z[z.files[0]] if isinstance(z, np.lib.npyio.NpzFile) else z

def load_ela_array(path):
    """讀入 .npy/.npz 的 ELA 陣列；輸出 HxWx3 float32"""
    z = np.load(path, mmap_mode='r')
    a = _npz_pick(z)
    a = np.asarray(a)
    # 形狀整理到 HxWx3
    if a.ndim == 2:
        a = np.repeat(a[..., None], 3, axis=2)
    elif a.ndim == 3 and a.shape[0] in (1,3) and a.shape[-1] not in (1,3):
        a = np.transpose(a, (1,2,0))  # CxHxW -> HxWxC
    elif a.ndim == 3 and a.shape[-1] == 1:
        a = np.repeat(a, 3, axis=2)
    assert a.ndim == 3 and a.shape[-1] == 3, f"Expect HxWx3, got {a.shape} from {path}"

    # dtype 正規化：uint8/int8/float* 全部可；把 0..255 類型先縮放到 0..1
    a = a.astype(np.float32, copy=False)
    if np.nanmax(a) > 1.5:
        a *= (1.0 / 255.0)
    return a

def per_image_zscore_3ch(x: np.ndarray):
    m = x.mean(axis=(0,1), keepdims=True)
    s = x.std(axis=(0,1), keepdims=True); s[s<1e-6] = 1.0
    return (x - m) / s

def crop_2d_hw3(img, size=256, center=False, rng=None):
    h, w, c = img.shape
    if h < size or w < size:
        pad_h = max(0, size - h); pad_w = max(0, size - w)
        img = np.pad(img, ((pad_h//2, pad_h - pad_h//2),
                           (pad_w//2, pad_w - pad_w//2),
                           (0,0)), mode='edge')
        h, w, _ = img.shape
    if h == size and w == size:
        return img.copy()
    if center:
        top = (h - size)//2; left = (w - size)//2
    else:
        if rng is None: rng = np.random.default_rng()
        top  = int(rng.integers(0, h - size + 1))
        left = int(rng.integers(0, w - size + 1))
    return img[top:top+size, left:left+size, :].copy()

def cap_per_class(paths, labels, cap, seed=1337):
    if cap is None: return paths, labels
    rng = np.random.RandomState(seed)
    idx = np.arange(len(paths)); labels = np.array(labels)
    keep = []
    for cls in [0,1]:
        idc = idx[labels==cls]
        if len(idc) > cap: idc = rng.choice(idc, cap, replace=False)
        keep.append(idc)
    keep = np.concatenate(keep); rng.shuffle(keep)
    return [paths[i] for i in keep], [int(labels[i]) for i in keep]

# ---------------- Dataset ----------------
class ELAFromNPY(Dataset):
    def __init__(self, paths, labels, train=True, size=256, train_center=False):
        self.paths = paths; self.labels = labels
        self.train = train; self.size = size; self.train_center = train_center
        self.rng = np.random.default_rng()
    def __len__(self): return len(self.paths)
    def __getitem__(self, i):
        x = load_ela_array(self.paths[i])                                  # HxWx3 float32
        x = crop_2d_hw3(x, size=self.size, center=(not self.train or self.train_center), rng=self.rng)
        x = per_image_zscore_3ch(x)                                        # z-score
        x = np.transpose(x, (2,0,1))                                       # 3xHxW
        t = torch.from_numpy(x)                                            # float32
        y = torch.tensor(self.labels[i], dtype=torch.float32)
        return t, y

# ---------------- Model ----------------
class ELAForensicCNN(nn.Module):
    def __init__(self, in_ch=3):
        super().__init__()
        def bnblk(ci, co):
            return nn.Sequential(
                nn.Conv2d(ci, co, 3, padding=1, bias=False),
                nn.BatchNorm2d(co),
                nn.ReLU(inplace=True))
        def gnblk(ci, co, groups=8):
            return nn.Sequential(
                nn.Conv2d(ci, co, 3, padding=1, bias=False),
                nn.GroupNorm(num_groups=min(groups, co), num_channels=co),
                nn.ReLU(inplace=True))
        self.net = nn.Sequential(
            bnblk(in_ch,32), bnblk(32,32), nn.AvgPool2d(2),
            bnblk(32,64),    bnblk(64,64), nn.AvgPool2d(2),
            gnblk(64,128),   gnblk(128,128), nn.AvgPool2d(2),
            gnblk(128,256),  gnblk(256,256), nn.AdaptiveAvgPool2d(1),
        )
        self.fc = nn.Linear(256, 1)
    def forward(self,x):
        return self.fc(self.net(x).flatten(1)).squeeze(1)

# ---------------- Load splits from JSON (ELA-aware) ----------------
SPLITS_JSON = os.path.join(OUTPUT_DIR, "splits_clip_feature_iid_ood.json")
assert os.path.isfile(SPLITS_JSON), f"找不到 splits json：{SPLITS_JSON}"
with open(SPLITS_JSON, "r", encoding="utf-8") as f:
    SPLITS_IN = json.load(f)["splits"]

def _index_dir(d):
    idx = {}
    for q in Path(d).glob("*.npy"):
        idx[q.name.lower()] = str(q)
        idx[q.stem.lower()] = str(q)
    for q in Path(d).glob("*.npz"):
        idx[q.name.lower()] = str(q)
        idx[q.stem.lower()] = str(q)
    return idx
IDX_REAL = _index_dir(REAL_DIR)
IDX_FAKE = _index_dir(FAKE_DIR)

def _pick_lists(sp_name):
    """回傳 (real_list, fake_list) of ELA 路徑或檔名（未必存在）"""
    S = SPLITS_IN.get(sp_name, {})
    # 1) 若 JSON 已包含 ELA 的路徑
    if isinstance(S, dict) and "ela" in S and isinstance(S["ela"], dict):
        real = [p for p in S["ela"].get("real", []) if p]
        fake = [p for p in S["ela"].get("fake", []) if p]
        return real, fake
    # 2) 有 stems：直接用 stem 拼檔名
    if isinstance(S, dict) and "stems" in S and isinstance(S["stems"], dict):
        real = [s + ".npy" for s in S["stems"].get("real", [])]
        fake = [s + ".npy" for s in S["stems"].get("fake", [])]
        return real, fake
    # 3) 退而用 clip 清單，將目錄名改寫成 ELA
    if isinstance(S, dict) and "clip" in S and isinstance(S["clip"], dict):
        r = []
        for p in S["clip"].get("real", []):
            s = Path(p).as_posix().replace("/clip_real_npy/", "/ela_real_npy/")
            r.append(s)
        f = []
        for p in S["clip"].get("fake", []):
            s = Path(p).as_posix().replace("/clip_fake_npy/", "/ela_fake_npy/")
            f.append(s)
        return r, f
    # 舊版 JSON（直接是 list）的相容
    if isinstance(S, list):
        return S, []
    return [], []

def _map_to_existing(src_list, is_real):
    idx = IDX_REAL if is_real else IDX_FAKE
    mapped, miss = [], 0
    for p in src_list:
        q = Path(p)
        # 先試絕對/相對路徑是否存在
        if q.is_file():
            mapped.append(str(q)); continue
        # 不存在則用 stem/basename 在 ELA 目錄找
        key = q.name.lower()
        q2 = idx.get(key) or idx.get(q.stem.lower())
        if q2 is None:
            miss += 1; continue
        mapped.append(q2)
    return mapped, miss

def build_split(sp_name):
    r_list, f_list = _pick_lists(sp_name)
    r_map, r_miss = _map_to_existing(r_list, True)
    f_map, f_miss = _map_to_existing(f_list, False)
    if (r_miss + f_miss) > 0:
        print(f"[{sp_name}] 對應 real {len(r_map)}/{len(r_list)}（缺 {r_miss}），fake {len(f_map)}/{len(f_list)}（缺 {f_miss}）")
    else:
        print(f"[{sp_name}] real={len(r_map)} fake={len(f_map)}")
    paths = r_map + f_map
    labels = [0]*len(r_map) + [1]*len(f_map)
    return paths, labels

train_paths, y_tr = build_split("train")
val_paths,   y_va = build_split("val")
test_iid,    y_ti = build_split("test_iid") if "test_iid" in SPLITS_IN else ([],[])
test_ood,    y_to = build_split("test_ood") if "test_ood" in SPLITS_IN else ([],[])

# 自動對半 / cap
if AUTO_BALANCE and CAP_TRAIN_PER_CLASS is None:
    n0, n1 = sum(np.array(y_tr)==0), sum(np.array(y_tr)==1)
    CAP_TRAIN_PER_CLASS = min(n0, n1)
def _cap(paths, labels, cap):
    return cap_per_class(paths, labels, cap, seed=RANDOM_SEED)
train_paths, y_tr = _cap(train_paths, y_tr, CAP_TRAIN_PER_CLASS)
val_paths,   y_va = _cap(val_paths,   y_va, CAP_VAL_PER_CLASS)

print(f"Train(after cap): {len(train_paths)} | real={sum(np.array(y_tr)==0)} fake={sum(np.array(y_tr)==1)}")
print(f"Val  (after cap): {len(val_paths)}   | real={sum(np.array(y_va)==0)} fake={sum(np.array(y_va)==1)}")

# ---------------- Build Loaders ----------------
def make_loader(dataset, batch_size, shuffle):
    kwargs = dict(batch_size=batch_size, shuffle=shuffle,
                  num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
                  drop_last=shuffle)
    if NUM_WORKERS > 0:
        kwargs.update(persistent_workers=True, prefetch_factor=2)
    return DataLoader(dataset, **kwargs)

train_set = ELAFromNPY(train_paths, y_tr, train=True,  size=INPUT_SIZE, train_center=TRAIN_CENTER)
val_set   = ELAFromNPY(val_paths,   y_va, train=False, size=INPUT_SIZE, train_center=True)

dl = make_loader(train_set, BATCH_SIZE, True)
vl = make_loader(val_set,   BATCH_SIZE*2, False)

ti_loader = make_loader(ELAFromNPY(test_iid, y_ti, train=False, size=INPUT_SIZE, train_center=True), BATCH_SIZE*2, False) if len(test_iid) else None
to_loader = make_loader(ELAFromNPY(test_ood, y_to, train=False, size=INPUT_SIZE, train_center=True), BATCH_SIZE*2, False) if len(test_ood) else None

# ---------------- Eval / Threshold helpers ----------------
@torch.no_grad()
def collect_val_scores(model, loader, device="cuda"):
    model.eval(); ys, ps = [], []
    for x,y in loader:
        x = x.to(device, non_blocking=True).contiguous(memory_format=torch.channels_last if device=='cuda' else torch.contiguous_format)
        logit = model(x)
        ps.extend(torch.sigmoid(logit).float().cpu().numpy().tolist())
        ys.extend(y.numpy().tolist())
    return np.array(ys), np.array(ps)

def best_threshold(y, p, mode="youden"):
    if mode == "youden":
        fpr, tpr, thr = roc_curve(y, p); j = tpr - fpr
        return float(thr[np.argmax(j)])
    elif mode == "acc":
        qs = np.quantile(p, np.linspace(0.01, 0.99, 99))
        accs = [(((p>=t).astype(int)==y).mean()) for t in qs]
        return float(qs[int(np.argmax(accs))])
    else:
        raise ValueError

# ---------------- Train ----------------
try:
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.set_float32_matmul_precision('high')
except Exception:
    pass

model = ELAForensicCNN().to(device).to(memory_format=torch.channels_last if device=='cuda' else torch.contiguous_format)
opt   = torch.optim.AdamW(model.parameters(), lr=LR_MAX, weight_decay=WEIGHT_DECAY)
lossf = nn.BCEWithLogitsLoss()
use_amp = (device == "cuda")
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

steps_per_epoch = max(1, math.ceil(len(dl) / max(1, ACCUM_STEPS)))
sched = torch.optim.lr_scheduler.OneCycleLR(
    opt, max_lr=LR_MAX, epochs=EPOCHS, steps_per_epoch=steps_per_epoch,
    pct_start=0.15, div_factor=10.0, final_div_factor=10.0
)

best_auc = -1.0
stamp = time.strftime("%Y%m%d_%H%M%S")
best_path = os.path.join(OUTPUT_DIR, f"ela_fromnpy_cnn_best_{stamp}.pt")
best_thr  = 0.5

for ep in range(1, EPOCHS+1):
    model.train(); running = 0.0
    pbar = tqdm(dl, total=len(dl), desc=f"train ep{ep}")
    opt.zero_grad(set_to_none=True)
    for step,(x,y) in enumerate(pbar,1):
        x = x.to(device, non_blocking=True).contiguous(memory_format=torch.channels_last if device=='cuda' else torch.contiguous_format)
        y = y.to(device, non_blocking=True)
        with torch.autocast(device_type=("cuda" if use_amp else "cpu"),
                            dtype=(torch.float16 if use_amp else torch.bfloat16),
                            enabled=use_amp):
            logit = model(x)
            y_s = y*0.95 + 0.025  # label smoothing 0.05
            loss = lossf(logit, y_s) / max(1, ACCUM_STEPS)
        scaler.scale(loss).backward()
        if step % max(1, ACCUM_STEPS) == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
            scaler.step(opt); scaler.update(); opt.zero_grad(set_to_none=True); sched.step()
        running += loss.detach().float().item()

    y_val, p_val = collect_val_scores(model, vl, device=device)
    acc = ((p_val>=0.5).astype(int)==y_val).mean()
    try: auc = roc_auc_score(y_val, p_val)
    except ValueError: auc = float("nan")
    thr_star = best_threshold(y_val, p_val, mode="youden")
    acc_star = ((p_val>=thr_star).astype(int)==y_val).mean()

    print(f"[EP {ep:02d}] train_loss={running/max(1,len(dl)):.4f} | Val acc@0.5={acc:.4f} acc@thr*={acc_star:.4f} AUC={auc:.4f} thr*={thr_star:.3f}")

    if auc > best_auc + 1e-4:
        best_auc = auc; best_thr = thr_star
        torch.save(model.state_dict(), best_path)
        with open(best_path.replace(".pt",".thr.txt"), "w") as f:
            f.write(f"{best_thr:.6f}\n")
        print("  ↳ saved best:", best_path)

print("\n=== Best model on Val ===")
state = torch.load(best_path, map_location=device)
model.load_state_dict(state, strict=True)
y_val, p_val = collect_val_scores(model, vl, device=device)
acc = ((p_val>=0.5).astype(int)==y_val).mean()
auc = roc_auc_score(y_val, p_val) if len(y_val) else float("nan")
thr_path = best_path.replace(".pt",".thr.txt")
thr_star = float(open(thr_path).read().strip()) if os.path.exists(thr_path) else 0.5
acc_star = ((p_val>=thr_star).astype(int)==y_val).mean()
print("Val ACC@0.5:", acc, "ACC@thr*:", acc_star, "AUC:", auc)
print(confusion_matrix(y_val, (p_val>=thr_star).astype(int)))
print(classification_report(y_val, (p_val>=thr_star).astype(int), target_names=["real(0)","fake(1)"]))

# ---------------- Inference ----------------
@torch.no_grad()
def predict_one_ela_npy(npy_or_npz_path, model_path=best_path, input_size=INPUT_SIZE):
    mdl = ELAForensicCNN().to(device).eval()
    mdl.load_state_dict(torch.load(model_path, map_location=device))
    thr_path = model_path.replace(".pt",".thr.txt")
    thr = float(open(thr_path).read().strip()) if os.path.exists(thr_path) else 0.5

    x = load_ela_array(npy_or_npz_path)
    x = crop_2d_hw3(x, size=input_size, center=True)
    x = per_image_zscore_3ch(x)
    x = np.transpose(x, (2,0,1))
    t = torch.from_numpy(x).unsqueeze(0).to(device)
    p = torch.sigmoid(mdl(t)).float().cpu().item()
    return p, int(p >= thr)   # prob_fake, 1=fake


device = cuda
[train] real=56000 fake=56000
[val] real=7000 fake=7000
[test_iid] real=7000 fake=7000
[test_ood] real=8785 fake=8785
Train(after cap): 112000 | real=56000 fake=56000
Val  (after cap): 14000   | real=7000 fake=7000


train ep1: 100%|██████████| 1750/1750 [06:38<00:00,  4.40it/s]


[EP 01] train_loss=0.5100 | Val acc@0.5=0.7964 acc@thr*=0.8251 AUC=0.8903 thr*=0.258
  ↳ saved best: /home/yaya/ai-detect-proj/Script/saved_models/ela_fromnpy_cnn_best_20250820_194931.pt


train ep2: 100%|██████████| 1750/1750 [06:49<00:00,  4.28it/s]


KeyboardInterrupt: 

In [2]:
# ==== ELA-CNN：用 Val 找門檻，評估 Test-IID / Test-OOD，並做 OOD 逐來源報表 ====
from sklearn.metrics import roc_curve, accuracy_score, roc_auc_score, classification_report, confusion_matrix
import numpy as np
from pathlib import Path
from collections import defaultdict

# 讀回最佳模型與 thr*
state = torch.load(best_path, map_location=device)
model.load_state_dict(state, strict=True)
model.eval()
thr_saved = float(open(best_path.replace(".pt",".thr.txt")).read().strip())

def eval_loader(name, loader, thr):
    y, p = collect_val_scores(model, loader, device=device)
    acc = ((p >= thr).astype(int) == y).mean()
    auc = roc_auc_score(y, p)
    print(f"\n[{name}] acc@thr={acc:.4f} | auc={auc:.4f} | thr={thr:.3f}")
    print(confusion_matrix(y, (p >= thr).astype(int)))
    print(classification_report(y, (p >= thr).astype(int), digits=4))
    return y, p

# 1) 用 Val 取兩個門檻：Youden J 與 FPR@5%
y_v, p_v = collect_val_scores(model, vl, device=device)
fpr, tpr, thr = roc_curve(y_v, p_v); J = tpr - fpr
t_youden = float(thr[np.argmax(J)])
idx = np.where(fpr <= 0.05)[0]
t_fpr05 = float(thr[idx[-1]]) if len(idx) else float(thr[0])
print(f"Val thresholds → youden={t_youden:.3f} | fpr@5%={t_fpr05:.3f} | saved*={thr_saved:.3f}")

# 2) 依兩種門檻做完整評估
eval_loader("Val (youden)", vl, t_youden)
if 'ti_loader' in globals() and ti_loader is not None:
    y_ti, p_ti = eval_loader("Test-IID (youden)", ti_loader, t_youden)
if 'to_loader' in globals() and to_loader is not None:
    y_to, p_to = eval_loader("Test-OOD (youden)", to_loader, t_youden)

eval_loader("Val (FPR@5%)", vl, t_fpr05)
if 'ti_loader' in globals() and ti_loader is not None:
    eval_loader("Test-IID (FPR@5%)", ti_loader, t_fpr05)
if 'to_loader' in globals() and to_loader is not None:
    y_to2, p_to2 = eval_loader("Test-OOD (FPR@5%)", to_loader, t_fpr05)

# 3) （可選）OOD 逐來源報表：看 Unsplash / DALL·E3 誰在拖分
SEPS = ("__", "---", "--", "_", "-", " ")
ALIASES = {"unslpash":"unsplash","dalle-3":"dalle3","mj":"midjourney","midj":"midjourney"}

def infer_tag(p: str, is_real: bool) -> str:
    stem = Path(p).stem
    cut = None
    for s in SEPS:
        i = stem.find(s)
        if i != -1: cut = i if cut is None else min(cut, i)
    tag = stem[:cut] if cut is not None else stem
    tag = ALIASES.get(tag.lower().strip(), tag.lower().strip())
    if (not tag) or tag.isdigit():
        tag = "imagenet" if is_real else "unknown"
    return tag

if 'test_ood' in globals() and test_ood:
    # 先把 OOD 的 y/p 取一次（用 FPR@5% 比較保守）
    y_ood, p_ood = (y_to2, p_to2) if 'p_to2' in globals() else (y_to, p_to)
    groups = defaultdict(list)
    for pth, yy, pp in zip(test_ood, y_ood, p_ood):
        groups[infer_tag(pth, is_real=(yy==0))].append((yy, pp))
    print("\n== OOD per-dataset (thr = FPR@5%) ==")
    for tag, items in sorted(groups.items(), key=lambda kv: -len(kv[1])):
        yy = np.array([a for a,_ in items]); pp = np.array([b for _,b in items])
        pred = (pp >= t_fpr05).astype(int)
        acc = (pred == yy).mean()
        try: auc = roc_auc_score(yy, pp)
        except: auc = float("nan")
        cm = confusion_matrix(yy, pred)
        print(f"- {tag:10s} n={len(yy):5d} | acc={acc:.4f} | auc={auc:.4f}\n{cm}\n")


Val thresholds → youden=0.526 | fpr@5%=0.513 | saved*=0.526

[Val (youden)] acc@thr=0.9236 | auc=0.9600 | thr=0.526
[[6665  335]
 [ 734 6266]]
              precision    recall  f1-score   support

         0.0     0.9008    0.9521    0.9258      7000
         1.0     0.9493    0.8951    0.9214      7000

    accuracy                         0.9236     14000
   macro avg     0.9250    0.9236    0.9236     14000
weighted avg     0.9250    0.9236    0.9236     14000


[Test-IID (youden)] acc@thr=0.9255 | auc=0.9599 | thr=0.526
[[6685  315]
 [ 728 6272]]
              precision    recall  f1-score   support

         0.0     0.9018    0.9550    0.9276      7000
         1.0     0.9522    0.8960    0.9232      7000

    accuracy                         0.9255     14000
   macro avg     0.9270    0.9255    0.9254     14000
weighted avg     0.9270    0.9255    0.9254     14000


[Test-OOD (youden)] acc@thr=0.5478 | auc=0.6227 | thr=0.526
[[ 7523 16262]
 [ 5249 18536]]
              precision

