In [None]:
# === 只做量化：PRNU float -> int8（無 zst 版）===
# 特色：
# 1) MODE='per_dataset'：每個資料集一個 scale S，兼顧精度與速度
# 2) 取分位數用 np.partition 的近似法 + 子抽樣像素，極快且準
# 3) 多進程平行處理
# 4) 直接輸出 .npy（dtype=int8），並寫 meta.json 紀錄各 dataset 的 S

import os, re, json, math, time
from pathlib import Path
from typing import List, Dict, Tuple
from functools import partial
import numpy as np
from tqdm import tqdm
import multiprocessing as mp

# ========= 路徑設定 =========
SCRIPT_ROOT = Path("/home/yaya/ai-detect-proj/Script")
SRC_REAL_F32 = SCRIPT_ROOT/"features_npy/prnu_real_npy"    # 原始 float PRNU (real)
SRC_FAKE_F32 = SCRIPT_ROOT/"features_npy/prnu_fake_npy"    # 原始 float PRNU (fake)

DST_REAL_I8  = SCRIPT_ROOT/"features_i8/prnu_real_i8_npy"  # 量化 int8 輸出 (real)
DST_FAKE_I8  = SCRIPT_ROOT/"features_i8/prnu_fake_i8_npy"  # 量化 int8 輸出 (fake)
DST_REAL_I8.mkdir(parents=True, exist_ok=True)
DST_FAKE_I8.mkdir(parents=True, exist_ok=True)

# ========= 主要參數 =========
MODE = "per_dataset"       # 可選：'global'（最快）、'per_dataset'（推薦）、'per_file'（最保真但慢）
PERCENTILE = 0.999         # p99.9
SAMPLE_FILES_PER_DS = 500  # 每資料集抽幾個檔來估 S
SAMPLE_PIXELS_PER_FILE = 4096  # 每檔抽幾個像素估分位數
NPROC = max(1, (os.cpu_count() or 2) - 1)  # 進程數

DATASET_KEYWORDS = [
    'imagenet','flickr30k','unsplash','places365','coco2017','coco','div2k',
    'sd3','midjourney-v6-llava','flux','dalle3','stablediffusion','midjourney'
]

# ========= 工具 =========
def list_files(root: Path) -> List[Path]:
    return sorted([p for p in root.rglob("*.npy")] + [p for p in root.rglob("*.npz")])

def infer_dataset_from_stem(stem: str) -> str:
    s = stem.lower()
    for k in DATASET_KEYWORDS:
        if k in s: return k
    m = re.match(r'([a-z0-9\-]+)[_\-]', s)
    return m.group(1) if m else "unknown"

def load_prnu_2d(path: Path) -> np.ndarray:
    z = np.load(path, mmap_mode='r')
    if isinstance(z, np.lib.npyio.NpzFile):
        for k in ('prnu','noise','arr','arr_0','data'):
            if k in z.files:
                a = z[k]; break
        else:
            a = z[z.files[0]]
    else:
        a = z
    a = np.asarray(a, dtype=np.float32)
    if a.ndim == 3:
        if a.shape[-1] in (1,3): a = a.mean(axis=2)
        elif a.shape[0] in (1,3): a = a.mean(axis=0)
        else: a = a.squeeze()
    assert a.ndim == 2, f"Expect 2D, got {a.shape} from {path}"
    return a

def subsample_vals_abs(a: np.ndarray, k: int) -> np.ndarray:
    """從 |a| 隨機抽 k 個像素（不足則全取），避免整張展平成新陣列造成大量記憶體。"""
    v = a.reshape(-1)
    n = v.size
    if n <= k:
        return np.abs(v.astype(np.float32, copy=False))
    idx = np.random.default_rng(1337).integers(0, n, size=k, endpoint=False)
    return np.abs(v[idx].astype(np.float32, copy=False))

def fast_percentile_abs(v: np.ndarray, q: float) -> float:
    if v.size == 0: return 1e-8
    k = int(q * (v.size - 1))
    vk = np.partition(v, k)[k]
    return float(max(vk, 1e-8))

