In [4]:
# ==== Downsample FAKE: each generator/source -> keep at most 25K IDs ====
from pathlib import Path
import re, random, time, shutil
from collections import defaultdict, Counter

# ---------- 設定 ----------
FEA_ROOT = Path("/home/yaya/ai-detect-proj/Script/features_256")  # 你的特徵根目錄（會遞迴掃描）
PER_GEN_TARGET = 25_000               # 每個生成器想保留的 ID 數上限
SEED = 42; random.seed(SEED)

DRY_RUN = False                        # 先預覽；OK 後改 False 真的動作
MODE = "trash"                        # "trash"=移到垃圾桶 | "delete"=直接刪除
KEEP_ALL_ELA_QUALS = True             # True=保留該ID所有 ELA 品質檔；False=只留一個（靠近 q90）

TRASH_DIR = FEA_ROOT / f"_trash_fake_pergen_{time.strftime('%Y%m%d_%H%M%S')}"

# ---------- 小工具 ----------
ELA_Q_PAT = re.compile(r"__q(\d+)$")
def base_id(stem: str):
    m = ELA_Q_PAT.search(stem)
    return stem[:m.start()] if m else stem

def ela_quality(stem: str):
    m = ELA_Q_PAT.search(stem)
    return int(m.group(1)) if m else None

def list_feature_dirs(root: Path, feat: str, cls: str):
    return [d for d in root.glob(f"**/{feat}_{cls}_npy") if d.is_dir()]

def list_npy(d: Path):
    return sorted(p for p in d.glob("*.npy"))

def dataset_of(img_id: str):
    return img_id.split("__", 1)[0]

def ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

def move_to_trash(p: Path):
    rel = p.relative_to(FEA_ROOT)
    dst = TRASH_DIR / rel
    ensure_dir(dst.parent)
    shutil.move(str(p), str(dst))

# ---------- 掃描 fake 的各特徵路徑 ----------
dirs = {
    ("ela","fake"):  list_feature_dirs(FEA_ROOT, "ela",  "fake"),
    ("clip","fake"): list_feature_dirs(FEA_ROOT, "clip", "fake"),
    ("prnu","fake"): list_feature_dirs(FEA_ROOT, "prnu", "fake"),
}
present_feats = [feat for (feat,_), lst in dirs.items() if lst]
assert present_feats, "找不到任何 *_fake_npy 目錄，請確認 FEA_ROOT 設定。"
print("偵測到特徵：", present_feats)

# ---------- 建立各特徵的 fake ID 集合（ELA 用 base id） ----------
id_sets = {}
for feat in present_feats:
    s = set()
    for d in dirs[(feat,"fake")]:
        for p in list_npy(d):
            s.add(base_id(p.stem) if feat=="ela" else p.stem)
    id_sets[feat] = s

# 交集（僅在目前可同時取得之 IDs 中動作）
ids_common = None
for feat, s in id_sets.items():
    ids_common = s if ids_common is None else (ids_common & s)
print("可操作（特徵交集）fake IDs：", len(ids_common))

# ---------- 依生成器分桶，決定每個生成器保留到 25K ----------
bucket = defaultdict(list)
for i in ids_common:
    bucket[dataset_of(i)].append(i)
for k in bucket: random.shuffle(bucket[k])

keep_ids = set()
per_src_keep = {}
per_src_avail = {src: len(lst) for src, lst in bucket.items()}
for src, lst in bucket.items():
    k = min(len(lst), PER_GEN_TARGET)
    per_src_keep[src] = k
    keep_ids.update(lst[:k])

print("\n來源現況（可用 → 保留 ≤ 25K）：")
for src in sorted(bucket.keys()):
    print(f"  {src:22s} {per_src_avail[src]:6d} → {per_src_keep[src]:6d}")
print("將移除的 ID 數：", len(ids_common) - len(keep_ids))

# ---------- 建立刪除清單 ----------
to_remove = []           # 要刪/移的檔案路徑
to_keep_one_ela = {}     # 當只保留單一品質時，記錄每個 base id 應保留的那個檔

if not KEEP_ALL_ELA_QUALS:
    # 先決定每個保留 ID 的唯一 ELA 檔（優先 q90，其次最接近 90）
    for d in dirs.get(("ela","fake"), []):
        # 把該目錄所有檔按 base 分桶
        by_base = defaultdict(list)
        for p in list_npy(d):
            b = base_id(p.stem)
            if b in keep_ids:
                by_base[b].append(p)
        for b, lst in by_base.items():
            best = None
            best_score = (999, "")
            for pp in lst:
                q = ela_quality(pp.stem)
                score = (0 if q == 90 else (abs(q-90) if q is not None else 999), pp.name)
                if score < best_score:
                    best_score, best = score, pp
            to_keep_one_ela[b] = best

