In [8]:
# ======================================================
# Extract ELA(int8) / PRNU(int8) / CLIP(float32) → save as .npy
# ======================================================

from pathlib import Path
from io import BytesIO
import os, random, warnings, json
import numpy as np
from PIL import Image, ImageChops, ImageFile
from skimage import io as skio
from skimage.util import img_as_float32
from skimage.restoration import denoise_wavelet
from skimage.transform import resize
from tqdm.notebook import tqdm
warnings.filterwarnings("ignore", category=UserWarning)
ImageFile.LOAD_TRUNCATED_IMAGES = True

# ---------------- Config（改這裡） ----------------
REAL_DIR = Path("/home/yaya/ai-detect-proj/data/Pic")
FAKE_DIR = Path("/home/yaya/ai-detect-proj/data/6kflux")

OUT_ROOT = Path("/home/yaya/ai-detect-proj/Script/features_test")  # 會自動建立
# 會生成：
#   OUT_ROOT/ela_real_npy/*.npy   （int8）
#   OUT_ROOT/ela_fake_npy/*.npy   （int8）
#   OUT_ROOT/prnu_real_npy/*.npy  （int8）
#   OUT_ROOT/prnu_fake_npy/*.npy  （int8）
#   OUT_ROOT/clip_real_npy/*.npy  （float32）
#   OUT_ROOT/clip_fake_npy/*.npy  （float32）

RUN_CLASSES = ["fake"]                 # 可以改 ["real"], ["fake"], 或 ["real","fake"]
SELECT_FEATURES = ["ela", "prnu", "clip"]  # 想只跑其中幾個就刪掉其餘

# ⬇︎ 抽樣上限：每個 class 最多處理多少張（None=不限制）
MAX_PER_CLASS = 30000  # 例如只取 1 萬張；或設 None 表示全取

# ELA 參數
IMG_SIZE     = 256     # 先把最短邊放到 >= 這個長度後做中心裁切
ELA_QUALITY  = 90
ELA_SCALE    = 15      # 只是把差值放大以增強對比；之後仍會映射再量化
ELA_FEASZ    = 128     # ELA 輸出尺寸
# ELA int8 格式說明：先把 0..1 → uint8(0..255) → int8 = uint8 - 128  （方便存成 int8）
# 之後讀取時若要還原 0..1，可用：(arr_i8.astype(float)+128)/255.0

# PRNU 參數
PRNU_CROP_FROM = 256   # 中心裁起始邊長（不足會放大）
PRNU_OUT_SIZE  = 256   # PRNU 輸出尺寸
PRNU_WAVELET   = "db8" # 小波基
PRNU_MODE      = "soft"

# PRNU 量化（int8）設定：對稱量化 q = clip(round(x/S*127), -127, 127)
PRNU_Q_MODE    = "per_file"  # 'per_file'（推薦）| 'global'
PRNU_Q_PERC    = 0.999       # 用 |x| 的 p99.9 當尺度 S（對 outlier 不敏感）
PRNU_Q_SAMPLES = 4096        # 每張抽幾個像素估分位數

# CLIP 模型（需要 pip 安裝 openai-clip；會 lazy-load）
CLIP_MODEL_NAME = "ViT-L/14"  # 你也可用 "ViT-B/32"

SEED = 42
random.seed(SEED); np.random.seed(SEED)


In [9]:
import os
from concurrent.futures import ProcessPoolExecutor, as_completed