def quantize_i8(a: np.ndarray, S: float) -> np.ndarray:
    x = np.clip(a, -S, S) / S * 127.0
    q = np.rint(x).astype(np.int16)
    q = np.clip(q, -127, 127).astype(np.int8)
    return q

def save_npy(out_base: Path, q: np.ndarray):
    out_path = out_base.with_suffix(".npy")
    out_path.parent.mkdir(parents=True, exist_ok=True)
    # allow_pickle=False 以避免帶入非必要資訊
    np.save(out_path, q, allow_pickle=False)
    return out_path

# ========= 估 scale S =========
def estimate_global_S(files: List[Path], sample_files=20000, sample_pixels=4096, q=0.999) -> float:
    pick = files if len(files) <= sample_files else list(np.random.default_rng(1337).choice(files, sample_files, replace=False))
    vals = []
    for p in tqdm(pick, desc="scan global S"):
        a = load_prnu_2d(p)
        vals.append(subsample_vals_abs(a, sample_pixels))
    v = np.concatenate(vals, axis=0) if len(vals) else np.array([1.0], dtype=np.float32)
    return fast_percentile_abs(v, q)

def estimate_S_per_dataset(files: List[Path], sample_files_per_ds=500, sample_pixels=4096, q=0.999) -> Dict[str, float]:
    buckets: Dict[str, List[Path]] = {}
    for p in files:
        ds = infer_dataset_from_stem(p.stem)
        buckets.setdefault(ds, []).append(p)
    ds2S = {}
    for ds, lst in buckets.items():
        pick = lst if len(lst) <= sample_files_per_ds else list(np.random.default_rng(1337).choice(lst, sample_files_per_ds, replace=False))
        vals = []
        for p in tqdm(pick, desc=f"scan S[{ds}]"):
            a = load_prnu_2d(p)
            vals.append(subsample_vals_abs(a, sample_pixels))
        v = np.concatenate(vals, 0) if len(vals) else np.array([1.0], dtype=np.float32)
        ds2S[ds] = fast_percentile_abs(v, q)
    return ds2S

# ========= worker：單檔處理 =========
def worker_convert_one(path: Path, out_dir: Path, mode: str, S_global: float, ds2S: Dict[str,float],
                       skip_existing: bool=True) -> Tuple[str, str]:
    try:
        stem = path.stem
        out_base = (out_dir / stem)
        out_path = out_base.with_suffix(".npy")
        if skip_existing and out_path.exists():
            return ("skip", str(out_path))
        a = load_prnu_2d(path)
        if mode == "per_file":
            v = subsample_vals_abs(a, SAMPLE_PIXELS_PER_FILE)
            S = fast_percentile_abs(v, PERCENTILE)
        elif mode == "per_dataset":
            ds = infer_dataset_from_stem(stem)
            S = ds2S.get(ds, S_global)
        else:  # global
            S = S_global
        q = quantize_i8(a, S)
        save_npy(out_base, q)
        return ("ok", str(out_path))
    except Exception as e:
        return ("err", f"{path} -> {e}")

# ========= 主流程：一側（real 或 fake）量化 =========
def run_quantize_side(src_dir: Path, dst_dir: Path, mode="per_dataset"):
    files = list_files(src_dir)
    if not files:
        print(f"[WARN] no files in {src_dir}"); return

    print(f"MODE = {mode}")
    S_global = estimate_global_S(files, sample_files=min(20000, len(files)),
                                 sample_pixels=SAMPLE_PIXELS_PER_FILE, q=PERCENTILE)
    print(f"Global S = {S_global:.6g}")

    ds2S = {}
    if mode == "per_dataset":
        ds2S = estimate_S_per_dataset(files, sample_files_per_ds=SAMPLE_FILES_PER_DS,
                                      sample_pixels=SAMPLE_PIXELS_PER_FILE, q=PERCENTILE)
        print("Per-dataset S:", {k: round(v,6) for k,v in ds2S.items()})

    # 存一份 meta（便於之後還原或追蹤）
    meta = {
        "mode": mode,
        "percentile": PERCENTILE,
        "global_S": S_global,
        "per_dataset_S": ds2S,
        "created_at": time.strftime("%Y-%m-%d %H:%M:%S"),
        "src": str(src_dir), "dst": str(dst_dir),
        "output_format": "npy(int8)"
    }
    (dst_dir.parent/"prnu_i8_meta.json").write_text(json.dumps(meta, ensure_ascii=False, indent=2))

    # 多進程平行處理
    fn = partial(worker_convert_one, out_dir=dst_dir, mode=mode, S_global=S_global, ds2S=ds2S,
                 skip_existing=True)
    ok=skip=err=0
    with mp.Pool(processes=NPROC, maxtasksperchild=200) as pool:
        for status, msg in tqdm(pool.imap_unordered(fn, files, chunksize=64), total=len(files), desc=f"quantize→{dst_dir.name}"):
            if status=="ok": ok+=1
            elif status=="skip": skip+=1
            else:
                err+=1; print("[ERR]", msg)
    print(f"完成：ok={ok}, skip={skip}, err={err}, out_dir={dst_dir}")