# 只對「在交集內而不在 keep_ids」的檔動作（避免誤刪其他孤兒或未抽到的特徵）
for feat in present_feats:
    for d in dirs[(feat,"fake")]:
        for p in list_npy(d):
            id_base = base_id(p.stem) if feat=="ela" else p.stem
            if id_base in ids_common:
                if id_base not in keep_ids:
                    to_remove.append(p)
                elif feat=="ela" and (not KEEP_ALL_ELA_QUALS) and (to_keep_one_ela.get(id_base) != p):
                    to_remove.append(p)

print(f"\nPlanned removals: {len(to_remove)} files")
for p in to_remove[:10]:
    print("  -", p)

# ---------- 執行 ----------
if DRY_RUN:
    print("\nDRY_RUN=True → 只預覽，不動作。確認 OK 後把 DRY_RUN=False 重新跑。")
else:
    removed = 0
    if MODE == "trash":
        ensure_dir(TRASH_DIR)
    for p in to_remove:
        try:
            if MODE == "trash":
                move_to_trash(p)
            else:
                p.unlink(missing_ok=True)
            removed += 1
        except Exception as e:
            print("skip (error):", p, "|", e)
    print(("🗑 已移到垃圾桶：" if MODE=="trash" else "🗑 已刪除："), removed)

    # 收尾統計（現在各來源的交集/保留狀況）
    # 重新掃描交集
    id_sets2 = {}
    for feat in present_feats:
        s2 = set()
        for d in dirs[(feat,"fake")]:
            for p in list_npy(d):
                s2.add(base_id(p.stem) if feat=="ela" else p.stem)
        id_sets2[feat] = s2
    ids_common2 = None
    for feat, s in id_sets2.items():
        ids_common2 = s if ids_common2 is None else (ids_common2 & s)

    by_src2 = Counter(dataset_of(i) for i in ids_common2)
    print("After cleanup → intersection fake IDs by source:")
    for src in sorted(by_src2.keys()):
        print(f"  {src:22s} {by_src2[src]:6d}")
    if MODE=="trash":
        print("Trash dir:", TRASH_DIR)


偵測到特徵： ['ela', 'clip', 'prnu']
可操作（特徵交集）fake IDs： 89000

來源現況（可用 → 保留 ≤ 25K）：
  FLUX                    20000 →  20000
  SD3                     50000 →  25000
  dalle3                  19000 →  19000
將移除的 ID 數： 25000

Planned removals: 75000 files
  - /home/yaya/ai-detect-proj/Script/features_256/ela_fake_npy/SD3__0000_00000003.npy
  - /home/yaya/ai-detect-proj/Script/features_256/ela_fake_npy/SD3__0000_00000007.npy
  - /home/yaya/ai-detect-proj/Script/features_256/ela_fake_npy/SD3__0000_00000012.npy
  - /home/yaya/ai-detect-proj/Script/features_256/ela_fake_npy/SD3__0000_00000013.npy
  - /home/yaya/ai-detect-proj/Script/features_256/ela_fake_npy/SD3__0000_00000015.npy
  - /home/yaya/ai-detect-proj/Script/features_256/ela_fake_npy/SD3__0000_00000016.npy
  - /home/yaya/ai-detect-proj/Script/features_256/ela_fake_npy/SD3__0000_00000018.npy
  - /home/yaya/ai-detect-proj/Script/features_256/ela_fake_npy/SD3__0000_00000019.npy
  - /home/yaya/ai-detect-proj/Script/features_256/ela_fake_npy/

In [3]:
# ===== Build IID + OOD splits (complete version, with your quotas) =====
from pathlib import Path
import re, json, random
from collections import defaultdict, Counter

# ---------- 基本設定 ----------
SEED = 42
random.seed(SEED)

FEA_ROOT = Path("/home/yaya/ai-detect-proj/Script/features_256")            # ← 改成你的 features 根目錄
OUT_JSON = Path("/home/yaya/ai-detect-proj/Script/splits_iid_ood.json") # ← 輸出 JSON 路徑

# 用哪些特徵做交集（建議三者；只想快一點可暫時改 ['ela']）
INTERSECT_FEATURES = ['ela','clip','prnu']