# ===== 并行/批次設定（可依機器調）=====
N_WORKERS_CPU = max(1, (os.cpu_count() or 4) - 1)  # 給 ELA/PRNU 用的進程數
CLIP_BATCH    = 64                                   # RTX 4060 + ViT-L/14 可先用 64（OOM 就降 48/32）
DL_WORKERS    = min(4, max(1, (os.cpu_count() or 4)//2))  # DataLoader 的 CPU worker
PIN_MEMORY    = True

# 避免多進程下 BLAS 過度多執行緒互打架
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")

# tqdm in notebook（自動 fallback）
try:
    from tqdm.notebook import tqdm
except Exception:
    from tqdm.auto import tqdm
TQDM_KW = dict(dynamic_ncols=True, leave=False)

from pathlib import Path
import os, numpy as np
from tempfile import NamedTemporaryFile

def _atomic_save_npy(path: Path, arr: np.ndarray):
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    # 方法 A：簡單！臨時檔用 ".tmp.npy"（np.save 不會再加 .npy）
    tmp = path.with_suffix(".tmp.npy")
    np.save(tmp, arr, allow_pickle=False)
    os.replace(tmp, path)


In [10]:

# ---------------- Utils ----------------
IMG_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tif", ".tiff"}

def all_images(root: Path):
    return [p for p in root.rglob("*") if p.suffix.lower() in IMG_EXTS]

def ensure_dirs():
    OUT_ROOT.mkdir(parents=True, exist_ok=True)
    for feat in SELECT_FEATURES:
        for cls in RUN_CLASSES:
            (OUT_ROOT / f"{feat}_{cls}_npy").mkdir(parents=True, exist_ok=True)

def center_resize_crop_PIL(img: Image.Image, to_size: int) -> Image.Image:
    w, h = img.size
    if min(w, h) < to_size:
        s = to_size / min(w, h)
        img = img.resize((int(round(w*s)), int(round(h*s))), Image.BICUBIC)
        w, h = img.size
    x0, y0 = (w - to_size)//2, (h - to_size)//2
    return img.crop((x0, y0, x0 + to_size, y0 + to_size))

def center_resize_crop_np(img: np.ndarray, crop_from=512, out_size=256) -> np.ndarray:
    h, w = img.shape[:2]
    if min(h, w) < crop_from:
        s = crop_from / min(h, w)
        img = resize(img, (int(round(h*s)), int(round(w*s))),
                     preserve_range=True, anti_aliasing=True).astype(img.dtype)
        h, w = img.shape[:2]
    y0, x0 = (h - crop_from)//2, (w - crop_from)//2
    img = img[y0:y0+crop_from, x0:x0+crop_from]
    if crop_from != out_size:
        img = resize(img, (out_size, out_size),
                     preserve_range=True, anti_aliasing=True).astype(img.dtype)
    return img

# ---------------- ELA（→ int8） ----------------
def _to_int8_offset128_from_01(x01: np.ndarray) -> np.ndarray:
    """x01 in [0,1] → uint8(0..255) → int8(-128..127)"""
    u8 = np.rint(np.clip(x01, 0.0, 1.0) * 255.0).astype(np.uint8)
    i8 = (u8.astype(np.int16) - 128).astype(np.int8)
    return i8

def extract_ela_i8(p: Path) -> np.ndarray | None:
    """回傳 int8 (H,W) 大小為 ELA_FEASZ×ELA_FEASZ；儲存格式：uint8-128"""
    try:
        img = Image.open(p).convert("RGB")
        img = center_resize_crop_PIL(img, IMG_SIZE)
        # JPEG 重壓 & 差分
        buf = BytesIO()
        img.save(buf, format="JPEG", quality=int(ELA_QUALITY), subsampling=0, optimize=False)
        buf.seek(0)
        diff = ImageChops.difference(img, Image.open(buf)).point(lambda x: x * ELA_SCALE)
        diff = diff.convert("L").resize((ELA_FEASZ, ELA_FEASZ))
        arr01 = np.asarray(diff, dtype=np.float32) / 255.0  # 0..1
        q = _to_int8_offset128_from_01(arr01)
        return q
    except Exception as e:
        print("[ELA] skip", p.name, "|", e)
        return None

# ---------------- PRNU（→ int8） ----------------
def _sample_abs_vals(a: np.ndarray, k: int, rng=np.random.default_rng(SEED)) -> np.ndarray:
    v = a.reshape(-1).astype(np.float32, copy=False)
    if v.size <= k: return np.abs(v)
    idx = rng.integers(0, v.size, size=k, endpoint=False)
    return np.abs(v[idx])

def _fast_percentile(v: np.ndarray, q: float) -> float:
    if v.size == 0: return 1e-8
    k = int(q * (v.size - 1))
    val = np.partition(v, k)[k]
    return float(max(val, 1e-8))

def _prnu_quant_i8(a: np.ndarray, S: float) -> np.ndarray:
    x = np.clip(a, -S, S) / S * 127.0
    q = np.rint(x).astype(np.int16)
    q = np.clip(q, -127, 127).astype(np.int8)
    return q

def extract_prnu_i8(p: Path) -> np.ndarray | None:
    """回傳 int8 (PRNU_OUT_SIZE, PRNU_OUT_SIZE)，對稱量化"""
    try:
        im = skio.imread(str(p))
        if im.ndim == 2:
            im = np.repeat(im[..., None], 3, axis=-1)
        im = img_as_float32(im)  # 0..1
        crop = center_resize_crop_np(im, PRNU_CROP_FROM, PRNU_OUT_SIZE)
        gray = crop.mean(axis=2, dtype=np.float32)
        denoised = denoise_wavelet(gray, channel_axis=None, mode=PRNU_MODE,
                                   wavelet=PRNU_WAVELET, convert2ycbcr=False)
        residual = gray - denoised
        residual -= residual.mean()

        # 估 per-file S
        if PRNU_Q_MODE == "per_file":
            vals = _sample_abs_vals(residual, PRNU_Q_SAMPLES)
            S = _fast_percentile(vals, PRNU_Q_PERC)
        else:
            # 若要 global，可先掃一輪估 S，再放進這裡；簡化起見用保底 S
            S = max(1e-6, float(np.std(residual)) * 6.0)

        q = _prnu_quant_i8(residual, S)
        return q
    except Exception as e:
        print("[PRNU] skip", p.name, "|", e)
        return None

# ---------------- CLIP（→ float32 向量） ----------------
# ---------------- CLIP（→ float32；倒數第二層；open_clip/LAION） ----------------
# 取代原本的 "import torch, clip" 與 extract_clip_vec() 區塊

# 設定：選 backbone 與對應的 LAION 權重
CLIP_BACKBONE   = "ViT-L-14"         # 可改 "ViT-B-32"、"ViT-L-14-336" 等
CLIP_PRETRAINED = {
    "ViT-L-14":       "laion2b_s32b_b82k",
    "ViT-B-32":       "laion400m_e32",
    "ViT-L-14-336":   "laion2b_s32b_b82k"  # 336 變體若可用，維持同權重系列
}.get(CLIP_BACKBONE, "laion2b_s32b_b82k")

_openclip_model = None
_openclip_pre   = None
_openclip_dev   = "cpu"

# ---------------- Fix: open_clip loader（相容 2/3 回傳值） ----------------
_openclip_model = None
_openclip_pre   = None
_openclip_dev   = "cpu"

def load_openclip():
    global _openclip_model, _openclip_pre, _openclip_dev
    if _openclip_model is None:
        import torch, open_clip
        _openclip_dev = "cuda" if torch.cuda.is_available() else "cpu"

        res = open_clip.create_model_and_transforms(
            CLIP_BACKBONE, pretrained=CLIP_PRETRAINED
        )
        # res 可能是 (model, pre_train, pre_val) 或 (model, pre)
        if isinstance(res, tuple) and len(res) == 3:
            model, pre_train, pre_val = res
            pre = pre_val   # 推論用 eval 版變換
        elif isinstance(res, tuple) and len(res) == 2:
            model, pre = res
        else:
            # 極少數版本保底：直接用另一個 API
            model, pre = open_clip.create_model_from_pretrained(
                CLIP_BACKBONE, pretrained=CLIP_PRETRAINED
            )

        _openclip_model = model.to(_openclip_dev)
        _openclip_model.eval()
        _openclip_pre = pre

    return _openclip_model, _openclip_pre, _openclip_dev


def _encode_image_penultimate(model, image_tensor):
    """
    回傳倒數第二層（pre-projection）CLS 特徵：
    - 優先在 visual.ln_post 取得輸出（OpenAI/early open_clip ViT 結構）
    - 退而求其次用 trunk.forward_features() 取 CLS，再經 ln_post（若存在）
    - 最後的保底是用 encode_image（投影後），確保不會報錯
    """
    import torch
    feats = {}
    handle = None

    # 1) 嘗試 hook 在 ln_post（投影前最後一層）
    try:
        target = getattr(model.visual, "ln_post", None)
        if target is not None:
            def _hook(_m, _inp, out):
                feats["penult"] = out.detach()
            handle = target.register_forward_hook(_hook)
            _ = model.encode_image(image_tensor)  # 觸發 forward
            if handle is not None:
                handle.remove()
            if "penult" in feats:
                return feats["penult"]
    except Exception:
        if handle is not None:
            handle.remove()

    # 2) timm trunk：直接拿 forward_features 的 CLS，並套 ln_post（若有）
    try:
        visual = model.visual
        if hasattr(visual, "trunk") and hasattr(visual.trunk, "forward_features"):
            x = visual.trunk.forward_features(image_tensor)
            if isinstance(x, (tuple, list)):
                x = x[0]
            # 若仍是 token map，取 CLS
            if x.ndim == 3:
                x = x[:, 0, :]
            if hasattr(visual, "ln_post") and visual.ln_post is not None:
                x = visual.ln_post(x)
            return x
    except Exception:
        pass

    # 3) 保底：使用最終投影後向量（不是倒數第二層，但避免中斷流程）
    with torch.no_grad():
        return model.encode_image(image_tensor)

def extract_clip_vec(p: Path) -> np.ndarray | None:
    """回傳 float32 向量 (D,)，D 取決於 backbone（例如 ViT-L-14 ≈ 1024 維倒數第二層）"""
    try:
        import torch
        model, pre, dev = load_openclip()
        img = Image.open(p).convert("RGB")
        # 先做正方形中心裁切，之後交給 open_clip 的 preprocess 做尺寸/標準化
        w, h = img.size
        s = min(w, h)
        img = img.crop(((w - s)//2, (h - s)//2, (w + s)//2, (h + s)//2))
        im = pre(img).unsqueeze(0).to(dev)

        with torch.no_grad():
            penult = _encode_image_penultimate(model, im).float()
            # L2 normalize（與你原本一致）
            penult = penult / penult.norm(dim=-1, keepdim=True).clamp_min(1e-12)
            vec = penult.squeeze(0).cpu().numpy().astype(np.float32)
        return vec
    except Exception as e:
        print("[CLIP] skip", p.name, "|", e)
        return None

def _ela_worker(args):
    img_path, out_path = args
    try:
        if out_path.exists(): 
            return True
        arr = extract_ela_i8(img_path)
        if arr is None:
            return False
        _atomic_save_npy(out_path, arr)
        return True
    except Exception:
        return False

def _prnu_worker(args):
    img_path, out_path = args
    try:
        if out_path.exists():
            return True
        arr = extract_prnu_i8(img_path)
        if arr is None:
            return False
        _atomic_save_npy(out_path, arr)
        return True
    except Exception:
        return False


In [11]:
def _center_square(img: Image.Image) -> Image.Image:
    w, h = img.size
    s = min(w, h)
    return img.crop(((w - s)//2, (h - s)//2, (w + s)//2, (h + s)//2))

class _ClipPathsDataset:
    def __init__(self, paths, pre):
        self.paths = paths
        self.pre   = pre
    def __len__(self): return len(self.paths)
    def __getitem__(self, i):
        p = self.paths[i]
        try:
            img = Image.open(p).convert("RGB")
            img = _center_square(img)
            t   = self.pre(img)
            ok  = True
        except Exception:
            t, ok = None, False
        return p, t, ok

def _collate(batch):
    # batch: list of (path, tensor, ok)
    ps, ts = [], []
    for p, t, ok in batch:
        if ok and t is not None:
            ps.append(p); ts.append(t)
    if len(ts) == 0:
        return [], None
    import torch
    return ps, torch.stack(ts, dim=0)


In [12]:

# ---------------- Runner ----------------
def _slug(s: str) -> str:
    return "".join(c if (c.isalnum() or c in "-_.") else "_" for c in s)

def make_id(p: Path, root: Path) -> str:
    rel = p.relative_to(root)                 # a/b/c.jpg
    base = "_".join(rel.with_suffix("").parts)
    dataset = _slug(root.name)                # e.g., "unsplash" or "FLUX"
    return f"{dataset}__{base}"

def run_extract_for_class(cls: str, root: Path):
    files = all_images(root)
    if MAX_PER_CLASS is not None and len(files) > MAX_PER_CLASS:
        rng = np.random.default_rng(SEED)
        idx = rng.choice(len(files), size=MAX_PER_CLASS, replace=False)
        files = [files[i] for i in idx]
    print(f"→ {cls} ({len(files)} images) from {root}")

    # 準備輸出資料夾
    dirs = {}
    if "ela"  in SELECT_FEATURES: dirs["ela"]  = OUT_ROOT / f"ela_{cls}_npy"
    if "prnu" in SELECT_FEATURES: dirs["prnu"] = OUT_ROOT / f"prnu_{cls}_npy"
    if "clip" in SELECT_FEATURES: dirs["clip"] = OUT_ROOT / f"clip_{cls}_npy"
    for d in dirs.values(): d.mkdir(parents=True, exist_ok=True)

    # ---- 先跑 ELA/PRNU：CPU 多進程 ----
    if "ela" in dirs:
        tasks = []
        for p in files:
            fid  = make_id(p, root)
            outp = dirs["ela"] / f"{fid}.npy"
            if not outp.exists():
                tasks.append((p, outp))
        if tasks:
            print(f"ELA (CPU x{N_WORKERS_CPU}) → {len(tasks)}")
            with ProcessPoolExecutor(max_workers=N_WORKERS_CPU) as ex:
                futs = [ex.submit(_ela_worker, t) for t in tasks]
                ok = 0
                for f in tqdm(as_completed(futs), total=len(futs), desc="ELA", **TQDM_KW):
                    ok += 1 if f.result() else 0
            print(f"ELA saved: {ok}/{len(tasks)}")

    if "prnu" in dirs:
        tasks = []
        for p in files:
            fid  = make_id(p, root)
            outp = dirs["prnu"] / f"{fid}.npy"
            if not outp.exists():
                tasks.append((p, outp))
        if tasks:
            print(f"PRNU (CPU x{N_WORKERS_CPU}) → {len(tasks)}")
            with ProcessPoolExecutor(max_workers=N_WORKERS_CPU) as ex:
                futs = [ex.submit(_prnu_worker, t) for t in tasks]
                ok = 0
                for f in tqdm(as_completed(futs), total=len(futs), desc="PRNU", **TQDM_KW):
                    ok += 1 if f.result() else 0
            print(f"PRNU saved: {ok}/{len(tasks)}")

    # ---- 再跑 CLIP：GPU 單進程 + DataLoader 多工載入 ----
    if "clip" in dirs:
        to_run = []
        for p in files:
            fid  = make_id(p, root)
            outp = dirs["clip"] / f"{fid}.npy"
            if not outp.exists():
                to_run.append((p, outp))
        if to_run:
            print(f"CLIP (GPU batch={CLIP_BATCH}, loader_workers={DL_WORKERS}) → {len(to_run)}")
            # 準備資料集 / DataLoader
            import torch
            model, pre, dev = load_openclip()
            ds_paths  = [p for p,_ in to_run]
            out_paths = {p: outp for p, outp in to_run}
            ds = _ClipPathsDataset(ds_paths, pre)
            dl = torch.utils.data.DataLoader(
                ds, batch_size=CLIP_BATCH, shuffle=False,
                num_workers=DL_WORKERS, pin_memory=PIN_MEMORY,
                collate_fn=_collate, drop_last=False
            )
            n_ok = 0
            pbar = tqdm(total=len(ds_paths), desc="CLIP", **TQDM_KW)
            with torch.no_grad():
                for paths, batch in dl:
                    if not paths:  # 全部壞圖
                        continue
                    batch = batch.to(dev, non_blocking=True)
                    feats = _encode_image_penultimate(model, batch).float()
                    feats = feats / feats.norm(dim=-1, keepdim=True).clamp_min(1e-12)
                    vecs  = feats.cpu().numpy().astype(np.float32)
                    # 寫檔
                    for pth, vec in zip(paths, vecs):
                        _atomic_save_npy(out_paths[pth], vec)
                        n_ok += 1
                    pbar.update(len(paths))
            pbar.close()
            print(f"CLIP saved: {n_ok}/{len(ds_paths)}")


In [13]:

# ---------------- Go! ----------------
def ensure_dirs():
    OUT_ROOT.mkdir(parents=True, exist_ok=True)
    for feat in SELECT_FEATURES:
        for cls in RUN_CLASSES:
            (OUT_ROOT / f"{feat}_{cls}_npy").mkdir(parents=True, exist_ok=True)

ensure_dirs()
if "real" in RUN_CLASSES: run_extract_for_class("real", REAL_DIR)
if "fake" in RUN_CLASSES: run_extract_for_class("fake", FAKE_DIR)
print("✅ Done. Features saved under:", OUT_ROOT)


→ fake (3650 images) from /home/yaya/ai-detect-proj/data/6kflux
ELA (CPU x7) → 3650


ELA:   0%|          | 0/3650 [00:00<?, ?it/s]

ELA saved: 3650/3650
PRNU (CPU x7) → 3650


PRNU:   0%|          | 0/3650 [00:00<?, ?it/s]

PRNU saved: 3650/3650
CLIP (GPU batch=64, loader_workers=4) → 3650


CLIP:   0%|          | 0/3650 [00:00<?, ?it/s]

CLIP saved: 3650/3650
✅ Done. Features saved under: /home/yaya/ai-detect-proj/Script/features_test


In [None]:
def cleanup_tmp_npy(root: Path):
    cnt = 0
    for p in root.rglob("*.tmp.npy"):
        try:
            p.unlink()
            cnt += 1
        except Exception:
            pass
    print(f"🧹 cleaned {cnt} orphan .tmp.npy files under {root}")

# 跑一次
cleanup_tmp_npy(OUT_ROOT)


In [None]:
m, pre, dev = load_openclip()
print("device:", dev)
p = next(iter(REAL_DIR.rglob("*.jpg")), None) or next(iter(FAKE_DIR.rglob("*.png")), None)
print("test image:", p)
v = extract_clip_vec(p)
print("vec shape:", None if v is None else v.shape)


In [None]:
# ===== GPU 版 CLIP→SVM（一格可跑）=====
from pathlib import Path
import os, json, random, numpy as np
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report, confusion_matrix

# -------- 基本參數 --------
SEED = 42
random.seed(SEED); np.random.seed(SEED)

OUT_ROOT = Path("/home/yaya/ai-detect-proj/Script/features_256")
CLIP_REAL_DIR = OUT_ROOT / "clip_real_npy"
CLIP_FAKE_DIR = OUT_ROOT / "clip_fake_npy"
SPLIT_JSON = Path("/home/yaya/ai-detect-proj/Script/splits/combined_split.json")  # 若無則隨機切
N_PER_CLASS = 10000
C = 1.0  # SVM 強度（越大越貼訓練集）

# -------- 掃檔與切分 --------
def list_npy(d): return sorted([p for p in d.glob("*.npy")])
def fid(p: Path): return p.stem

real_files = list_npy(CLIP_REAL_DIR)
fake_files = list_npy(CLIP_FAKE_DIR)
assert real_files and fake_files, "找不到 CLIP 特徵 .npy，請先完成特徵抽取。"

id2path = {"real": {fid(p): p for p in real_files},
           "fake": {fid(p): p for p in fake_files}}

def intersect_ids(need_ids, pool_dict): return [i for i in need_ids if i in pool_dict]

def make_splits():
    from sklearn.model_selection import train_test_split
    if SPLIT_JSON.exists():
        ids = json.loads(SPLIT_JSON.read_text())["ids"]
        def pick(sp, cls):
            arr = intersect_ids(ids[sp], id2path[cls])
            if N_PER_CLASS is not None:
                arr = arr[:min(len(arr), N_PER_CLASS)]
            return arr
        splits = {
            "train": [(i,0) for i in pick("train","real")] + [(i,1) for i in pick("train","fake")],
            "val":   [(i,0) for i in pick("val","real")]   + [(i,1) for i in pick("val","fake")],
            "test":  [(i,0) for i in pick("test","real")]  + [(i,1) for i in pick("test","fake")],
        }
        print(f"使用 split.json：{SPLIT_JSON}")
    else:
        real_ids = list(id2path["real"].keys())
        fake_ids = list(id2path["fake"].keys())
        random.shuffle(real_ids); random.shuffle(fake_ids)
        if N_PER_CLASS is not None:
            real_ids = real_ids[:min(len(real_ids), N_PER_CLASS)]
            fake_ids = fake_ids[:min(len(fake_ids), N_PER_CLASS)]
        # 8:1:1
        r_tr, r_tmp = train_test_split(real_ids, test_size=0.2, random_state=SEED)
        f_tr, f_tmp = train_test_split(fake_ids, test_size=0.2, random_state=SEED)
        r_va, r_te  = train_test_split(r_tmp, test_size=0.5, random_state=SEED)
        f_va, f_te  = train_test_split(f_tmp, test_size=0.5, random_state=SEED)
        splits = {
            "train": [(i,0) for i in r_tr] + [(i,1) for i in f_tr],
            "val":   [(i,0) for i in r_va] + [(i,1) for i in f_va],
            "test":  [(i,0) for i in r_te] + [(i,1) for i in f_te],
        }
        print("使用隨機切分（無 split.json）")
    for sp in splits:
        random.shuffle(splits[sp])
        n0 = sum(1 for _,y in splits[sp] if y==0); n1 = len(splits[sp])-n0
        print(f"{sp}: total={len(splits[sp])} | real={n0} fake={n1}")
    return splits
splits = make_splits()

def load_pairs(pairs):
    X, y = [], []
    for i, lab in pairs:
        p = id2path["real" if lab==0 else "fake"][i]
        v = np.load(p, allow_pickle=False)
        X.append(v.astype(np.float32, copy=False).reshape(-1))
        y.append(lab)
    return np.stack(X, 0), np.array(y, dtype=np.int32)

X_train, y_train = load_pairs(splits["train"])
X_val,   y_val   = load_pairs(splits["val"])
X_test,  y_test  = load_pairs(splits["test"])
print("shapes:", X_train.shape, X_val.shape, X_test.shape)

# -------- 嘗試 cuML（若不可用則退回 PyTorch 線性 SVM）--------
USE_BACKEND = None
try:
    import cuml, cupy as cp
    from cuml.svm import SVC
    USE_BACKEND = "cuml"
    print("✅ 使用 cuML GPU SVM")
except Exception as e:
    USE_BACKEND = "torch"
    print("⚠️ cuML 不可用，改用 PyTorch 線性 SVM（GPU）:", e)

def evaluate_scores(y_true, scores, name):
    acc = accuracy_score(y_true, (scores>0).astype(np.int32))
    auc = roc_auc_score(y_true, scores)
    print(f"[{name}] acc={acc:.4f} auc={auc:.4f}")
    print(confusion_matrix(y_true, (scores>0).astype(np.int32)))
    print(classification_report(y_true, (scores>0).astype(np.int32), target_names=["real","fake"], digits=4))

if USE_BACKEND == "cuml":
    # ---- cuML 線性 SVM（GPU）----
    Xtr = cp.asarray(X_train); ytr = cp.asarray(y_train)
    Xva = cp.asarray(X_val);   yva = cp.asarray(y_val)
    Xte = cp.asarray(X_test);  yte = cp.asarray(y_test)

    clf = SVC(kernel="linear", C=C, probability=False, max_iter=100000, tol=1e-3)
    clf.fit(Xtr, ytr)

    # decision_function > 0 視為 fake（label=1）
    s_val = clf.decision_function(Xva).get()
    s_te  = clf.decision_function(Xte).get()

    evaluate_scores(y_val,  s_val,  "val")
    evaluate_scores(y_test, s_te,   "test")

    # 儲存
    import joblib
    joblib.dump({"backend":"cuml","model":clf}, "/home/yaya/ai-detect-proj/Script/saved_models/clip_svm_gpu.pkl")
    print("✅ saved: saved_models/clip_svm_gpu.pkl")

else:
    # ---- PyTorch 線性 SVM（hinge）----
    import torch, torch.nn as nn

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    Xtr = torch.from_numpy(X_train).to(device)
    Xva = torch.from_numpy(X_val).to(device)
    Xte = torch.from_numpy(X_test).to(device)
    # y in {0,1} → y' in {-1,+1}
    ytr = torch.from_numpy(np.where(y_train==1, 1, -1).astype(np.float32)).to(device)
    yva = torch.from_numpy(np.where(y_val==1,   1, -1).astype(np.float32)).to(device)
    yte = torch.from_numpy(np.where(y_test==1,  1, -1).astype(np.float32)).to(device)

    D = Xtr.shape[1]
    model = nn.Linear(D, 1, bias=True).to(device)

    # Hinge 損失 + L2（= SVM 的正則）：min 0.5*||w||^2 + C * Σ max(0, 1 - y*(Wx+b))
    def hinge_loss(out, y):
        # out: [N,1], y: [N] in {-1,+1}
        m = 1 - y.unsqueeze(1) * out
        return torch.clamp(m, min=0).mean()

    opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    EPOCHS, BATCH = 50, 1024
    N = Xtr.shape[0]
    best_val = float("inf"); best = None

    for ep in range(1, EPOCHS+1):
        model.train()
        perm = torch.randperm(N, device=device)
        total = 0.0
        for i in range(0, N, BATCH):
            idx = perm[i:i+BATCH]
            xb, yb = Xtr[idx], ytr[idx]
            out = model(xb)                 # [B,1]
            reg = 0.5 * (model.weight**2).sum()
            loss = reg + C * hinge_loss(out, yb)
            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()
            total += loss.item() * xb.size(0)
        # 簡單 val
        model.eval()
        with torch.no_grad():
            s_val = model(Xva).squeeze(1)
            val_loss = (0.5*(model.weight**2).sum() + C*hinge_loss(s_val, yva)).item()
        if val_loss < best_val:
            best_val = val_loss
            best = {k: v.detach().cpu().clone() for k,v in model.state_dict().items()}
        if ep % 5 == 0 or ep == 1:
            print(f"[ep{ep:02d}] train_loss={total/N:.6f} val_obj={val_loss:.6f}")

    if best is not None:
        model.load_state_dict(best)

    with torch.no_grad():
        s_val = model(Xva).squeeze(1).detach().cpu().numpy()
        s_te  = model(Xte).squeeze(1).detach().cpu().numpy()

    evaluate_scores(y_val,  s_val, "val")
    evaluate_scores(y_test, s_te,  "test")

    # 儲存（Torch 權重）
    torch.save({"state_dict": model.state_dict(), "D": int(D), "C": C},
               "/home/yaya/ai-detect-proj/Script/saved_models/clip_svm_gpu_torch.pt")
    print("✅ saved: saved_models/clip_svm_gpu_torch.pt")


In [None]:
# ===== PRNU(int8) → Fast CNN (speed-optimized) =====
from pathlib import Path
import os, json, random, math, time
import numpy as np

# tqdm（Notebook 友善）
try:
    from tqdm.notebook import tqdm
except Exception:
    from tqdm.auto import tqdm
TQDM_KW = dict(dynamic_ncols=True, leave=False)

# ---------------- Config ----------------
SEED = 42
random.seed(SEED); np.random.seed(SEED)

SCRIPT_ROOT = Path("/home/yaya/ai-detect-proj/Script")
FEA_ROOT    = SCRIPT_ROOT / "features_256"
REAL_DIR    = FEA_ROOT / "prnu_real_npy"
FAKE_DIR    = FEA_ROOT / "prnu_fake_npy"
SPLIT_JSON  = SCRIPT_ROOT / "splits/combined_split.json"  # 若不存在自動隨機切
SAVE_DIR    = SCRIPT_ROOT / "saved_models"; SAVE_DIR.mkdir(parents=True, exist_ok=True)
BEST_PATH   = SAVE_DIR / "prnu_fastcnn_u8_best.pt"

# 資料與訓練
N_PER_CLASS = 10000         # 每類最多取多少（你的情境：各 2000）
BATCH       = 64          # 4060 建議 64~128，OOM 就降
EPOCHS      = 15
LR          = 2e-3
WEIGHT_DECAY= 1e-4
EARLY_STOP  = 5

# DataLoader
NUM_WORKERS     = min(8, os.cpu_count() or 4)
PIN_MEMORY      = True
PREFETCH_FACTOR = 4
PERSISTENT      = True

# Dataset 快取策略："ram" | "memmap" | None
CACHE_MODE = "ram"         # 4k 張 * 64KB ≈ 256MB，RAM 完全可承受 → 最快

# ---------------- 檔案列表 & splits ----------------
def list_npy(d: Path):
    assert d.exists(), f"Not found: {d}"
    return sorted([p for p in d.glob("*.npy")])

real_files = list_npy(REAL_DIR)
fake_files = list_npy(FAKE_DIR)
assert real_files and fake_files, "找不到 PRNU 特徵 .npy，請先完成特徵抽取。"

def fid(p: Path): return p.stem
id2path = {"real": {fid(p): p for p in real_files},
           "fake": {fid(p): p for p in fake_files}}

from sklearn.model_selection import train_test_split
def intersect_ids(need_ids, pool_dict): return [i for i in need_ids if i in pool_dict]

def make_splits():
    if SPLIT_JSON.exists():
        js  = json.loads(SPLIT_JSON.read_text())
        ids = js["ids"]
        def pick(sp, cls):
            arr = intersect_ids(ids[sp], id2path[cls])
            if N_PER_CLASS is not None:
                arr = arr[:min(len(arr), N_PER_CLASS)]
            return arr
        splits = {
            "train": [(i,0) for i in pick("train","real")] + [(i,1) for i in pick("train","fake")],
            "val":   [(i,0) for i in pick("val","real")]   + [(i,1) for i in pick("val","fake")],
            "test":  [(i,0) for i in pick("test","real")]  + [(i,1) for i in pick("test","fake")],
        }
        print(f"使用 split.json：{SPLIT_JSON}")
    else:
        real_ids = list(id2path["real"].keys())
        fake_ids = list(id2path["fake"].keys())
        random.shuffle(real_ids); random.shuffle(fake_ids)
        if N_PER_CLASS is not None:
            real_ids = real_ids[:min(len(real_ids), N_PER_CLASS)]
            fake_ids = fake_ids[:min(len(fake_ids), N_PER_CLASS)]
        r_tr, r_tmp = train_test_split(real_ids, test_size=0.2, random_state=SEED)
        f_tr, f_tmp = train_test_split(fake_ids, test_size=0.2, random_state=SEED)
        r_va, r_te  = train_test_split(r_tmp, test_size=0.5, random_state=SEED)  # 0.1/0.1
        f_va, f_te  = train_test_split(f_tmp, test_size=0.5, random_state=SEED)
        splits = {
            "train": [(i,0) for i in r_tr] + [(i,1) for i in f_tr],
            "val":   [(i,0) for i in r_va] + [(i,1) for i in f_va],
            "test":  [(i,0) for i in r_te] + [(i,1) for i in f_te],
        }
        print("使用隨機切分（無 split.json）")
    for sp in splits:
        random.shuffle(splits[sp])
        n0 = sum(1 for _,y in splits[sp] if y==0); n1 = len(splits[sp])-n0
        print(f"{sp}: total={len(splits[sp])} | real={n0} fake={n1}")
    return splits

splits = make_splits()

# ---------------- Dataset / DataLoader ----------------
import torch
from torch.utils.data import Dataset, DataLoader

class PRNUNPY(Dataset):
    """
    PRNU int8 -> float32[-1,1]，per-sample 去均值；可 RAM 快取或 memmap。
    """
    def __init__(self, pairs, id2path, augment=False, cache_mode="ram"):
        self.pairs = pairs
        self.id2p  = id2path
        self.augment = augment
        self.cache_mode = cache_mode
        self.cache = []

        if cache_mode == "ram":
            self.cache = [None]*len(pairs)
            for i, (idx, lab) in enumerate(pairs):
                d = "real" if lab==0 else "fake"
                p = self.id2p[d][idx]
                arr = np.load(p, allow_pickle=False)     # int8
                self.cache[i] = arr.copy()

    def __len__(self): return len(self.pairs)

    def __getitem__(self, i):
        idx, lab = self.pairs[i]
        if self.cache_mode == "ram":
            arr = self.cache[i]
        else:
            d = "real" if lab==0 else "fake"
            p = self.id2p[d][idx]
            if self.cache_mode == "memmap":
                arr = np.load(p, allow_pickle=False, mmap_mode='r')
            else:
                arr = np.load(p, allow_pickle=False)

        x = (arr.astype(np.float32) / 127.0)
        x = x - x.mean()
        # 如需更快，可註解掉任何增強
        if self.augment:
            # x += np.random.normal(0.0, 0.02, size=x.shape).astype(np.float32)
            pass
        x = np.clip(x, -2.0, 2.0)
        x = x[None, ...]  # [1,H,W]
        return torch.from_numpy(x), torch.tensor(lab, dtype=torch.long)

def make_loader(pairs, train=False):
    ds = PRNUNPY(pairs, id2path, augment=train, cache_mode=CACHE_MODE)
    return DataLoader(
        ds, batch_size=BATCH, shuffle=train,
        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
        prefetch_factor=PREFETCH_FACTOR, persistent_workers=PERSISTENT,
        drop_last=train
    )

train_loader = make_loader(splits["train"], train=True)
val_loader   = make_loader(splits["val"],   train=False)
test_loader  = make_loader(splits["test"],  train=False)

# ---------------- Fast CNN 模型 ----------------
import torch.nn as nn
import torch.nn.functional as F

class DSBlock(nn.Module):
    """Depthwise-Separable Conv：DW 3x3 + BN + ReLU → PW 1x1 + BN + ReLU"""
    def __init__(self, c_in, c_out, stride=1):
        super().__init__()
        self.dw = nn.Conv2d(c_in, c_in, 3, stride=stride, padding=1, groups=c_in, bias=False)
        self.bn1= nn.BatchNorm2d(c_in)
        self.pw = nn.Conv2d(c_in, c_out, 1, bias=False)
        self.bn2= nn.BatchNorm2d(c_out)
        self.act = nn.ReLU(inplace=True)
    def forward(self, x):
        x = self.act(self.bn1(self.dw(x)))
        x = self.act(self.bn2(self.pw(x)))
        return x

class PRNUFastCNN(nn.Module):
    """
    1×256×256 → (stem 32) → DS(64,s=1) → DS(128,s=2) → DS(128,s=1)
                 → DS(256,s=2) → DS(256,s=1) → GAP → FC
    約 0.6M 參數，吞吐快，對紋理有效。
    """
    def __init__(self, nc=1, num_classes=2):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(nc, 32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        self.stage = nn.Sequential(
            DSBlock(32, 64,  stride=1),   # 256
            DSBlock(64, 128, stride=2),   # 128
            DSBlock(128,128, stride=1),   # 128
            DSBlock(128,256, stride=2),   # 64
            DSBlock(256,256, stride=1),   # 64
        )
        self.head = nn.Sequential(
            nn.Conv2d(256, 256, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.dropout = nn.Dropout(0.2)
        self.fc   = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.stem(x)
        x = self.stage(x)
        x = self.head(x)
        x = self.pool(x).squeeze(-1).squeeze(-1)
        x = self.dropout(x)
        return self.fc(x)

# ---------------- 訓練 & 評估 ----------------
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report, confusion_matrix

def set_seed(seed=SEED):
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    random.seed(seed); np.random.seed(seed)

def evaluate(model, loader, device):
    model.eval()
    ys, logits = [], []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device, non_blocking=True).contiguous(memory_format=torch.channels_last)
            y = y.to(device, non_blocking=True)
            out = model(x)
            logits.append(out.detach().cpu().numpy())
            ys.append(y.detach().cpu().numpy())
    logits = np.concatenate(logits, 0)
    y_true = np.concatenate(ys, 0)
    y_pred = logits.argmax(1)
    acc = (y_pred == y_true).mean()
    prob1 = torch.softmax(torch.from_numpy(logits), dim=1).numpy()[:,1]
    auc = roc_auc_score(y_true, prob1)
    return acc, auc, y_true, y_pred, prob1

def train_prnu_fastcnn():
    set_seed(SEED)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.backends.cudnn.benchmark = True

    model  = PRNUFastCNN().to(device)
    model  = model.to(memory_format=torch.channels_last)

    # bf16 > fp16（Ada 支援 bf16）
    AMP_DTYPE = torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else torch.float16
    USE_SCALER = (AMP_DTYPE == torch.float16)
    scaler = torch.cuda.amp.GradScaler(enabled=USE_SCALER)

    # 類別權重（避免不平衡）
    n0 = sum(1 for _,y in splits["train"] if y==0); n1 = len(splits["train"])-n0
    w = torch.tensor([1.0/n0, 1.0/n1], dtype=torch.float32, device=device)
    w = w / w.mean()

    crit  = nn.CrossEntropyLoss(weight=w)
    optim = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=EPOCHS)

    best_auc, best_state, no_improve = -1.0, None, 0
    for ep in range(1, EPOCHS+1):
        model.train()
        losses = []
        pbar = tqdm(train_loader, desc=f"train ep{ep}", **TQDM_KW)
        for x,y in pbar:
            x = x.to(device, non_blocking=True).contiguous(memory_format=torch.channels_last)
            y = y.to(device, non_blocking=True)
            with torch.cuda.amp.autocast(dtype=AMP_DTYPE):
                out  = model(x)
                loss = crit(out, y)
            optim.zero_grad(set_to_none=True)
            if USE_SCALER:
                scaler.scale(loss).backward()
                scaler.step(optim)
                scaler.update()
            else:
                loss.backward()
                optim.step()
            losses.append(loss.item())
            pbar.set_postfix(loss=f"{np.mean(losses):.4f}")
        sched.step()

        # 驗證
        val_acc, val_auc, *_ = evaluate(model, val_loader, device)
        print(f"[EP {ep:02d}] train_loss={np.mean(losses):.4f} | val acc={val_acc:.4f} auc={val_auc:.4f}")

        if val_auc > best_auc:
            best_auc = val_auc
            best_state = {"model": model.state_dict(),
                          "meta": {"arch":"PRNUFastCNN","seed":SEED,"epochs_done":ep,
                                   "val_auc":float(val_auc),"val_acc":float(val_acc),
                                   "input":"PRNU int8 → float32/127, zero-mean",
                                   "shape":[1,256,256]}}
            torch.save(best_state, BEST_PATH)
            print("  ↳ saved best:", BEST_PATH)
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= EARLY_STOP:
                print(f"⏹ Early stop (no AUC improvement {EARLY_STOP} epochs).")
                break

    # 載入最佳權重並在 test 評估
    if best_state is None and BEST_PATH.exists():
        best_state = torch.load(BEST_PATH, map_location="cpu")
    if best_state is not None:
        model.load_state_dict(best_state["model"])

    test_acc, test_auc, y_true, y_pred, prob1 = evaluate(model, test_loader, device)
    print(f"[TEST] acc={test_acc:.4f} auc={test_auc:.4f}")
    print(confusion_matrix(y_true, y_pred))
    from sklearn.metrics import classification_report
    print(classification_report(y_true, y_pred, target_names=["real","fake"], digits=4))
    return model

model = train_prnu_fastcnn()

# ---- 單張 .npy 推論 ----
def predict_prnu_npy(npy_path: Path):
    arr = np.load(npy_path, allow_pickle=False).astype(np.float32)
    x = (arr/127.0); x = x - x.mean()
    x = torch.from_numpy(x[None, None, ...])
    device = next(model.parameters()).device
    with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16):
        logit = model(x.to(device).contiguous(memory_format=torch.channels_last))
        prob  = torch.softmax(logit, dim=1)[0,1].item()
        pred  = int(prob >= 0.5)  # 1=fake, 0=real
    return pred, prob


In [None]:
# ===== ELA(int8) → Fast CNN（不使用 split 檔，隨機切 8:1:1）=====
from pathlib import Path
import os, random, re, json
import numpy as np

# tqdm（Notebook 友善）
try:
    from tqdm.notebook import tqdm
except Exception:
    from tqdm.auto import tqdm
TQDM_KW = dict(dynamic_ncols=True, leave=False)

# ---------------- Config ----------------
SEED = 42
random.seed(SEED); np.random.seed(SEED)

SCRIPT_ROOT = Path("/home/yaya/ai-detect-proj/Script")
FEA_ROOT    = SCRIPT_ROOT / "features_npy"
ELA_REAL    = FEA_ROOT / "ela_real_npy"
ELA_FAKE    = FEA_ROOT / "ela_fake_npy"

SAVE_DIR    = SCRIPT_ROOT / "saved_models"; SAVE_DIR.mkdir(parents=True, exist_ok=True)
BEST_PATH   = SAVE_DIR / "ela_fastcnn_u8_best.pt"

# 每類取多少；None = 全部可用
N_PER_CLASS = 10000

# 訓練設定
BATCH       = 64          # 4060 上 64~128 較穩；OOM 就降
EPOCHS      = 15
LR          = 2e-3
WEIGHT_DECAY= 1e-4
EARLY_STOP  = 5

# DataLoader
NUM_WORKERS     = min(8, os.cpu_count() or 4)
PIN_MEMORY      = True
PREFETCH_FACTOR = 4
PERSISTENT      = True

# Dataset 快取："ram" | "memmap" | None
CACHE_MODE = "ram"

# ELA 正規化：i8(-128..127) → (i8+128)/255 ∈ [0,1]，再做零均值
ELA_ZERO_MEAN = True

# ---------------- 掃檔（支援 __qXX，優先 q90） ----------------
def list_npy(d: Path):
    assert d.exists(), f"Not found: {d}"
    return sorted([p for p in d.glob("*.npy")])

_q_pat = re.compile(r"__q(\d+)$")
def base_id_from_stem(stem: str):
    m = _q_pat.search(stem); return stem[:m.start()] if m else stem
def quality_from_stem(stem: str):
    m = _q_pat.search(stem); return int(m.group(1)) if m else None

def build_id2path(files):
    buckets = {}
    for p in files:
        b = base_id_from_stem(p.stem)
        buckets.setdefault(b, []).append(p)
    id2path = {}
    for b, ps in buckets.items():
        if len(ps) == 1:
            id2path[b] = ps[0]
        else:
            # 優先 q90，其次挑離 90 最近
            scored = []
            for pp in ps:
                q = quality_from_stem(pp.stem)
                scored.append((0 if q == 90 else (abs(q-90) if q is not None else 999), pp))
            scored.sort(key=lambda x: (x[0], str(x[1])))
            id2path[b] = scored[0][1]
    return id2path

real_files = list_npy(ELA_REAL)
fake_files = list_npy(ELA_FAKE)
assert real_files and fake_files, "找不到 ELA 特徵 .npy，請先完成特徵抽取。"

id2path = {
    "real": build_id2path(real_files),
    "fake": build_id2path(fake_files),
}

# ---------------- 隨機切 8:1:1（不使用 split 檔） ----------------
from sklearn.model_selection import train_test_split

def cap(ids, k):
    ids = list(ids); random.shuffle(ids)
    return ids if (k is None or len(ids) <= k) else ids[:k]

real_ids = cap(list(id2path["real"].keys()), N_PER_CLASS)
fake_ids = cap(list(id2path["fake"].keys()), N_PER_CLASS)

r_tr, r_tmp = train_test_split(real_ids, test_size=0.2, random_state=SEED)
f_tr, f_tmp = train_test_split(fake_ids, test_size=0.2, random_state=SEED)
r_va, r_te  = train_test_split(r_tmp, test_size=0.5, random_state=SEED)  # 0.1/0.1
f_va, f_te  = train_test_split(f_tmp, test_size=0.5, random_state=SEED)

splits = {
    "train": [(i,0) for i in r_tr] + [(i,1) for i in f_tr],
    "val":   [(i,0) for i in r_va] + [(i,1) for i in f_va],
    "test":  [(i,0) for i in r_te] + [(i,1) for i in f_te],
}
for sp in splits:
    random.shuffle(splits[sp])
    n0 = sum(1 for _,y in splits[sp] if y==0); n1 = len(splits[sp])-n0
    print(f"{sp}: total={len(splits[sp])} | real={n0} fake={n1}")

# ---------------- Dataset / DataLoader ----------------
import torch
from torch.utils.data import Dataset, DataLoader

class ELAFromNPY(Dataset):
    """ELA int8(u8-128) → float32；(i8+128)/255 ∈ [0,1]，可選 zero-mean；RAM/memmap 快取。"""
    def __init__(self, pairs, id2path, augment=False, cache_mode="ram"):
        self.pairs = pairs
        self.id2p  = id2path
        self.augment = augment
        self.cache_mode = cache_mode
        self.cache = []
        if cache_mode == "ram":
            self.cache = [None]*len(pairs)
            for i, (idx, lab) in enumerate(pairs):
                d = "real" if lab==0 else "fake"
                p = self.id2p[d][idx]
                self.cache[i] = np.load(p, allow_pickle=False).copy()  # int8

    def __len__(self): return len(self.pairs)

    def __getitem__(self, i):
        idx, lab = self.pairs[i]
        if self.cache_mode == "ram":
            arr = self.cache[i]
        else:
            d = "real" if lab==0 else "fake"
            p = self.id2p[d][idx]
            arr = np.load(p, allow_pickle=False, mmap_mode='r' if self.cache_mode=="memmap" else None)

        x = (arr.astype(np.float32) + 128.0) / 255.0
        if ELA_ZERO_MEAN:
            x = x - x.mean()
        x = np.clip(x, -2.0, 2.0)
        x = x[None, ...]  # [1,H,W]
        return torch.from_numpy(x), torch.tensor(lab, dtype=torch.long)

def make_loader(pairs, train=False):
    ds = ELAFromNPY(pairs, id2path, augment=train, cache_mode=CACHE_MODE)
    return DataLoader(
        ds, batch_size=BATCH, shuffle=train,
        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
        prefetch_factor=PREFETCH_FACTOR, persistent_workers=PERSISTENT,
        drop_last=train
    )

train_loader = make_loader(splits["train"], train=True)
val_loader   = make_loader(splits["val"],   train=False)
test_loader  = make_loader(splits["test"],  train=False)

# ---------------- 模型：Depthwise-Separable CNN（快） ----------------
import torch.nn as nn
import torch.nn.functional as F

class DSBlock(nn.Module):
    def __init__(self, c_in, c_out, stride=1):
        super().__init__()
        self.dw = nn.Conv2d(c_in, c_in, 3, stride=stride, padding=1, groups=c_in, bias=False)
        self.bn1= nn.BatchNorm2d(c_in)
        self.pw = nn.Conv2d(c_in, c_out, 1, bias=False)
        self.bn2= nn.BatchNorm2d(c_out)
        self.act = nn.ReLU(inplace=True)
    def forward(self, x):
        x = self.act(self.bn1(self.dw(x)))
        x = self.act(self.bn2(self.pw(x)))
        return x

class ELAFastCNN(nn.Module):
    def __init__(self, nc=1, num_classes=2):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(nc, 32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32), nn.ReLU(inplace=True)
        )
        self.stage = nn.Sequential(
            DSBlock(32, 64,  stride=1),   # H
            DSBlock(64, 128, stride=2),   # H/2
            DSBlock(128,128, stride=1),   # H/2
            DSBlock(128,256, stride=2),   # H/4
            DSBlock(256,256, stride=1),   # H/4
        )
        self.head = nn.Sequential(
            nn.Conv2d(256, 256, 1, bias=False),
            nn.BatchNorm2d(256), nn.ReLU(inplace=True)
        )
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.dropout = nn.Dropout(0.2)
        self.fc   = nn.Linear(256, num_classes)
    def forward(self, x):
        x = self.stem(x); x = self.stage(x); x = self.head(x)
        x = self.pool(x).squeeze(-1).squeeze(-1)
        x = self.dropout(x)
        return self.fc(x)

# ---------------- 訓練 & 評估 ----------------
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report, confusion_matrix

def set_seed(seed=SEED):
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    random.seed(seed); np.random.seed(seed)

def evaluate(model, loader, device):
    model.eval()
    ys, logits = [], []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device, non_blocking=True).contiguous(memory_format=torch.channels_last)
            y = y.to(device, non_blocking=True)
            out = model(x)
            logits.append(out.detach().cpu().numpy())
            ys.append(y.detach().cpu().numpy())
    logits = np.concatenate(logits, 0)
    y_true = np.concatenate(ys, 0)
    y_pred = logits.argmax(1)
    acc = (y_pred == y_true).mean()
    prob1 = torch.softmax(torch.from_numpy(logits), dim=1).numpy()[:,1]
    auc = roc_auc_score(y_true, prob1)
    return acc, auc, y_true, y_pred, prob1

import torch
def train_ela_fastcnn():
    set_seed(SEED)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.backends.cudnn.benchmark = True

    model  = ELAFastCNN().to(device)
    model  = model.to(memory_format=torch.channels_last)

    AMP_DTYPE = torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else torch.float16
    USE_SCALER = (AMP_DTYPE == torch.float16)
    scaler = torch.cuda.amp.GradScaler(enabled=USE_SCALER)

    # 類別權重
    n0 = sum(1 for _,y in splits["train"] if y==0); n1 = len(splits["train"])-n0
    w = torch.tensor([1.0/n0, 1.0/n1], dtype=torch.float32, device=device); w = w/w.mean()

    crit  = torch.nn.CrossEntropyLoss(weight=w)
    optim = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=EPOCHS)

    best_auc, best_state, no_improve = -1.0, None, 0
    for ep in range(1, EPOCHS+1):
        model.train(); losses = []
        pbar = tqdm(train_loader, desc=f"train ep{ep}", **TQDM_KW)
        for x,y in pbar:
            x = x.to(device, non_blocking=True).contiguous(memory_format=torch.channels_last)
            y = y.to(device, non_blocking=True)
            with torch.cuda.amp.autocast(dtype=AMP_DTYPE):
                out  = model(x)
                loss = crit(out, y)
            optim.zero_grad(set_to_none=True)
            if USE_SCALER:
                scaler.scale(loss).backward(); scaler.step(optim); scaler.update()
            else:
                loss.backward(); optim.step()
            losses.append(loss.item())
            pbar.set_postfix(loss=f"{np.mean(losses):.4f}")
        sched.step()

        val_acc, val_auc, *_ = evaluate(model, val_loader, device)
        print(f"[EP {ep:02d}] train_loss={np.mean(losses):.4f} | val acc={val_acc:.4f} auc={val_auc:.4f}")

        if val_auc > best_auc:
            best_auc = val_auc
            best_state = {"model": model.state_dict(),
                          "meta": {"arch":"ELAFastCNN","seed":SEED,"epochs_done":ep,
                                   "val_auc":float(val_auc),"val_acc":float(val_acc),
                                   "input":"ELA int8→(i8+128)/255"+("→zero-mean" if ELA_ZERO_MEAN else ""),
                                   "shape":"[1,H,W]"}}
            torch.save(best_state, BEST_PATH); print("  ↳ saved best:", BEST_PATH); no_improve = 0
        else:
            no_improve += 1
            if no_improve >= EARLY_STOP:
                print(f"⏹ Early stop (no AUC improvement {EARLY_STOP} epochs)."); break

    if best_state is None and BEST_PATH.exists():
        best_state = torch.load(BEST_PATH, map_location="cpu")
    if best_state is not None:
        model.load_state_dict(best_state["model"])

    test_acc, test_auc, y_true, y_pred, prob1 = evaluate(model, test_loader, device)
    print(f"[TEST] acc={test_acc:.4f} auc={test_auc:.4f}")
    print(confusion_matrix(y_true, y_pred))
    from sklearn.metrics import classification_report
    print(classification_report(y_true, y_pred, target_names=["real","fake"], digits=4))
    return model

model = train_ela_fastcnn()

# ---- 單張 .npy 推論 ----
def predict_ela_npy(npy_path: Path):
    arr = np.load(npy_path, allow_pickle=False).astype(np.float32)
    x = (arr + 128.0) / 255.0
    if ELA_ZERO_MEAN: x = x - x.mean()
    x = torch.from_numpy(x[None, None, ...])
    device = next(model.parameters()).device
    with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16):
        logit = model(x.to(device).contiguous(memory_format=torch.channels_last))
        prob  = torch.softmax(logit, dim=1)[0,1].item()
        pred  = int(prob >= 0.5)  # 1=fake, 0=real
    return pred, prob


In [None]:
# ================ Export CLIP pooled(1024) per-split → .npy (memmap) ================
# 不做訓練，只把資料準備好。支援：
#  - 量化 .npz：優先讀 'pooled'；沒有就動態還原再 pool
#  - 原始 .npy：形狀 [257,1024] 或 [1024]
# 產物：exports/clip_pooled/<split>/{X.npy,y.npy,ids.txt,meta.json}

from pathlib import Path
import os, json, random, numpy as np

# tqdm（console 版，避免 IProgress）
try:
    from tqdm import tqdm
except Exception:
    def tqdm(x, **kw): return x

# ---------------- Config（改這裡） ----------------
SEED = 42
random.seed(SEED); np.random.seed(SEED)

SCRIPT_ROOT = Path("/home/yaya/ai-detect-proj/Script")
SPLIT_JSON  = SCRIPT_ROOT / "splits/combined_split.json"
SPLIT_KEY   = "iid"          # "iid" / "smoke_10p" / "ood_gen_strict.sd3" ...
EXPORT_ROOT = SCRIPT_ROOT / "exports/clip_pooled"  # 輸出根目錄

# 特徵來源（兩種都會掃，.npz 優先）
CLIP_REAL_DIRS = [
    SCRIPT_ROOT / "features_256q/clip_real_q",   # 量化
    SCRIPT_ROOT / "features_256/clip_real_npy",  # 原始
]
CLIP_FAKE_DIRS = [
    SCRIPT_ROOT / "features_256q/clip_fake_q",
    SCRIPT_ROOT / "features_256/clip_fake_npy",
]

SAVE_DTYPE   = "float32"   # 輸出向量精度：'float32'（訓練友善）或 'float16'（更省空間）
SMOKE_FRAC   = None        # 若想只先匯出部分（例如 0.1 為 10%），預設 None=全量

# ---------------- 掃檔 → id→path（.npz 優先） ----------------
def scan_first_hit(dirs, exts=(".npz",".npy")):
    lut = {}
    for d in dirs:
        if not d.exists(): continue
        for ext in exts:
            for p in sorted(d.glob(f"*.{ext.lstrip('.')}")):
                k = p.stem
                if k not in lut:
                    lut[k] = p
    return lut

id2path = {"real": scan_first_hit(CLIP_REAL_DIRS),
           "fake": scan_first_hit(CLIP_FAKE_DIRS)}
assert id2path["real"] and id2path["fake"], "❌ 找不到 CLIP 檔案，請檢查路徑"

# ---------------- 讀 unified split（支援 dot-path） ----------------
def load_split_ids(json_path: Path, split_key: str):
    data = json.loads(json_path.read_text())
    node = data
    if "ids" in node and split_key in (None,"","ids"):
        node = node["ids"]
    else:
        for k in split_key.split("."):
            node = node[k]
    assert all(k in node for k in ("train","val","test"))
    return node["train"], node["val"], node["test"]

def attach(ids):
    pairs, miss = [], 0
    for i in ids:
        if i in id2path["real"]: pairs.append((i,0))
        elif i in id2path["fake"]: pairs.append((i,1))
        else: miss += 1
    if miss: print(f"⚠️ split 有 {miss} 個 id 在磁碟找不到，已忽略。")
    return pairs

def stratified_frac(pairs, frac, seed=SEED):
    if not frac or frac>=1: return pairs
    r = [(i,y) for (i,y) in pairs if y==0]
    f = [(i,y) for (i,y) in pairs if y==1]
    rnd = random.Random(seed); rnd.shuffle(r); rnd.shuffle(f)
    return r[:max(1,int(len(r)*frac))] + f[:max(1,int(len(f)*frac))]

tr_ids, va_ids, te_ids = load_split_ids(SPLIT_JSON, SPLIT_KEY)
train_pairs, val_pairs, test_pairs = attach(tr_ids), attach(va_ids), attach(te_ids)
if SMOKE_FRAC:
    train_pairs = stratified_frac(train_pairs, SMOKE_FRAC, SEED)
    val_pairs   = stratified_frac(val_pairs,   SMOKE_FRAC, SEED+1)
    test_pairs  = stratified_frac(test_pairs,  SMOKE_FRAC, SEED+2)

for name, pairs in [("train",train_pairs),("val",val_pairs),("test",test_pairs)]:
    n0 = sum(1 for _,y in pairs if y==0); n1 = len(pairs)-n0
    print(f"{name}: total={len(pairs)} | real={n0} fake={n1}")

# ---------------- 量化 .npz → 還原（如需） + pooling(mean_excl_cls) ----------------
def _unpack_int4(packed, orig_size):
    u = np.empty(orig_size + (orig_size % 2), dtype=np.uint8)
    u[0::2] = packed & 0x0F
    u[1::2] = (packed >> 4) & 0x0F
    u = u[:orig_size]
    return (u.astype(np.int16) - 8).astype(np.int8)

def dequantize_npz(npz_path: Path):
    z = np.load(npz_path, allow_pickle=False)
    meta = json.loads(str(z["meta"][()]))
    mode = meta["mode"]; shape = tuple(meta["shape"])
    if mode == "fp16":
        return z["q"].astype(np.float32).reshape(shape)
    if "int8" in mode:
        q = z["q"].astype(np.int8)
        if mode == "int8_tensor":
            S = float(z["scales"][0]); return (q.astype(np.float32)*S).reshape(shape)
        if mode == "int8_row":
            if len(shape)==1:
                S = float(z["scales"][0]); return (q.astype(np.float32)*S).reshape(shape)
            T,D = shape; S = z["scales"].astype(np.float32)
            out = np.empty((T,D), np.float32)
            for t in range(T): out[t] = q[t].astype(np.float32)*S[t]
            return out
        if mode == "int8_block32":
            B = int(meta["block"])
            if len(shape)==1:
                D = shape[0]; S = z["scales"].astype(np.float32); out = np.empty((D,), np.float32)
                nB = (D+B-1)//B
                for b in range(nB):
                    s = slice(b*B, min((b+1)*B,D)); out[s] = q[s].astype(np.float32)*S[b]
                return out
            else:
                T,D = shape; S = z["scales"].astype(np.float32); out = np.empty((T,D), np.float32)
                nB = (D+B-1)//B
                for t in range(T):
                    for b in range(nB):
                        s = slice(b*B, min((b+1)*B,D)); out[t,s] = q[t,s].astype(np.float32)*S[t,b]
                return out
    if "int4" in mode:
        packed = z["q"].astype(np.uint8); B = int(meta["block"])
        if len(shape)==1:
            D = shape[0]; S = z["scales"].astype(np.float32)
            q = _unpack_int4(packed, D); out = np.empty((D,), np.float32)
            nB = (D+B-1)//B
            for b in range(nB):
                s = slice(b*B, min((b+1)*B,D)); out[s] = q[s].astype(np.float32)*S[b]
            return out
        else:
            T,D = shape; S = z["scales"].astype(np.float32)
            q = _unpack_int4(packed, T*D).reshape(T,D); out = np.empty((T,D), np.float32)
            nB = (D+B-1)//B
            for t in range(T):
                for b in range(nB):
                    s = slice(b*B, min((b+1)*B,D)); out[t,s] = q[t,s].astype(np.float32)*S[t,b]
            return out
    raise ValueError(f"Unknown quant mode: {mode}")

def pooled_vec_from_file(p: Path):
    if p.suffix == ".npz":
        z = np.load(p, allow_pickle=False)
        if "pooled" in z and z["pooled"].size > 0:   # 直接用預存 pooled（快）
            v = z["pooled"].astype(np.float32)
        else:
            X = dequantize_npz(p)                    # [T,D] 或 [D]
            v = X if X.ndim==1 else X[1:].mean(axis=0)   # mean_excl_cls
    else:
        arr = np.load(p, allow_pickle=False)
        v = arr.astype(np.float32) if arr.ndim==1 else arr[1:].astype(np.float32).mean(axis=0)
    # L2 normalize
    n = np.linalg.norm(v) + 1e-12
    return (v / n).astype(np.float32)

# ---------------- 輸出工具：.npy（open_memmap，邊寫邊落盤） ----------------
def export_split(name, pairs):
    if not pairs:
        print(f"{name}: 無資料，略過"); return
    out_dir = EXPORT_ROOT / SPLIT_KEY / name
    out_dir.mkdir(parents=True, exist_ok=True)

    # 先 peek 一個向量決定 D
    iid0, y0 = pairs[0]
    p0 = id2path["real" if y0==0 else "fake"][iid0]
    v0 = pooled_vec_from_file(p0)
    D = v0.shape[0]

    # 建立 memmap .npy
    from numpy.lib.format import open_memmap
    X_path = out_dir / "X.npy"
    X_mm   = open_memmap(X_path, mode="w+", dtype=SAVE_DTYPE, shape=(len(pairs), D))
    y_path = out_dir / "y.npy"
    y_arr  = np.empty((len(pairs),), dtype=np.int32)

    ids_txt = (out_dir / "ids.txt").open("w")
    print(f"→ Export {name}: N={len(pairs)} D={D} → {X_path}")

    # 第 0 筆
    X_mm[0] = v0.astype(SAVE_DTYPE, copy=False)
    y_arr[0] = y0
    ids_txt.write(f"{iid0}\t{y0}\t{p0}\n")

    # 其餘
    for i,(iid,lab) in enumerate(tqdm(pairs[1:], total=len(pairs)-1, desc=f"build {name}")):
        p = id2path["real" if lab==0 else "fake"][iid]
        v = pooled_vec_from_file(p)
        X_mm[i+1] = v.astype(SAVE_DTYPE, copy=False)
        y_arr[i+1] = lab
        if i < 10:  # 前幾筆留下路徑方便除錯
            ids_txt.write(f"{iid}\t{lab}\t{p}\n")

    # 寫出 y / meta
    np.save(y_path, y_arr, allow_pickle=False)
    meta = {
        "split_key": SPLIT_KEY,
        "split": name,
        "dtype": SAVE_DTYPE,
        "N": int(len(pairs)),
        "D": int(D),
        "pool": "mean_excl_cls",
        "l2norm": True,
        "source_priority": [str(d) for d in (CLIP_REAL_DIRS+CLIP_FAKE_DIRS)],
    }
    (out_dir / "meta.json").write_text(json.dumps(meta, ensure_ascii=False, indent=2))
    ids_txt.close()

    # 小總結
    szX = X_path.stat().st_size; szy = y_path.stat().st_size
    def human(n): 
        u=["B","KiB","MiB","GiB","TiB"]; i=0; f=float(n)
        while f>=1024 and i<len(u)-1: f/=1024; i+=1
        return f"{f:.1f} {u[i]}"
    print(f"{name} saved: X={human(szX)} y={human(szy)} → {out_dir}")

# ---------------- Run：逐 split 匯出 ----------------
export_split("train", train_pairs)
export_split("val",   val_pairs)
export_split("test",  test_pairs)

print("✅ Done. Exports at:", EXPORT_ROOT / SPLIT_KEY)


In [None]:
# ===================== Step 1 + 2: 原始 CLIP 池化 → per-file；三模態交集 → 共用 splits =====================
from pathlib import Path
import json, random, time, re, numpy as np

# ---- 基本設定（改這裡）----
SEED = 42
random.seed(SEED); np.random.seed(SEED)

SCRIPT_ROOT = Path("/home/yaya/ai-detect-proj/Script")
FEA_ROOT    = SCRIPT_ROOT / "features_256"

# 1) 原始 CLIP token map 的來源（每圖一檔 .npy）
CLIP_RAW_REAL = FEA_ROOT / "clip_real_npy"
CLIP_RAW_FAKE = FEA_ROOT / "clip_fake_npy"

# 1) 池化後要存到哪（每圖一檔 .npy，float32 1024 維）
CLIP_POOL_REAL = FEA_ROOT / "clip_pooled_real_npy"
CLIP_POOL_FAKE = FEA_ROOT / "clip_pooled_fake_npy"

# PRNU/ELA 來源（每圖一檔 .npy）
PRNU_REAL = FEA_ROOT / "prnu_real_npy"
PRNU_FAKE = FEA_ROOT / "prnu_fake_npy"
ELA_REAL  = FEA_ROOT / "ela_real_npy"
ELA_FAKE  = FEA_ROOT / "ela_fake_npy"

# 2) Split 存放位置與名稱
SPLIT_OUT   = SCRIPT_ROOT / "splits" / "combined_split.json"
SMOKE_FRAC  = 0.10           # 10% smoke
IID_RATIO   = (0.8, 0.1, 0.1)  # train/val/test

# 要做 OOD 的假圖生成器代號（從檔名 stem 的前綴推斷）
GEN_CANON = {
    "sd3": ["sd3"],
    "midjourney": ["midjourney", "midjourney-v6", "midjourney-v6-llava", "mj"],
    "flux": ["flux", "black-forest-labs", "flux-dev", "flux-1"],
    "dalle3": ["dalle3", "dall-e-3", "dalle-3"]
}

# ===================== 工具 =====================
try:
    from tqdm import tqdm
except Exception:
    def tqdm(x, **kw): return x

def ensure_dir(d: Path):
    d.mkdir(parents=True, exist_ok=True)

def list_files(d: Path, ext=".npy"):
    return sorted([p for p in d.glob(f"*{ext}")])

def stem_set(d: Path, ext=".npy"):
    return set(p.stem for p in d.glob(f"*{ext}"))

def pool_clip_file(npy_path: Path) -> np.ndarray:
    """
    載入原始 CLIP 特徵：
      - 若 shape=(257,1024)：採 mean_excl_cls（排除 CLS）→ 1024
      - 若 shape=(1024,)   ：直接視為 pooled
    之後做 L2 normalize，回傳 float32[1024]
    """
    arr = np.load(npy_path, allow_pickle=False)
    if arr.ndim == 2 and arr.shape[1] == 1024:
        v = arr[1:].astype(np.float32).mean(axis=0)
    elif arr.ndim == 1 and arr.shape[0] == 1024:
        v = arr.astype(np.float32)
    else:
        raise ValueError(f"Unsupported CLIP shape {arr.shape} @ {npy_path.name}")
    n = np.linalg.norm(v) + 1e-12
    return (v / n).astype(np.float32)

def atomic_save(path: Path, arr: np.ndarray):
    tmp = path.with_suffix(path.suffix + ".tmp")
    np.save(tmp, arr, allow_pickle=False)
    tmp.replace(path)

def pool_all_clip(src_dir: Path, dst_dir: Path, limit=None):
    ensure_dir(dst_dir)
    files = list_files(src_dir, ".npy")
    if limit: files = files[:limit]
    n_ok, n_skip = 0, 0
    for p in tqdm(files, desc=f"pool {src_dir.name} → {dst_dir.name}"):
        out = dst_dir / (p.stem + ".npy")
        if out.exists(): 
            n_skip += 1
            continue
        try:
            v = pool_clip_file(p)
            atomic_save(out, v)
            n_ok += 1
        except Exception as e:
            print(f"[pool] skip {p.name} | {e}")
    print(f"→ {dst_dir}: wrote={n_ok} skipped={n_skip}")

def guess_generator(img_id: str) -> str | None:
    """
    依檔名 stem 的前綴（__ 前面的 dataset 名）推斷生成器：
      e.g. 'sd3__xxxx', 'midjourney-v6-llava__xxxx', 'flux__xxxx', 'dalle3__xxxx'
    回傳 'sd3'/'midjourney'/'flux'/'dalle3' 或 None（其他）
    """
    # 取第一段前綴
    prefix = img_id.split("__", 1)[0].lower()
    # 移除非字母數字與 - _
    prefix = re.sub(r"[^a-z0-9\-_]+", "", prefix)
    for canon, aliases in GEN_CANON.items():
        for a in aliases:
            if prefix.startswith(a):
                return canon
    return None

def stratified_split(real_ids, fake_ids, ratios=(0.8,0.1,0.1), seed=SEED):
    """不依賴 sklearn 的分層切分，回傳 dict: {'train':[...], 'val':[...], 'test':[...]}"""
    assert abs(sum(ratios)-1.0) < 1e-6 and len(ratios)==3
    r = list(real_ids); f = list(fake_ids)
    rnd = random.Random(seed)
    rnd.shuffle(r); rnd.shuffle(f)
    def cut(arr):
        n = len(arr)
        n_tr = int(round(n*ratios[0]))
        n_va = int(round(n*ratios[1]))
        n_te = n - n_tr - n_va
        return arr[:n_tr], arr[n_tr:n_tr+n_va], arr[n_tr+n_va:]
    r_tr, r_va, r_te = cut(r)
    f_tr, f_va, f_te = cut(f)
    return {
        "train": r_tr + f_tr,
        "val":   r_va + f_va,
        "test":  r_te + f_te,
    }

def summary(name, ids_dict):
    for sp in ["train","val","test"]:
        ids = ids_dict[sp]
        n0 = sum(1 for i in ids if "__" in i and not guess_generator(i))  # 估 real by no-gen? (僅供參考)
    # 真正 summary 我們用路徑存在性判定

# ===================== Step 1：原始 CLIP → 池化 per-file =====================
print("== Step 1: Pool original CLIP to per-file 1024 ==")
assert CLIP_RAW_REAL.exists() and CLIP_RAW_FAKE.exists(), "找不到原始 CLIP .npy 資料夾"
pool_all_clip(CLIP_RAW_REAL, CLIP_POOL_REAL)
pool_all_clip(CLIP_RAW_FAKE, CLIP_POOL_FAKE)

# ===================== Step 2：三模態取交集 → 產出 splits =====================
print("\n== Step 2: Build unified splits (IID / OOD / smoke_10p) with tri-modal intersection ==")

# 交集（必須三模態都有）
ids_real = stem_set(CLIP_POOL_REAL) & stem_set(PRNU_REAL) & stem_set(ELA_REAL)
ids_fake = stem_set(CLIP_POOL_FAKE) & stem_set(PRNU_FAKE) & stem_set(ELA_FAKE)
print(f"交集數量 → real: {len(ids_real)} | fake: {len(ids_fake)}")

# --- IID：8/1/1 分層切分 ---
iid = stratified_split(sorted(ids_real), sorted(ids_fake), ratios=IID_RATIO, seed=SEED)

def count_split(ids_dict, ids_real_all, ids_fake_all):
    def lab(i):
        return 0 if i in ids_real_all else 1
    for sp in ["train","val","test"]:
        ids = ids_dict[sp]
        n0 = sum(1 for i in ids if i in ids_real_all)
        n1 = len(ids) - n0
        print(f"[{sp}] total={len(ids)} | real={n0} fake={n1}")

print("== IID summary ==")
count_split(iid, ids_real, ids_fake)

# --- OOD：對每個生成器 g，train/val 不含 g，test 全部 g（+ real 的 test） ---
# 先把 fake 依生成器分桶
fake_by_gen = {}
for i in ids_fake:
    g = guess_generator(i) or "other"
    fake_by_gen.setdefault(g, []).append(i)

def ood_split_for(gen_key: str):
    # real 跟 IID 使用同一個切分（確保各版本一致）
    r_tr = [i for i in iid["train"] if i in ids_real]
    r_va = [i for i in iid["val"]   if i in ids_real]
    r_te = [i for i in iid["test"]  if i in ids_real]
    # 假圖：非 gen_key 的 → 按 IID 的分法放 train/val/test；gen_key 的 → 全部進 test
    allowed = set().union(*[set(v) for k,v in fake_by_gen.items() if k != gen_key])
    holdout = set(fake_by_gen.get(gen_key, []))
    f_tr = [i for i in iid["train"] if i in allowed]
    f_va = [i for i in iid["val"]   if i in allowed]
    f_te = [i for i in iid["test"]  if i in allowed] + sorted(list(holdout))
    return {"train": r_tr + f_tr, "val": r_va + f_va, "test": r_te}

ood_gen = {}
for g in ["sd3","midjourney","flux","dalle3"]:
    sp = ood_split_for(g)
    ood_gen[g] = sp
    # 摘要
    ntr = (sum(i in ids_real for i in sp["train"]), sum(i in ids_fake for i in sp["train"]))
    nva = (sum(i in ids_real for i in sp["val"]),   sum(i in ids_fake for i in sp["val"]))
    nte = (sum(i in ids_real for i in sp["test"]),  sum(i in ids_fake for i in sp["test"]))
    print(f"== OOD-{g} summary ==")
    print(f"[train] total={len(sp['train'])} | real={ntr[0]} fake={ntr[1]}")
    print(f"[val]   total={len(sp['val'])}   | real={nva[0]} fake={nva[1]}")
    # 顯示 test 裡 holdout g 的數量
    n_g_test = sum(1 for i in sp["test"] if guess_generator(i)==g)
    print(f"[test]  total={len(sp['test'])}  | real={nte[0]} fake={nte[1]} (holdout {g} in test: {n_g_test})")

# --- smoke_10p：從 IID 各 split 各自取 10%（分層隨機） ---
def take_frac(ids_list, frac, seed):
    rnd = random.Random(seed)
    ids_r = [i for i in ids_list if i in ids_real]
    ids_f = [i for i in ids_list if i in ids_fake]
    rnd.shuffle(ids_r); rnd.shuffle(ids_f)
    kr = max(1, int(round(len(ids_r)*frac))) if ids_r else 0
    kf = max(1, int(round(len(ids_f)*frac))) if ids_f else 0
    return ids_r[:kr] + ids_f[:kf]

smoke_10p = {
    "train": take_frac(iid["train"], SMOKE_FRAC, SEED),
    "val":   take_frac(iid["val"],   SMOKE_FRAC, SEED+1),
    "test":  take_frac(iid["test"],  SMOKE_FRAC, SEED+2),
}
print("== smoke_10p summary ==")
count_split(smoke_10p, ids_real, ids_fake)

# ===================== 寫出 JSON（會保留舊檔，再覆蓋） =====================
out = {
    "meta": {
        "seed": SEED,
        "created_at": time.strftime("%Y-%m-%d %H:%M:%S"),
        "tri_sync_modalities": ["clip_pooled(1024)", "prnu", "ela"],
        "paths": {
            "clip_pooled_real": str(CLIP_POOL_REAL),
            "clip_pooled_fake": str(CLIP_POOL_FAKE),
            "prnu_real": str(PRNU_REAL), "prnu_fake": str(PRNU_FAKE),
            "ela_real":  str(ELA_REAL),  "ela_fake":  str(ELA_FAKE),
        },
        "iid_ratio": IID_RATIO,
        "generators": list(GEN_CANON.keys()),
        "note": "IDs are image stems without extension; splits contain tri-modal intersection only."
    },
    "iid": iid,
    "ood_gen": ood_gen,
    "smoke_10p": smoke_10p
}

ensure_dir(SPLIT_OUT.parent)
# 備份舊檔
if SPLIT_OUT.exists():
    bk = SPLIT_OUT.with_suffix(f".bak_{int(time.time())}.json")
    SPLIT_OUT.replace(bk)
    print("↻ backup old split →", bk)
SPLIT_OUT.write_text(json.dumps(out, ensure_ascii=False, indent=2))
print("✅ saved:", SPLIT_OUT)


In [16]:
from pathlib import Path
import os

# 兩個資料夾都修
DIRS = [
    Path("/home/yaya/ai-detect-proj/Script/features_256/clip_pooled_real_npy"),
    Path("/home/yaya/ai-detect-proj/Script/features_256/clip_pooled_fake_npy"),
]

for root in DIRS:
    if not root.exists(): 
        print("skip (not found):", root); 
        continue
    n_fix = n_drop = 0
    for tmpf in root.glob("*.npy.tmp.npy"):
        # 變回「正確檔名」：把 ".npy.tmp.npy" → ".npy"
        dst = tmpf.with_name(tmpf.name.replace(".npy.tmp.npy", ".npy"))
        if dst.exists():
            # 目標已存在：刪掉這個多餘的 .tmp 檔
            tmpf.unlink()
            n_drop += 1
        else:
            os.replace(tmpf, dst)
            n_fix += 1
    print(f"[{root.name}] fixed={n_fix} removed_tmp={n_drop}")


[clip_pooled_real_npy] fixed=0 removed_tmp=0
[clip_pooled_fake_npy] fixed=60106 removed_tmp=0


In [17]:
import numpy as np, random
check_dirs = [Path("/home/yaya/ai-detect-proj/Script/features_256/clip_pooled_fake_npy")]
for d in check_dirs:
    files = list(d.glob("*.npy"))
    for p in random.sample(files, min(5, len(files))):
        arr = np.load(p, allow_pickle=False)
        print(p.name, arr.shape)

FLUX__0000_00018124.npy (1024,)
SD3__0000_00049062.npy (1024,)
SD3__0000_00021448.npy (1024,)
dalle3__008127.npy (1024,)
SD3__0000_00047005.npy (1024,)