# ======== 執行（一次一側；你可先跑 real，再跑 fake）========
run_quantize_side(SRC_REAL_F32, DST_REAL_I8, mode=MODE)
run_quantize_side(SRC_FAKE_F32, DST_FAKE_I8, mode=MODE)

print("✅ Done. 輸出為 .npy（int8）；meta 寫在：", (DST_REAL_I8.parent/"prnu_i8_meta.json"))


In [2]:
# === PRNU float -> int8（單一資料集版；無 zst）===
# 只針對指定資料集（例如 dalle3）做量化；其餘略過
import os, re, json, time, math
from pathlib import Path
from typing import List, Dict, Tuple
from functools import partial
import numpy as np
from tqdm import tqdm
import multiprocessing as mp

# ========= 路徑設定（依你的環境調整）=========
SCRIPT_ROOT = Path("/home/yaya/ai-detect-proj/Script")

# 來源（float PRNU）：如果要量化 fake/dalle3，通常在 prnu_fake_npy
SRC_DIR = SCRIPT_ROOT / "features_npy" / "prnu_fake_npy"
# 輸出（int8 PRNU）
DST_DIR = SCRIPT_ROOT / "features_i8" / "prnu_fake_i8_npy"
DST_DIR.mkdir(parents=True, exist_ok=True)

# 僅處理這些資料集（可放多個），預設只做 dalle3
INCLUDE_TAGS = {"dalle3"}

# ========= 量化參數 =========
MODE = "per_dataset"           # 'global' | 'per_dataset' | 'per_file'
PERCENTILE = 0.999             # p99.9
SAMPLE_FILES_PER_DS = 500      # 估 S 時每資料集抽樣的檔數上限
SAMPLE_PIXELS_PER_FILE = 4096  # 估分位數時，每檔抽樣像素數
NPROC = max(1, (os.cpu_count() or 2) - 1)
SEED = 1337

# ========= 來源關鍵字/別名（含 dalle-3 -> dalle3）=========
ALIASES = {
    "imagenet1k":"imagenet", "imgnet":"imagenet", "imagenet":"imagenet",
    "unslpash":"unsplash", "unsplash":"unsplash",
    "flicker30k":"flickr30k", "flickr30K":"flickr30k", "flickr30k":"flickr30k",
    "places365":"places365", "coco2017":"coco2017", "div2k":"div2k",
    "sd3":"sd3", "sdxl":"sd3",
    "flux":"flux", "black-forest-labs":"flux",
    "dalle-3":"dalle3", "dalle3":"dalle3",
    "midjourney-v6-llava":"midjourney", "midjourney":"midjourney",
}
DATASET_KEYWORDS = list(set(ALIASES.values()) | set(ALIASES.keys()))

def canonical(tag:str)->str:
    return ALIASES.get(tag.lower().strip(), tag.lower().strip())

def infer_dataset_from_stem(stem: str) -> str:
    s = stem.lower()
    # 先匹配較長別名
    for k in sorted(DATASET_KEYWORDS, key=len, reverse=True):
        if k in s:
            return canonical(k)
    m = re.match(r'([a-z0-9\-]+)[_\-]', s)
    return canonical(m.group(1)) if m else "unknown"