# 你的配額（固定）
FAKE_QUOTAS = {"sd3":25_000, "midjourney":30_000, "flux":20_000, "dalle3":19_000}   # 94k
REAL_QUOTAS = {"imagenet":30_000, "flickr30k":30_000, "unsplash":24_000}            # 84k

# 切分比例
RATIOS = {"train":0.8, "val":0.1, "test":0.1}

# ---------- 來源別名（前綴 -> 標準名）----------
ALIAS = {
    # real
    "flick":"flickr30k","flicker":"flickr30k","flickr":"flickr30k","flickr30k":"flickr30k",
    "unsplash":"unsplash","imagenet":"imagenet",
    # fake
    "sd3":"sd3","stable-diffusion-3":"sd3","sd3.5":"sd3",
    "midjourney":"midjourney","midjourney-v6-llava":"midjourney","mj":"midjourney",
    "flux":"flux","FLUX":"flux",
    "dalle":"dalle3","dalle3":"dalle3","dall-e-3":"dalle3",
}

# ---------- 小工具 ----------
ELA_Q_PAT = re.compile(r"__q(\d+)$")
def base_id(stem: str):
    m = ELA_Q_PAT.search(stem)
    return stem[:m.start()] if m else stem

def norm_source(raw: str):
    return ALIAS.get(raw.lower(), raw.lower())

def dataset_of(img_id: str):
    return norm_source(img_id.split("__",1)[0])

def list_feature_dirs(feat: str, cls: str):
    return [d for d in FEA_ROOT.glob(f"**/{feat}_{cls}_npy") if d.is_dir()]

def list_ids_for(feat: str, cls: str):
    ids = set()
    for d in list_feature_dirs(feat, cls):
        for p in d.glob("*.npy"):
            s = p.stem
            s = base_id(s) if feat == "ela" else s
            ids.add(s)
    return ids

def bucket_by_source(ids):
    b = defaultdict(list)
    for i in ids:
        b[dataset_of(i)].append(i)
    for k in b: random.shuffle(b[k])
    return b

def sample_strict_quota(bucket, quotas):
    """嚴格依配額抽樣；不足就盡量拿。回傳 keep 與各來源可用量。"""
    keep = []
    avail = {k: len(v) for k,v in bucket.items()}
    for src, q in quotas.items():
        got = bucket.get(src, [])[:q]
        keep.extend(got)
    return keep, avail

def scale_real_quotas_to(total_target, avail_per_src, base_quotas):
    """把 real 配額按比例縮放到 total_target；不超過可用量；不足就輪詢補齊。"""
    s = sum(base_quotas.values())
    if s == 0: return {}
    quotas = {}
    acc = 0
    keys = sorted(base_quotas.keys())
    for k in keys[:-1]:
        q = int(round(total_target * base_quotas[k] / s))
        quotas[k] = q; acc += q
    quotas[keys[-1]] = max(0, total_target - acc)
    # 截斷到可用
    for k in list(quotas.keys()):
        quotas[k] = min(quotas[k], avail_per_src.get(k, 0))
    # 若總量仍不足，從有餘量的來源輪詢補
    need = total_target - sum(quotas.values())
    if need > 0:
        remain = {k: max(0, avail_per_src.get(k,0) - quotas.get(k,0)) for k in avail_per_src}
        srcs = sorted(remain.keys())
        while need > 0:
            progressed = False
            for sname in srcs:
                if remain[sname] > 0:
                    quotas[sname] = quotas.get(sname,0) + 1
                    remain[sname] -= 1
                    need -= 1
                    progressed = True
                    if need == 0: break
            if not progressed:
                break
    return quotas

def split_8_1_1_per_source(selected_ids):
    """每來源各自 8/1/1，再合併；保持來源比例。"""
    from sklearn.model_selection import train_test_split
    out = {"train":[], "val":[], "test":[]}
    per_src = bucket_by_source(selected_ids)
    for src, ids in per_src.items():
        if len(ids) < 10:
            n_tr = int(len(ids) * RATIOS["train"])
            n_va = int(len(ids) * RATIOS["val"])
            out["train"].extend(ids[:n_tr])
            out["val"].extend(ids[n_tr:n_tr+n_va])
            out["test"].extend(ids[n_tr+n_va:])
        else:
            tr, tmp = train_test_split(ids, test_size=(1-RATIOS["train"]), random_state=SEED)
            va, te  = train_test_split(tmp, test_size=RATIOS["test"]/(RATIOS["test"]+RATIOS["val"]), random_state=SEED)
            out["train"].extend(tr); out["val"].extend(va); out["test"].extend(te)
    for k in out: random.shuffle(out[k])
    return out

def report_counts(ids):
    c = Counter(dataset_of(i) for i in ids)
    return {"total": len(ids), "by_source": dict(sorted(c.items()))}

# ---------- 掃描 & 特徵交集 ----------
present = {}
for feat in INTERSECT_FEATURES:
    present[feat] = {"real": list_ids_for(feat,"real"),
                     "fake": list_ids_for(feat,"fake")}
sets_real = [present[feat]["real"] for feat in INTERSECT_FEATURES if present[feat]["real"]]
sets_fake = [present[feat]["fake"] for feat in INTERSECT_FEATURES if present[feat]["fake"]]
assert sets_real and sets_fake, "找不到可用的 real 或 fake IDs，請檢查特徵路徑與 INTERSECT_FEATURES。"

common = {
    "real": set.intersection(*sets_real) if len(sets_real)>1 else sets_real[0],
    "fake": set.intersection(*sets_fake) if len(sets_fake)>1 else sets_fake[0],
}
print("交集數量 → real:", len(common["real"]), "| fake:", len(common["fake"]))

real_bkt = bucket_by_source(common["real"])
fake_bkt = bucket_by_source(common["fake"])

# ---------- IID ----------
# fake：嚴格依配額
fake_keep_iid, fake_avail = sample_strict_quota(fake_bkt, FAKE_QUOTAS)

# real：嘗試縮放到與 fake 同總量（如不足就拿到上限）
avail_real = {k: len(v) for k,v in real_bkt.items()}
real_quota_scaled = scale_real_quotas_to(len(fake_keep_iid), avail_real, REAL_QUOTAS)
real_keep_iid, _ = sample_strict_quota(real_bkt, real_quota_scaled)

iid_real = split_8_1_1_per_source(real_keep_iid)
iid_fake = split_8_1_1_per_source(fake_keep_iid)
iid = { "train": iid_real["train"] + iid_fake["train"],
        "val":   iid_real["val"]   + iid_fake["val"],
        "test":  iid_real["test"]  + iid_fake["test"] }
for k in iid: random.shuffle(iid[k])

# ---------- OOD-by-Generator（四個） ----------
def build_ood_for(hold_src: str):
    hold_src = hold_src.lower()
    # holdout：整個來源，按配額上限截斷
    hold_quota = FAKE_QUOTAS.get(hold_src, len(fake_bkt.get(hold_src, [])))
    hold_ids   = fake_bkt.get(hold_src, [])[:min(hold_quota, len(fake_bkt.get(hold_src, [])))]
    # 其餘 fake：照配額抽樣
    remain_quota = {s:q for s,q in FAKE_QUOTAS.items() if s != hold_src}
    fake_keep, _ = sample_strict_quota(fake_bkt, remain_quota)
    # real：縮放到與「訓練用 fake」同總量
    real_quota_scaled = scale_real_quotas_to(len(fake_keep), avail_real, REAL_QUOTAS)
    real_keep, _ = sample_strict_quota(real_bkt, real_quota_scaled)
    # split
    r = split_8_1_1_per_source(real_keep)
    f = split_8_1_1_per_source(fake_keep)
    ood = { "train": r["train"] + f["train"],
            "val":   r["val"]   + f["val"],
            "test":  r["test"]  + f["test"] + hold_ids }
    for k in ood: random.shuffle(ood[k])
    rep = {
        "train": report_counts(ood["train"]),
        "val":   report_counts(ood["val"]),
        "test":  report_counts(ood["test"]),
        "holdout_source": hold_src,
        "holdout_size": len(hold_ids)
    }
    return ood, rep

ood_all = {}
ood_reports = {}
for src in ["sd3","midjourney","flux","dalle3"]:
    ood_all[src], ood_reports[src] = build_ood_for(src)

# ---------- 報告 & 存檔 ----------
def quick_report(name, split, real_pool):
    rset = set(real_pool["train"] + real_pool["val"] + real_pool["test"])
    for sp in ("train","val","test"):
        both = split[sp]
        n_total = len(both)
        n_real  = sum(1 for i in both if i in rset)
        n_fake  = n_total - n_real
        print(f"[{name} {sp}] total={n_total} | real={n_real} fake={n_fake}")