# ========= IO / 讀檔 =========
def list_files(root: Path) -> List[Path]:
    return sorted([*root.rglob("*.npy"), *root.rglob("*.npz")])

def load_prnu_2d(path: Path) -> np.ndarray:
    z = np.load(path, mmap_mode='r')
    if isinstance(z, np.lib.npyio.NpzFile):
        for k in ('prnu','noise','arr','arr_0','data'):
            if k in z.files:
                a = z[k]; break
        else:
            a = z[z.files[0]]
    else:
        a = z
    a = np.asarray(a, dtype=np.float32)
    if a.ndim == 3:
        if a.shape[-1] in (1,3): a = a.mean(axis=2)
        elif a.shape[0] in (1,3): a = a.mean(axis=0)
        else: a = a.squeeze()
    assert a.ndim == 2, f"Expect 2D, got {a.shape} from {path}"
    return a

def subsample_vals_abs(a: np.ndarray, k: int) -> np.ndarray:
    v = a.reshape(-1)
    n = v.size
    if n <= k:
        return np.abs(v.astype(np.float32, copy=False))
    idx = np.random.default_rng(SEED).integers(0, n, size=k, endpoint=False)
    return np.abs(v[idx].astype(np.float32, copy=False))

def fast_percentile_abs(v: np.ndarray, q: float) -> float:
    if v.size == 0: return 1e-8
    k = int(q * (v.size - 1))
    vk = np.partition(v, k)[k]
    return float(max(vk, 1e-8))

def quantize_i8(a: np.ndarray, S: float) -> np.ndarray:
    x = np.clip(a, -S, S) / S * 127.0
    q = np.rint(x).astype(np.int16)
    q = np.clip(q, -127, 127).astype(np.int8)
    return q

def save_npy(out_base: Path, q: np.ndarray):
    out_path = out_base.with_suffix(".npy")
    out_path.parent.mkdir(parents=True, exist_ok=True)
    np.save(out_path, q, allow_pickle=False)
    return out_path

# ========= 僅保留指定資料集的檔案 =========
def filter_files_by_tags(files: List[Path], include_tags:set) -> List[Path]:
    keep=[]
    for p in files:
        tag = infer_dataset_from_stem(p.stem)
        if tag in include_tags:
            keep.append(p)
    return keep

# ========= 估 S =========
def estimate_global_S(files: List[Path], sample_files=20000, sample_pixels=4096, q=0.999) -> float:
    pick = files if len(files) <= sample_files else list(np.random.default_rng(SEED).choice(files, sample_files, replace=False))
    vals = []
    for p in tqdm(pick, desc="scan global S"):
        a = load_prnu_2d(p)
        vals.append(subsample_vals_abs(a, sample_pixels))
    v = np.concatenate(vals, axis=0) if len(vals) else np.array([1.0], dtype=np.float32)
    return fast_percentile_abs(v, q)

def estimate_S_per_dataset(files: List[Path], sample_files_per_ds=500, sample_pixels=4096, q=0.999) -> Dict[str, float]:
    buckets: Dict[str, List[Path]] = {}
    for p in files:
        ds = infer_dataset_from_stem(p.stem)
        buckets.setdefault(ds, []).append(p)
    ds2S = {}
    for ds, lst in buckets.items():
        pick = lst if len(lst) <= sample_files_per_ds else list(np.random.default_rng(SEED).choice(lst, sample_files_per_ds, replace=False))
        vals = []
        for p in tqdm(pick, desc=f"scan S[{ds}]"):
            a = load_prnu_2d(p)
            vals.append(subsample_vals_abs(a, sample_pixels))
        v = np.concatenate(vals, 0) if len(vals) else np.array([1.0], dtype=np.float32)
        ds2S[ds] = fast_percentile_abs(v, q)
    return ds2S