print("\n== IID summary ==")
quick_report("IID", iid, iid_real)
print("  real per-src:", report_counts(iid_real["train"]+iid_real["val"]+iid_real["test"])["by_source"])
print("  fake per-src:", report_counts(iid_fake["train"]+iid_fake["val"]+iid_fake["test"])["by_source"])

print("\n== OOD summaries (by generator) ==")
for src in ood_all:
    quick_report(f"OOD-{src}", ood_all[src], iid_real)  # 只是用來區分 real/fake
    print(f"  holdout {src} →", ood_reports[src]["holdout_size"])

meta = {
    "seed": SEED,
    "intersect_features": INTERSECT_FEATURES,
    "fake_quotas": FAKE_QUOTAS,
    "real_quotas": REAL_QUOTAS,
    "avail": {
        "real_by_src": {k:len(v) for k,v in real_bkt.items()},
        "fake_by_src": {k:len(v) for k,v in fake_bkt.items()},
        "real_intersection_total": len(common["real"]),
        "fake_intersection_total": len(common["fake"]),
    },
    "ood_reports": ood_reports,
}
payload = {"meta": meta, "iid": iid, "ood_gen": ood_all}
OUT_JSON.parent.mkdir(parents=True, exist_ok=True)
OUT_JSON.write_text(json.dumps(payload, ensure_ascii=False, indent=2))
print("\n✅ saved:", OUT_JSON)


交集數量 → real: 86772 | fake: 81720

== IID summary ==
[IID train] total=102398 | real=51198 fake=51200
[IID val] total=12801 | real=6401 fake=6400
[IID test] total=12801 | real=6401 fake=6400
  real per-src: {'flickr30k': 22857, 'imagenet': 22857, 'unsplash': 18286}
  fake per-src: {'dalle3': 19000, 'flux': 20000, 'sd3': 25000}

== OOD summaries (by generator) ==
[OOD-sd3 train] total=62399 | real=31199 fake=31200
[OOD-sd3 val] total=7800 | real=3900 fake=3900
[OOD-sd3 test] total=32801 | real=3901 fake=28900
  holdout sd3 → 25000
[OOD-midjourney train] total=102398 | real=51198 fake=51200
[OOD-midjourney val] total=12801 | real=6401 fake=6400
[OOD-midjourney test] total=12801 | real=6401 fake=6400
  holdout midjourney → 0
[OOD-flux train] total=70399 | real=35199 fake=35200
[OOD-flux val] total=8799 | real=4399 fake=4400
[OOD-flux test] total=28802 | real=4402 fake=24400
  holdout flux → 20000
[OOD-dalle3 train] total=71998 | real=35998 fake=36000
[OOD-dalle3 val] total=9000 | real=4500

In [7]:
# ===== Unified split builder: IID / IID_balanced / OOD(by generator) / OOD_strict / smoke_10p =====
from pathlib import Path
import re, json, random
from collections import defaultdict, Counter

# ---------- 基本設定（改這裡） ----------
SEED = 42
random.seed(SEED)

SCRIPT_ROOT = Path("/home/yaya/ai-detect-proj/Script")
FEA_ROOT    = SCRIPT_ROOT / "features_256"                         # features 根
OUT_JSON    = SCRIPT_ROOT / "splits/combined_split.json"           # 輸出檔

# 用哪些特徵做交集（建議三者；要加速可改成 ['ela']）
INTERSECT_FEATURES = ['ela','clip','prnu']

# 你的配額（固定）
FAKE_QUOTAS = {"sd3":25_000, "midjourney":30_000, "flux":20_000, "dalle3":19_000}  # 94k
REAL_QUOTAS = {"imagenet":30_000, "flickr30k":30_000, "unsplash":24_000}           # 84k

# 切分比例
RATIOS = {"train":0.8, "val":0.1, "test":0.1}

# 產生一份 10% 的 smoke，從哪個 split 抽
SMOKE_BASE_KEY = "iid"   # 可改成 "ood_gen.sd3" / "ood_gen.midjourney" / "ood_gen.flux" / "ood_gen.dalle3"

# ---------- 來源別名（前綴 -> 標準名）----------
ALIAS = {
    # real
    "flick":"flickr30k","flicker":"flickr30k","flickr":"flickr30k","flickr30k":"flickr30k",
    "unsplash":"unsplash","imagenet":"imagenet",
    # fake
    "sd3":"sd3","stable-diffusion-3":"sd3","sd3.5":"sd3",
    "midjourney":"midjourney","midjourney-v6":"midjourney","mj":"midjourney",
    "flux":"flux","FLUX":"flux",
    "dalle":"dalle3","dalle3":"dalle3","dall-e-3":"dalle3",
}

REAL_SOURCES = {"imagenet","flickr30k","unsplash"}

# ---------- 小工具 ----------
ELA_Q_PAT = re.compile(r"__q(\d+)$")
def base_id(stem: str):
    m = ELA_Q_PAT.search(stem)
    return stem[:m.start()] if m else stem

def norm_source(raw: str): return ALIAS.get(raw.lower(), raw.lower())
def dataset_of(img_id: str): return norm_source(img_id.split("__",1)[0])
def is_real_id(img_id: str): return dataset_of(img_id) in REAL_SOURCES

def list_feature_dirs(feat: str, cls: str):
    return [d for d in FEA_ROOT.glob(f"**/{feat}_{cls}_npy") if d.is_dir()]

def list_ids_for(feat: str, cls: str):
    ids = set()
    for d in list_feature_dirs(feat, cls):
        for p in d.glob("*.npy"):
            s = p.stem
            s = base_id(s) if feat == "ela" else s
            ids.add(s)
    return ids

def bucket_by_source(ids):
    b = defaultdict(list)
    for i in ids:
        b[dataset_of(i)].append(i)
    for k in b: random.shuffle(b[k])
    return b

def sample_quota(bucket, quotas):
    """依配額抽樣；不足就拿可用量。回傳 keep 與各來源可用量"""
    keep = []
    avail = {k: len(v) for k,v in bucket.items()}
    for src, q in quotas.items():
        got = bucket.get(src, [])[:q]
        keep.extend(got)
    return keep, avail

def scale_real_quotas_to(total_target, avail_per_src, base_quotas):
    """把 real 配額按比例縮放到 total_target；不超過可用量；不足就輪詢補齊。"""
    s = sum(base_quotas.values())
    if s == 0: return {}
    quotas = {}
    acc = 0
    keys = sorted(base_quotas.keys())
    for k in keys[:-1]:
        q = int(round(total_target * base_quotas[k] / s))
        quotas[k] = q; acc += q
    quotas[keys[-1]] = max(0, total_target - acc)
    # 截到可用
    for k in list(quotas.keys()):
        quotas[k] = min(quotas[k], avail_per_src.get(k, 0))
    # 若仍不足，輪詢補
    need = total_target - sum(quotas.values())
    if need > 0:
        remain = {k: max(0, avail_per_src.get(k,0) - quotas.get(k,0)) for k in avail_per_src}
        srcs = sorted(remain.keys())
        while need > 0:
            progressed = False
            for sname in srcs:
                if remain[sname] > 0:
                    quotas[sname] = quotas.get(sname,0) + 1
                    remain[sname] -= 1
                    need -= 1
                    progressed = True
                    if need == 0: break
            if not progressed: break
    return quotas

from sklearn.model_selection import train_test_split
def split_8_1_1_per_source(selected_ids):
    """各來源 8/1/1，再合併，保持來源比例。"""
    out = {"train":[], "val":[], "test":[]}
    per_src = bucket_by_source(selected_ids)
    for src, ids in per_src.items():
        if len(ids) < 10:
            n_tr = int(len(ids) * RATIOS["train"])
            n_va = int(len(ids) * RATIOS["val"])
            out["train"].extend(ids[:n_tr])
            out["val"].extend(ids[n_tr:n_tr+n_va])
            out["test"].extend(ids[n_tr+n_va:])
        else:
            tr, tmp = train_test_split(ids, test_size=(1-RATIOS["train"]), random_state=SEED, shuffle=True)
            va, te  = train_test_split(tmp, test_size=RATIOS["test"]/(RATIOS["test"]+RATIOS["val"]), random_state=SEED, shuffle=True)
            out["train"].extend(tr); out["val"].extend(va); out["test"].extend(te)
    for k in out: random.shuffle(out[k])
    return out

def report_counts(ids):
    c = Counter(dataset_of(i) for i in ids)
    return {"total": len(ids), "by_source": dict(sorted(c.items()))}

def summarize(name, split, real_pool):
    rset = set(real_pool["train"] + real_pool["val"] + real_pool["test"])
    for sp in ("train","val","test"):
        both = split[sp]
        n_total = len(both)
        n_real  = sum(1 for i in both if i in rset)
        n_fake  = n_total - n_real
        print(f"[{name} {sp}] total={n_total} | real={n_real} fake={n_fake}")