# ========= 單檔 worker =========
def worker_convert_one(path: Path, out_dir: Path, mode: str, S_global: float, ds2S: Dict[str,float],
                       skip_existing: bool=True) -> Tuple[str, str]:
    try:
        stem = path.stem
        out_base = (out_dir / stem)
        out_path = out_base.with_suffix(".npy")
        if skip_existing and out_path.exists():
            return ("skip", str(out_path))
        a = load_prnu_2d(path)
        if mode == "per_file":
            v = subsample_vals_abs(a, SAMPLE_PIXELS_PER_FILE)
            S = fast_percentile_abs(v, PERCENTILE)
        elif mode == "per_dataset":
            ds = infer_dataset_from_stem(stem)
            S = ds2S.get(ds, S_global)
        else:
            S = S_global
        q = quantize_i8(a, S)
        save_npy(out_base, q)
        return ("ok", str(out_path))
    except Exception as e:
        return ("err", f"{path} -> {e}")

# ========= 主流程（僅單一/少量資料集）=========
def run_quantize_single_dataset(src_dir: Path, dst_dir: Path, include_tags:set, mode="per_dataset"):
    all_files = list_files(src_dir)
    files = filter_files_by_tags(all_files, include_tags)
    if not files:
        raise FileNotFoundError(f"在 {src_dir} 找不到符合 {include_tags} 的 .npy/.npz 檔案")
    print(f"[INFO] 將處理 {len(files)} 個檔案，datasets = {sorted(list(include_tags))}")

    print(f"MODE = {mode}")
    S_global = estimate_global_S(files, sample_files=min(20000, len(files)),
                                 sample_pixels=SAMPLE_PIXELS_PER_FILE, q=PERCENTILE)
    print(f"Global S (filtered) = {S_global:.6g}")

    ds2S = {}
    if mode == "per_dataset":
        ds2S = estimate_S_per_dataset(files, sample_files_per_ds=SAMPLE_FILES_PER_DS,
                                      sample_pixels=SAMPLE_PIXELS_PER_FILE, q=PERCENTILE)
        print("Per-dataset S:", {k: round(v,6) for k,v in ds2S.items()})

    meta = {
        "mode": mode,
        "percentile": PERCENTILE,
        "global_S_filtered": S_global,
        "per_dataset_S": ds2S,
        "created_at": time.strftime("%Y-%m-%d %H:%M:%S"),
        "src": str(src_dir), "dst": str(dst_dir),
        "include_tags": sorted(list(include_tags)),
        "output_format": "npy(int8)"
    }
    (dst_dir.parent/"prnu_i8_meta.json").write_text(json.dumps(meta, ensure_ascii=False, indent=2))

    fn = partial(worker_convert_one, out_dir=dst_dir, mode=mode, S_global=S_global, ds2S=ds2S,
                 skip_existing=True)
    ok=skip=err=0
    with mp.Pool(processes=NPROC, maxtasksperchild=200) as pool:
        for status, msg in tqdm(pool.imap_unordered(fn, files, chunksize=64), total=len(files), desc=f"quantize→{dst_dir.name}"):
            if status=="ok": ok+=1
            elif status=="skip": skip+=1
            else:
                err+=1; print("[ERR]", msg)
    print(f"完成：ok={ok}, skip={skip}, err={err}, out_dir={dst_dir}")

# ======== 執行（僅 dalle3）=========
run_quantize_single_dataset(SRC_DIR, DST_DIR, include_tags=INCLUDE_TAGS, mode=MODE)
print("✅ Done. 只量化這些資料集：", sorted(list(INCLUDE_TAGS)))
print("   輸出為 int8 .npy；meta 寫在：", (DST_DIR.parent/"prnu_i8_meta.json"))


[INFO] 將處理 19000 個檔案，datasets = ['dalle3']
MODE = per_dataset


scan global S: 100%|██████████| 19000/19000 [00:37<00:00, 510.29it/s]


Global S (filtered) = 0.0596475


scan S[dalle3]: 100%|██████████| 500/500 [00:00<00:00, 3525.38it/s]

Per-dataset S: {'dalle3': 0.046448}



quantize→prnu_fake_i8_npy: 100%|██████████| 19000/19000 [00:04<00:00, 4230.24it/s]


完成：ok=19000, skip=0, err=0, out_dir=/home/yaya/ai-detect-proj/Script/features_i8/prnu_fake_i8_npy
✅ Done. 只量化這些資料集： ['dalle3']
   輸出為 int8 .npy；meta 寫在： /home/yaya/ai-detect-proj/Script/features_i8/prnu_i8_meta.json