# ---------- 掃描 & 特徵交集 ----------
present = {}
for feat in INTERSECT_FEATURES:
    present[feat] = {"real": list_ids_for(feat,"real"),
                     "fake": list_ids_for(feat,"fake")}
sets_real = [present[feat]["real"] for feat in INTERSECT_FEATURES if present[feat]["real"]]
sets_fake = [present[feat]["fake"] for feat in INTERSECT_FEATURES if present[feat]["fake"]]
assert sets_real and sets_fake, "找不到可用的 real 或 fake IDs，請檢查特徵路徑與 INTERSECT_FEATURES。"

common = {
    "real": set.intersection(*sets_real) if len(sets_real)>1 else sets_real[0],
    "fake": set.intersection(*sets_fake) if len(sets_fake)>1 else sets_fake[0],
}
print("交集數量 → real:", len(common["real"]), "| fake:", len(common["fake"]))

real_bkt = bucket_by_source(common["real"])
fake_bkt = bucket_by_source(common["fake"])
avail_real = {k: len(v) for k,v in real_bkt.items()}

# ---------- IID ----------
# fake：嚴格依配額
iid_fake_keep, fake_avail = sample_quota(fake_bkt, FAKE_QUOTAS)
# real：固定用你的 REAL_QUOTAS（不縮放）
iid_real_keep, _ = sample_quota(real_bkt, REAL_QUOTAS)

iid_real = split_8_1_1_per_source(iid_real_keep)
iid_fake = split_8_1_1_per_source(iid_fake_keep)
iid = { "train": iid_real["train"] + iid_fake["train"],
        "val":   iid_real["val"]   + iid_fake["val"],
        "test":  iid_real["test"]  + iid_fake["test"] }
for k in iid: random.shuffle(iid[k])

# ---------- IID_balanced（讓 fake 下採樣到 real 的總量） ----------
total_real_target = sum(len(v) for v in iid_real.values())
# 這裡用各生成器的配額占比來決定 fake 下採樣比例
sum_fake_quota = sum(FAKE_QUOTAS.values())
bal_fake_quota = {src: int(round(total_real_target * q / sum_fake_quota)) for src, q in FAKE_QUOTAS.items()}
bal_fake_keep, _ = sample_quota(fake_bkt, bal_fake_quota)
bal_fake = split_8_1_1_per_source(bal_fake_keep)
iid_balanced = { "train": iid_real["train"] + bal_fake["train"],
                 "val":   iid_real["val"]   + bal_fake["val"],
                 "test":  iid_real["test"]  + bal_fake["test"] }
for k in iid_balanced: random.shuffle(iid_balanced[k])

# ---------- OOD-by-Generator（混合 test） ----------
def build_ood_for(hold_src: str, strict=False):
    hold_src = hold_src.lower()
    # holdout：整個來源，按配額截斷
    hold_quota = FAKE_QUOTAS.get(hold_src, len(fake_bkt.get(hold_src, [])))
    hold_ids   = fake_bkt.get(hold_src, [])[:min(hold_quota, len(fake_bkt.get(hold_src, [])))]
    # 其餘 fake：照配額抽樣
    remain_quota = {s:q for s,q in FAKE_QUOTAS.items() if s != hold_src}
    fake_keep, _ = sample_quota(fake_bkt, remain_quota)
    # real：縮放到與「訓練用 fake」同總量
    real_quota_scaled = scale_real_quotas_to(len(fake_keep), avail_real, REAL_QUOTAS)
    real_keep, _ = sample_quota(real_bkt, real_quota_scaled)
    # split
    r = split_8_1_1_per_source(real_keep)
    f = split_8_1_1_per_source(fake_keep)
    if strict:
        test_fake = hold_ids                      # 只放 holdout
    else:
        test_fake = f["test"] + hold_ids          # 混合 test
    ood = { "train": r["train"] + f["train"],
            "val":   r["val"]   + f["val"],
            "test":  r["test"]  + test_fake }
    for k in ood: random.shuffle(ood[k])
    rep = {
        "train": report_counts(ood["train"]),
        "val":   report_counts(ood["val"]),
        "test":  report_counts(ood["test"]),
        "holdout_source": hold_src,
        "holdout_size": len(hold_ids)
    }
    return ood, rep

ood_gen = {}
ood_gen_strict = {}
ood_reports = {}
for src in ["sd3","midjourney","flux","dalle3"]:
    ood_gen[src], ood_reports[src] = build_ood_for(src, strict=False)
    ood_gen_strict[src], _         = build_ood_for(src, strict=True)

# ---------- 10% smoke（從 SMOKE_BASE_KEY 抽） ----------
def get_by_key(tree, key: str):
    node = tree
    for k in key.split("."):
        node = node[k]
    return node

def subsample_by_class_source(ids, frac=0.1, min_each=1, seed=SEED):
    random.seed(seed)
    buckets = defaultdict(list)
    for i in ids:
        buckets[( "real" if is_real_id(i) else "fake", dataset_of(i) )].append(i)
    keep = []
    for (_c,_s), arr in buckets.items():
        random.shuffle(arr)
        k = max(min_each, int(round(len(arr)*frac)))
        keep.extend(arr[:k])
    random.shuffle(keep)
    return keep

tmp_all = {
    "iid": iid,
    "iid_balanced": iid_balanced,
    "ood_gen": ood_gen,
    "ood_gen_strict": ood_gen_strict
}
base = get_by_key(tmp_all, SMOKE_BASE_KEY)
smoke_10p = {
    "train": subsample_by_class_source(base["train"], 0.10, seed=SEED+1),
    "val":   subsample_by_class_source(base["val"],   0.10, seed=SEED+2),
    "test":  subsample_by_class_source(base["test"],  0.10, seed=SEED+3),
}

# ---------- 簡報 ----------
print("\n== IID summary =="); summarize("IID", iid, iid_real)
print("\n== IID_balanced summary =="); summarize("IID_BAL", iid_balanced, iid_real)
print("\n== OOD summaries ==")
for src in ood_gen:
    summarize(f"OOD-{src}", ood_gen[src], iid_real)
    print(f"  holdout {src} →", ood_reports[src]["holdout_size"])
print("\n== OOD_strict summaries ==")
for src in ood_gen_strict:
    summarize(f"OOD_STRICT-{src}", ood_gen_strict[src], iid_real)

# ---------- 存檔 ----------
OUT_JSON.parent.mkdir(parents=True, exist_ok=True)
payload = {
    "meta": {
        "seed": SEED,
        "intersect_features": INTERSECT_FEATURES,
        "fake_quotas": FAKE_QUOTAS,
        "real_quotas": REAL_QUOTAS,
        "avail": {
            "real_by_src": {k:len(v) for k,v in real_bkt.items()},
            "fake_by_src": {k:len(v) for k,v in fake_bkt.items()},
            "real_intersection_total": len(common["real"]),
            "fake_intersection_total": len(common["fake"]),
        },
        "notes": "iid: 固定配額；iid_balanced: fake 下採樣到 real 總量；ood_gen: 混合 test；ood_gen_strict: 嚴格 OOD",
    },
    "iid": iid,
    "iid_balanced": iid_balanced,
    "ood_gen": ood_gen,
    "ood_gen_strict": ood_gen_strict,
    "smoke_10p": smoke_10p,
}
OUT_JSON.write_text(json.dumps(payload, ensure_ascii=False, indent=2))
print("\n✅ saved:", OUT_JSON)


交集數量 → real: 86772 | fake: 81720

== IID summary ==
[IID train] total=132576 | real=67200 fake=65376
[IID val] total=16572 | real=8400 fake=8172
[IID test] total=16572 | real=8400 fake=8172

== IID_balanced summary ==
[IID_BAL train] total=127128 | real=67200 fake=59928
[IID_BAL val] total=15891 | real=8400 fake=7491
[IID_BAL test] total=15892 | real=8400 fake=7492

== OOD summaries ==
[OOD-sd3 train] total=90750 | real=45374 fake=45376
[OOD-sd3 val] total=11345 | real=5673 fake=5672
[OOD-sd3 test] total=36345 | real=5673 fake=30672
  holdout sd3 → 25000
[OOD-midjourney train] total=102398 | real=51198 fake=51200
[OOD-midjourney val] total=12801 | real=6401 fake=6400
[OOD-midjourney test] total=30521 | real=6401 fake=24120
  holdout midjourney → 17720
[OOD-flux train] total=98751 | real=49375 fake=49376
[OOD-flux val] total=12343 | real=6171 fake=6172
[OOD-flux test] total=32346 | real=6174 fake=26172
  holdout flux → 20000
[OOD-dalle3 train] total=100352 | real=50176 fake=50176
[OOD-d