In [1]:
from __future__ import annotations
import os, math, json, uuid, tempfile
from io import BytesIO
from pathlib import Path
from typing import Dict, Any, List, Tuple

# 建議的 CUDA / I/O 環境變數
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True,max_split_size_mb:128")
os.environ.setdefault("PYTHONUTF8","1")
os.environ.setdefault("PYTHONIOENCODING","utf-8")

import numpy as np
import torch
import torch.nn as nn
from PIL import Image, ImageChops, ImageOps, ImageDraw, ImageFont
from skimage.restoration import denoise_wavelet

# 主要根目錄：預設當前工作目錄（Colab 會是 /content）
SCRIPT_ROOT = Path.cwd()
SAVED_MODELS = SCRIPT_ROOT / "saved_models"
SAVED_MODELS.mkdir(exist_ok=True, parents=True)

# 模型/融合器路徑
PRNU_MODEL_PATH = SAVED_MODELS / "prnu_fastcnn_best.pt"
ELA_MODEL_PATH  = SAVED_MODELS / "ela_fastcnn_best.pt"

CLIP_LOGREG_PKL = SAVED_MODELS / "clip_logreg_gpu.pkl"    # （若有 RAPIDS 可用）
CLIP_SVM_CUML   = SAVED_MODELS / "clip_svm_gpu.pkl"       # 回退
CLIP_SVM_TORCH  = SAVED_MODELS / "clip_svm_gpu_torch.pt"  # 回退
CLIP_PLATT_PKL  = SAVED_MODELS / "clip_platt.pkl"         # 可選

FUSER_PKL       = SAVED_MODELS / "fusion_lr.pkl"
FUSER_META      = SAVED_MODELS / "fusion_lr_meta.json"

# 參數
SEED            = 42
TILE            = 256
STRIDE          = 128
ELA_QUALITY     = 90
ELA_SCALE       = 15
ELA_FEASZ       = 128
PRNU_MODE       = "soft"
PRNU_WAVELET    = "db8"
PRNU_Q_MODE     = "per_file"
PRNU_Q_PERC     = 0.999
PRNU_Q_SAMPLES  = 4096

FORCE_JPG_NONJPG = True
JPEG_FORCE_QUALITY     = 95
JPEG_FORCE_SUBSAMPLING = 0     # 4:4:4

CLIP_BACKBONE   = "ViT-L-14"
CLIP_PRETRAINED = {"ViT-L-14":"laion2b_s32b_b82k","ViT-B-32":"laion400m_e32","ViT-L-14-336":"laion2b_s32b_b82k"}.get(CLIP_BACKBONE, "laion2b_s32b_b82k")

# 默認啟用模態（可在呼叫時覆寫）
DEFAULT_ENABLED = {"prnu": True, "ela": True, "clip": True}

# 裝置與 AMP
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_AMP = (device.type == "cuda")
if device.type == "cuda":
    torch.backends.cudnn.benchmark = True
    try:
        torch.set_float32_matmul_precision('high')
    except Exception:
        pass

rng = np.random.default_rng(SEED)


In [2]:
def _register_heif_avif():
    ok = False
    try:
        import pillow_heif
        pillow_heif.register_heif_opener()
        ok = True
        print("✅ pillow-heif registered")
    except Exception as e:
        print("ℹ️ pillow-heif not available:", e)
    # avif 兩種常見套件命名，擇一成功即可
    try:
        import pillow_avif  # from pillow-avif-python
        from pillow_avif import AvifImagePlugin  # noqa
        ok = True
        print("✅ pillow-avif (pillow-avif-python) registered")
    except Exception:
        try:
            import avif  # from pillow-avif-plugin
            ok = True
            print("✅ pillow-avif-plugin registered")
        except Exception as e:
            print("ℹ️ AVIF plugin not available:", e)
    return ok

_register_heif_avif()

def _open_image_any(p: Path) -> Image.Image:
    p = Path(p)
    try:
        img = Image.open(p)
        img.load()
        return img
    except Exception as e:
        if p.suffix.lower() in {".heic",".heif",".heifs",".hif",".avif"}:
            try:
                import pillow_heif
                h = pillow_heif.read_heif(str(p))
                img = Image.frombytes(h.mode, h.size, h.data, "raw")
                return img
            except Exception as e2:
                raise RuntimeError(f"HEIF/AVIF decode failed: {e2}") from e2
        raise

def _to_rgb_no_alpha(img: Image.Image) -> Image.Image:
    try:
        img = ImageOps.exif_transpose(img)
    except Exception:
        pass
    if img.mode in ("RGBA","LA") or (img.mode=="P" and "transparency" in img.info):
        bg = Image.new("RGB", img.size, (255,255,255))
        bg.paste(img.convert("RGBA"), mask=img.convert("RGBA").split()[-1])
        return bg
    if img.mode != "RGB":
        return img.convert("RGB")
    return img

def as_jpg_if_needed(p: Path,
                     quality: int = JPEG_FORCE_QUALITY,
                     subsampling: int | str = JPEG_FORCE_SUBSAMPLING) -> Tuple[Path, str | None]:
    p = Path(p)
    if p.suffix.lower() in {".jpg",".jpeg"}:
        return p, None
    img = _to_rgb_no_alpha(_open_image_any(p))
    tmp = p.with_suffix(f".tmp_infer_{uuid.uuid4().hex[:8]}.jpg")
    buf = BytesIO(); img.save(buf, format="JPEG", quality=int(quality), subsampling=subsampling, optimize=False)
    with open(tmp, "wb") as f: f.write(buf.getvalue())
    return tmp, str(tmp)


✅ pillow-heif registered
✅ pillow-avif (pillow-avif-python) registered


In [3]:
class DSBlock(nn.Module):
    def __init__(self, c_in, c_out, stride=1):
        super().__init__()
        self.dw = nn.Conv2d(c_in, c_in, 3, stride=stride, padding=1, groups=c_in, bias=False)
        self.bn1= nn.BatchNorm2d(c_in)
        self.pw = nn.Conv2d(c_in, c_out, 1, bias=False)
        self.bn2= nn.BatchNorm2d(c_out)
        self.act = nn.ReLU(inplace=True)
    def forward(self, x):
        x = self.act(self.bn1(self.dw(x)))
        x = self.act(self.bn2(self.pw(x)))
        return x

class FastCNN_1ch(nn.Module):
    def __init__(self, base=32, num_classes=2):
        super().__init__()
        self.stem = nn.Sequential(nn.Conv2d(1, base, 3, padding=1, bias=False),
                                  nn.BatchNorm2d(base), nn.ReLU(inplace=True))
        self.stage= nn.Sequential(DSBlock(base,base*2,1), DSBlock(base*2,base*4,2),
                                  DSBlock(base*4,base*4,1), DSBlock(base*4,base*8,2),
                                  DSBlock(base*8,base*8,1))
        self.head = nn.Sequential(nn.Conv2d(base*8, base*8, 1, bias=False),
                                  nn.BatchNorm2d(base*8), nn.ReLU(inplace=True))
        self.pool = nn.AdaptiveAvgPool2d(1); self.fc = nn.Linear(base*8,2)
    def forward(self, x):
        x=self.stem(x); x=self.stage(x); x=self.head(x); x=self.pool(x).flatten(1); return self.fc(x)

from collections import OrderedDict

def _safe_load(path, map_location="cpu"):
    try:
        return torch.load(path, map_location=map_location, weights_only=True)
    except TypeError:
        return torch.load(path, map_location=map_location)

def _extract_state_dict(blob):
    if isinstance(blob, dict):
        for k in ("state_dict","model","weights"):
            if k in blob and isinstance(blob[k], dict):
                sd = blob[k]; break
        else:
            sd = {k:v for k,v in blob.items() if torch.is_tensor(v)}
            if not sd: raise RuntimeError("未知 checkpoint 格式")
    elif isinstance(blob, torch.nn.Module):
        sd = blob.state_dict()
    else:
        sd = blob
    if len(sd)>0 and next(iter(sd)).startswith("module."):
        sd = OrderedDict((k[len("module."):], v) for k,v in sd.items())
    return sd

# 嘗試載入；缺檔時標示不可用，預設開關會自動關掉
AVAILABLE = {"prnu": False, "ela": False, "clip": True}
prnu_model = None; ela_model = None

try:
    prnu_model = FastCNN_1ch().to(device).eval()
    prnu_model.load_state_dict(_extract_state_dict(_safe_load(PRNU_MODEL_PATH, device)))
    AVAILABLE["prnu"] = True
    print("✅ PRNU model loaded")
except Exception as e:
    print("ℹ️ PRNU model not available:", e)

try:
    ela_model  = FastCNN_1ch().to(device).eval()
    ela_model.load_state_dict(_extract_state_dict(_safe_load(ELA_MODEL_PATH,  device)))
    AVAILABLE["ela"] = True
    print("✅ ELA model loaded")
except Exception as e:
    print("ℹ️ ELA model not available:", e)

# 若缺少某模態，預設關閉它
for k in list(DEFAULT_ENABLED):
    if not AVAILABLE.get(k, False):
        DEFAULT_ENABLED[k] = False

@torch.no_grad()
def prnu_logit_from_i8(arr_i8: np.ndarray) -> float:
    # channels_last + AMP
    x = arr_i8
    if x.dtype == np.int8:
        x = x.astype(np.float32) / 127.0
    elif x.dtype == np.uint8:
        x = x.astype(np.float32) / 255.0
    else:
        x = x.astype(np.float32); x = x/255.0 if x.max()>1.5 else x
    x = x - x.mean()
    t = torch.from_numpy(x[None,None,...]).to(device).contiguous(memory_format=torch.channels_last)
    with torch.autocast(device_type=device.type, enabled=USE_AMP):
        p1 = torch.softmax(prnu_model(t), dim=1)[0,1].item()
    p1 = float(np.clip(p1, 1e-6, 1-1e-6))
    z = math.log(p1) - math.log(1.0 - p1)
    return float(np.clip(z, -20.0, 20.0))

@torch.no_grad()
def ela_logit_from_i8(arr_i8: np.ndarray) -> float:
    a = arr_i8
    if a.dtype == np.int8:
        x = (a.astype(np.float32) + 128.0) / 255.0
    elif a.dtype == np.uint8:
        x = a.astype(np.float32) / 255.0
    else:
        x = a.astype(np.float32); x = x/255.0 if x.max()>1.5 else x
    x = x - x.mean()
    t = torch.from_numpy(x[None,None,...]).to(device).contiguous(memory_format=torch.channels_last)
    with torch.autocast(device_type=device.type, enabled=USE_AMP):
        p1 = torch.softmax(ela_model(t), dim=1)[0,1].item()
    p1 = float(np.clip(p1, 1e-6, 1-1e-6))
    z = math.log(p1) - math.log(1.0 - p1)
    return float(np.clip(z, -20.0, 20.0))


✅ PRNU model loaded
✅ ELA model loaded


In [4]:
_openclip = {"model":None,"pre":None,"dev":"cpu"}

def _cuml_ok():
    try:
        import cuml, cupy  # noqa
        return True
    except Exception:
        return False

@torch.no_grad()
def load_openclip():
    if _openclip["model"] is None:
        import open_clip
        dev = "cuda" if torch.cuda.is_available() else "cpu"
        model, _, pre = open_clip.create_model_and_transforms(CLIP_BACKBONE, pretrained=CLIP_PRETRAINED)
        model = model.to(dev).eval()
        _openclip.update(model=model, pre=pre, dev=dev)
    return _openclip["model"], _openclip["pre"], _openclip["dev"]

class CLIPSingle:
    def __init__(self):
        self.mode=None; self.backend=None; self.model=None; self.D=None; self.platt=None
        self._load()
    def _load(self):
        # cuML LR（優先）
        if CLIP_LOGREG_PKL.exists():
            import joblib
            obj = joblib.load(CLIP_LOGREG_PKL)
            self.model = obj["model"] if isinstance(obj, dict) and "model" in obj else obj
            self.mode, self.backend, self.D = "logreg", "cuml", 1024
            print("✅ CLIP: cuML LogisticRegression（logit）"); return
        # cuML SVM（回退）
        if CLIP_SVM_CUML.exists() and _cuml_ok():
            import joblib
            clf = joblib.load(CLIP_SVM_CUML)
            self.model = clf["model"] if isinstance(clf, dict) and "model" in clf else clf
            self.mode, self.backend, self.D = "svm", "cuml", 1024
            print("⚠️ CLIP: cuML SVM（margin）")
        # Torch 線性 SVM（回退）
        elif CLIP_SVM_TORCH.exists():
            sd = _safe_load(CLIP_SVM_TORCH, "cpu")
            state = sd["state_dict"] if isinstance(sd, dict) and "state_dict" in sd else sd
            self.D = int(sd["D"]) if isinstance(sd, dict) and "D" in sd else state["weight"].shape[1]
            lin = nn.Linear(self.D, 1, bias=True).eval(); lin.load_state_dict(state)
            self.model = lin; self.mode, self.backend = "svm", "torch"
            print(f"⚠️ CLIP: Torch SVM（margin, D={self.D}）")
        else:
            print("ℹ️ 未找到 CLIP 分類器（跳過 CLIP）")
            return
        if self.mode == "svm" and CLIP_PLATT_PKL.exists():
            import joblib
            try:
                self.platt = joblib.load(CLIP_PLATT_PKL)
                print("✅ 載入 Platt 標定器（SVM → prob）")
            except Exception as e:
                print("⚠️ Platt 載入失敗：", e)

    @torch.no_grad()
    def predict_logit(self, pil_tile_256: Image.Image) -> float | None:
        if self.model is None: return None
        tmp = Path(f"/tmp/clip_tile_{uuid.uuid4().hex[:8]}.jpg")
        buf = BytesIO(); pil_tile_256.save(buf, format="JPEG", quality=95, subsampling=0)
        with open(tmp, "wb") as f: f.write(buf.getvalue())
        try:
            model, pre, dev = load_openclip()
            img = Image.open(tmp).convert("RGB")
            im  = pre(img).unsqueeze(0).to(dev)
            visual = model.visual

            # 嘗試抓出 token map（最後一層）
            tokens=None; ok=False
            def _apply_norm(x):
                if x.ndim==2: x=x.unsqueeze(0)
                if hasattr(visual, "ln_post") and visual.ln_post is not None:
                    x = visual.ln_post(x)
                elif hasattr(visual, "trunk") and hasattr(visual.trunk, "norm") and visual.trunk.norm is not None:
                    x = visual.trunk.norm(x)
                return x.squeeze(0)

            for try_path in ("trunk.forward_features","visual.forward_features","hook"):
                try:
                    if try_path=="trunk.forward_features" and hasattr(visual,"trunk") and hasattr(visual.trunk,"forward_features"):
                        out = visual.trunk.forward_features(im)
                        if isinstance(out,(tuple,list)): out=out[0]
                        if isinstance(out,dict): out=out.get("x", out.get("tokens", out))
                        if torch.is_tensor(out) and out.ndim==3:
                            tokens=_apply_norm(out.detach()); ok=True; break
                    if try_path=="visual.forward_features" and hasattr(visual,"forward_features"):
                        out = visual.forward_features(im)
                        if isinstance(out,(tuple,list)): out=out[0]
                        if isinstance(out,dict): out=out.get("x", out.get("tokens", out))
                        if torch.is_tensor(out) and out.ndim==3:
                            tokens=_apply_norm(out.detach()); ok=True; break
                    if try_path=="hook":
                        feats={}; h=None
                        try:
                            if hasattr(visual,"trunk") and hasattr(visual.trunk,"blocks") and len(visual.trunk.blocks)>0:
                                target=visual.trunk.blocks[-1]
                                def _hook(_m,_i,o): feats["x"]=o
                                h=target.register_forward_hook(_hook)
                                _ = getattr(visual.trunk, "forward_features", visual.trunk.forward)(im)
                            elif hasattr(visual,"transformer") and hasattr(visual.transformer,"resblocks") and len(visual.transformer.resblocks)>0:
                                target=visual.transformer.resblocks[-1]
                                def _hook(_m,_i,o): feats["x"]=o
                                h=target.register_forward_hook(_hook); _=model.encode_image(im)
                        finally:
                            if h is not None:
                                try: h.remove()
                                except Exception: pass
                        x=feats.get("x",None)
                        if torch.is_tensor(x) and x.ndim==3:
                            tokens=_apply_norm(x.detach()); ok=True; break
                except Exception:
                    pass
            if not ok:
                return None
            tokens = tokens.float().cpu().numpy().astype(np.float32)
            if tokens.shape[0] <= 1: return None
            pooled = tokens[1:].mean(axis=0)
            pooled /= (np.linalg.norm(pooled)+1e-12)

            if self.backend == "cuml":
                import cupy as cp
                s = self.model.decision_function(cp.asarray(pooled[None,:])).get().astype(np.float32)[0]
                return float(np.clip(s, -20.0, 20.0))
            else:
                with torch.no_grad():
                    s = self.model(torch.from_numpy(pooled[None,:]).float()).squeeze(1).numpy().astype(np.float32)[0]
                if self.mode == "svm" and self.platt is not None:
                    from sklearn.linear_model import LogisticRegression  # noqa
                    import numpy as _np
                    p = float(self.platt.predict_proba(_np.array(s, _np.float32).reshape(1,1))[:,1][0])
                    z = math.log(max(p,1e-6)) - math.log(max(1-p,1e-6))
                    return float(np.clip(z, -20.0, 20.0))
                return float(np.clip(s, -20.0, 20.0))
        finally:
            try: os.unlink(tmp)
            except Exception: pass

clip_single = CLIPSingle()


✅ CLIP: cuML LogisticRegression（logit）


In [5]:
def read_json_utf8(p: Path):
    with p.open("r", encoding="utf-8") as f:
        return json.load(f)

def load_fuser_and_meta():
    meta=None; kind="dummy"; fuser=None
    try:
        if FUSER_PKL.exists():
            import joblib
            fuser = joblib.load(FUSER_PKL); kind="lr"
            print("✅ fusion_lr.pkl loaded (sklearn LR)")
    except Exception as e:
        print("ℹ️ fusion_lr.pkl not available:", e)
    if FUSER_META.exists():
        try:
            meta = read_json_utf8(FUSER_META); print("✅ fusion_lr_meta.json loaded")
        except Exception as e:
            print("ℹ️ fusion meta read fail:", e)
    return kind, fuser, meta

fuser_kind, fuser, fmeta = load_fuser_and_meta()

def fuse_logits(logits_dict: Dict[str, float | None],
                enabled: List[str] | Tuple[str, ...] | None = None) -> float:
    full_order = ["prnu", "ela", "clip"]
    enabled = list(enabled) if enabled is not None else full_order
    enabled = [k for k in full_order if k in enabled]

    # 3模態全開 + 有 LR → 直接用
    if fuser_kind == "lr" and enabled == full_order:
        z = [float(np.nan_to_num(logits_dict.get(k, 0.0), nan=0.0, posinf=20.0, neginf=-20.0)) for k in full_order]
        X = np.array([z], np.float32)
        try:
            proba = float(fuser.predict_proba(X)[:, 1][0])
            return float(np.clip(proba, 1e-6, 1-1e-6))
        except Exception as e:
            print("⚠️ LR fusion failed; fallback to weights:", e)

    # 備援權重
    w = None
    if fmeta and "weights_if_no_sklearn" in fmeta:
        meta_order = [s.lower() for s in fmeta.get("order", ["PRNU","ELA","CLIP"])]
        meta_w_all = np.array(fmeta["weights_if_no_sklearn"], np.float32)
        pick = [meta_order.index(k) for k in enabled]
        w = meta_w_all[pick]
        s = float(w.sum()); w = (w/s) if s>0 else None
    if w is None:
        w = np.ones((len(enabled),), np.float32) / max(1, len(enabled))

    z = [float(np.nan_to_num(logits_dict.get(k, 0.0), nan=0.0, posinf=20.0, neginf=-20.0)) for k in enabled]
    zz = float(np.clip(np.dot(z, w), -20.0, 20.0))
    return float(1.0/(1.0+np.exp(-zz)))

# ---- 特徵抽取（tile 級）----
def _to_int8_offset128_from_01(x01: np.ndarray) -> np.ndarray:
    u8 = np.rint(np.clip(x01, 0.0, 1.0) * 255.0).astype(np.uint8)
    return (u8.astype(np.int16) - 128).astype(np.int8)

def ela_i8_from_tile(pil_tile_256: Image.Image) -> np.ndarray:
    buf = BytesIO(); pil_tile_256.save(buf, format="JPEG", quality=int(ELA_QUALITY), subsampling=0, optimize=False)
    buf.seek(0)
    diff = ImageChops.difference(pil_tile_256, Image.open(buf)).point(lambda x: x * ELA_SCALE)
    diff = diff.convert("L").resize((ELA_FEASZ, ELA_FEASZ))
    arr01 = np.asarray(diff, dtype=np.float32) / 255.0
    return _to_int8_offset128_from_01(arr01)

def prnu_i8_from_tile(np_tile_rgb: np.ndarray) -> np.ndarray:
    if np_tile_rgb.ndim == 2:
        np_tile_rgb = np.repeat(np_tile_rgb[...,None], 3, axis=-1)
    gray = np_tile_rgb.mean(axis=2, dtype=np.float32)
    try:
        den = denoise_wavelet(gray, channel_axis=None, mode=PRNU_MODE, wavelet=PRNU_WAVELET, convert2ycbcr=False)
    except TypeError:
        den = denoise_wavelet(gray, multichannel=False, mode=PRNU_MODE, wavelet=PRNU_WAVELET, convert2ycbcr=False)
    residual = gray - den
    residual -= residual.mean()
    if PRNU_Q_MODE == "per_file":
        v = residual.reshape(-1).astype(np.float32, copy=False)
        if v.size > PRNU_Q_SAMPLES:
            idx = rng.integers(0, v.size, size=PRNU_Q_SAMPLES, endpoint=False)
            v = np.abs(v[idx])
        else:
            v = np.abs(v)
        k = int(PRNU_Q_PERC * max(1, v.size-1))
        S = float(max(1e-8, np.partition(v, k)[k]))
    else:
        S = max(1e-6, float(np.std(residual)) * 6.0)
    x = np.clip(residual, -S, S) / S * 127.0
    q = np.rint(x).astype(np.int16)
    q = np.clip(q, -127, 127).astype(np.int8)
    return q

def make_grid(w: int, h: int, tile: int = TILE, stride: int = STRIDE) -> List[Tuple[int,int,int,int]]:
    xs = list(range(0, max(1, w - tile + 1), stride))
    ys = list(range(0, max(1, h - tile + 1), stride))
    if xs[-1] != w - tile: xs.append(max(0, w - tile))
    if ys[-1] != h - tile: ys.append(max(0, h - tile))
    coords = [(x, y, tile, tile) for y in ys for x in xs]
    return coords

# ---- 可視化 ----
def _text_size(draw, text, font):
    if hasattr(draw, "textbbox"):
        l, t, r, b = draw.textbbox((0, 0), text, font=font); return (r - l), (b - t)
    if hasattr(font, "getbbox"):
        l, t, r, b = font.getbbox(text); return (r - l), (b - t)
    if hasattr(font, "getsize"):
        return font.getsize(text)
    return (8 * len(text), 12)

def overlay_tiles(base_img: Image.Image,
                  tiles: list[dict],
                  alpha: float = 0.35,
                  draw_frame: bool = True,
                  show_score: bool = True) -> Image.Image:
    W, H = base_img.size
    heat = np.zeros((H, W), dtype=np.float32)
    cnt  = np.zeros((H, W), dtype=np.float32)
    for t in tiles:
        x, y, w, h = t['x'], t['y'], t['w'], t['h']
        p = float(t['prob_fake'])
        heat[y:y+h, x:x+w] += p
        cnt [y:y+h, x:x+w] += 1.0
    cnt[cnt == 0] = 1.0
    heat = heat / cnt
    heat_u8 = np.rint(np.clip(heat, 0.0, 1.0) * 255.0).astype(np.uint8)
    heat_rgb = np.zeros((H, W, 3), dtype=np.uint8); heat_rgb[..., 0] = heat_u8
    base_rgb = base_img.convert('RGB'); heat_img = Image.fromarray(heat_rgb, mode='RGB')
    overlay = Image.blend(base_rgb, heat_img, alpha).convert('RGBA')
    draw = ImageDraw.Draw(overlay, 'RGBA')
    try:
        font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16)
    except Exception:
        font = ImageFont.load_default()
    if draw_frame or show_score:
        for t in tiles:
            x, y, w, h = t['x'], t['y'], t['w'], t['h']
            p  = float(t['prob_fake'])
            lab = int(t['pred_label'])
            color = (255, 0, 0, 255) if lab == 1 else (0, 255, 0, 255)
            if draw_frame:
                draw.rectangle([x, y, x + w, y + h], outline=color, width=2)
            if show_score:
                txt = f"{p:.2f}"; tw, th = _text_size(draw, txt, font)
                draw.rectangle([x, y, x + tw + 6, y + th + 4], fill=(0, 0, 0, 127))
                draw.text((x + 3, y + 2), txt, fill=(255, 255, 255, 255), font=font)
    return overlay


✅ fusion_lr.pkl loaded (sklearn LR)
✅ fusion_lr_meta.json loaded


In [6]:
def _safe_logit(p: float, eps: float = 1e-6, clamp: float = 20.0) -> float:
    p = float(np.clip(p, eps, 1.0 - eps))
    z = math.log(p) - math.log(1.0 - p)
    return float(np.clip(z, -clamp, clamp))

def infer_image_by_tiles(img_path: str | Path,
                         tile: int = TILE,
                         stride: int = STRIDE,
                         aggregate: str = 'mean_prob',
                         use_clip: bool = True,
                         save_overlay: bool = True,
                         overlay_alpha: float = 0.35,
                         enable: Dict[str, bool] | None = None
                         ) -> Dict[str, Any]:
    """
    enable: 例如 {"prnu":True,"ela":True,"clip":False}
    若 None → 用 DEFAULT_ENABLED；且若 use_clip=False 會強制關閉 clip
    """
    p = Path(img_path)
    enable = dict(DEFAULT_ENABLED if enable is None else enable)
    if not use_clip:
        enable["clip"] = False
    # 若某模態模型不存在，自動關閉
    if not AVAILABLE.get("prnu", False):
        enable["prnu"] = False
    if not AVAILABLE.get("ela", False):
        enable["ela"]  = False

    enabled_keys = [k for k,v in enable.items() if v]
    if len(enabled_keys) == 0:
        raise ValueError("至少需啟用一個模態（prnu/ela/clip）")

    tmp_files = []
    # 只有需要 ELA 或 CLIP 時才轉 JPG
    need_jpg = FORCE_JPG_NONJPG and (enable.get("ela",False) or enable.get("clip",False))
    if need_jpg:
        p_jpg, t1 = as_jpg_if_needed(p)
        if t1: tmp_files.append(t1)
    else:
        p_jpg = p

    # 讀圖
    base_rgb = _to_rgb_no_alpha(_open_image_any(p))
    img_pil = Image.open(p_jpg).convert('RGB')
    W, H = img_pil.size
    coords = make_grid(W, H, tile=tile, stride=stride)
    tiles_out: List[Dict[str, Any]] = []

    np_img = np.asarray(base_rgb, dtype=np.uint8)

    # 逐 tile
    for (x,y,w,h) in coords:
        pil_tile = img_pil.crop((x,y,x+w,y+h))
        np_tile  = np.asarray(np_img[y:y+h, x:x+w, :], dtype=np.uint8)

        z_prnu = None
        if enable.get("prnu", False) and prnu_model is not None:
            prnu_i8 = prnu_i8_from_tile(np_tile)
            z_prnu = prnu_logit_from_i8(prnu_i8)

        z_ela = None
        if enable.get("ela", False) and ela_model is not None:
            ela_i8  = ela_i8_from_tile(pil_tile)
            z_ela   = ela_logit_from_i8(ela_i8)

        z_clip = None
        if enable.get("clip", False) and clip_single.model is not None:
            z_clip = clip_single.predict_logit(pil_tile)

        prob_fake = fuse_logits({"prnu": z_prnu, "ela": z_ela, "clip": z_clip}, enabled=enabled_keys)
        pred = int(prob_fake >= 0.5)

        tiles_out.append({
            "x":x, "y":y, "w":w, "h":h,
            "prnu_logit": (None if z_prnu is None else float(z_prnu)),
            "ela_logit":  (None if z_ela  is None else float(z_ela)),
            "clip_logit": (None if z_clip is None else float(z_clip)),
            "prob_fake": float(prob_fake),
            "pred_label": pred,
        })

    # 彙總
    probs = np.array([t['prob_fake'] for t in tiles_out], np.float32)
    if aggregate == 'mean_prob':
        whole_prob = float(np.clip(probs.mean() if len(probs)>0 else 0.0, 1e-6, 1-1e-6))
    elif aggregate == 'max_prob':
        whole_prob = float(np.clip(probs.max() if len(probs)>0 else 0.0, 1e-6, 1-1e-6))
    elif aggregate.startswith('topk_mean'):
        try:
            frac = float(aggregate.split(':',1)[1]) if ':' in aggregate else 0.5
        except Exception:
            frac = 0.5
        k = max(1, int(len(probs) * frac))
        idx = np.argsort(-probs)[:k]
        whole_prob = float(np.clip(probs[idx].mean(), 1e-6, 1-1e-6))
    elif aggregate == 'mean_logit':
        lgs = np.array([_safe_logit(float(p)) for p in probs], np.float32)
        lg  = float(np.clip(lgs.mean() if len(lgs)>0 else 0.0, -20.0, 20.0))
        whole_prob = float(1.0/(1.0+np.exp(-lg)))
    else:
        whole_prob = float(np.clip(probs.mean() if len(probs)>0 else 0.0, 1e-6, 1-1e-6))
    whole_pred = int(whole_prob >= 0.5)

    # 各模態平均 logit（跨 tiles）
    def _avg_logit(key: str):
        vals = [t[f"{key}_logit"] for t in tiles_out if t[f"{key}_logit"] is not None]
        return (None if len(vals)==0 else float(np.mean(np.array(vals, np.float32))))
    feature_avg_logits = {"prnu": _avg_logit("prnu"),
                          "ela":  _avg_logit("ela"),
                          "clip": _avg_logit("clip")}

    overlay_path = None
    if save_overlay:
        overlay_img = overlay_tiles(base_rgb, tiles_out, alpha=overlay_alpha,
                                    draw_frame=True, show_score=True)
        overlay_path = str(p.with_suffix(".tiles_overlay.png"))
        overlay_img.save(overlay_path)

    out = {
        "path": str(p),
        "image_size": [H, W],
        "tile": tile,
        "stride": stride,
        "aggregate": aggregate,
        "enabled_modalities": enable,
        "feature_avg_logits": feature_avg_logits,
        "overall": {"prob_fake": float(whole_prob), "pred_label": int(whole_pred)},
        "tiles": tiles_out,
        "overlay_path": overlay_path,
    }

    for t in tmp_files:
        try: os.unlink(t)
        except Exception: pass

    return out


In [7]:
def aggregate_probs_offline(tiles: List[dict], method: str) -> float:
    probs = np.array([t['prob_fake'] for t in tiles], np.float32)
    if probs.size == 0: 
        return 1e-6
    if method == "mean_prob":
        p = probs.mean()
    elif method == "max_prob":
        p = probs.max()
    elif method.startswith("topk_mean"):
        try:
            frac = float(method.split(":",1)[1]) if ":" in method else 0.5
        except Exception:
            frac = 0.5
        k = max(1, int(len(probs) * frac))
        idx = np.argsort(-probs)[:k]
        p = probs[idx].mean()
    elif method == "mean_logit":
        lgs = np.array([_safe_logit(float(x)) for x in probs], np.float32)
        lg = float(np.clip(lgs.mean(), -20.0, 20.0))
        p = float(1.0/(1.0+np.exp(-lg)))
    else:
        p = probs.mean()
    return float(np.clip(p, 1e-6, 1-1e-6))


In [None]:
import gradio as gr

def run_once(img_path_or_pil, tile, stride, use_clip, enable_prnu, enable_ela, enable_clip, alpha, progress=gr.Progress()):
    """
    跑一次完整推論：支援進度條（內部以階段更新），輸出 overlay 與 raw 結果。
    """
    progress(0.0, desc="準備中…")
    # 準備臨時檔（Gradio 可能傳 PIL 或檔路徑）
    if isinstance(img_path_or_pil, str):
        p = Path(img_path_or_pil)
    else:
        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
            img_path_or_pil.save(f.name)
            p = Path(f.name)

    enable = {"prnu": bool(enable_prnu), "ela": bool(enable_ela), "clip": bool(enable_clip)}
    if not use_clip:
        enable["clip"] = False

    progress(0.2, desc="切 tile / 前處理…")
    res = infer_image_by_tiles(
        p, tile=int(tile), stride=int(stride),
        aggregate='mean_prob',              # 實際 tiles 的 prob_fake 與彙總無關；之後再事後重算
        use_clip=bool(use_clip),
        save_overlay=True,
        overlay_alpha=float(alpha),
        enable=enable
    )
    progress(0.85, desc="產生可視化…")
    overlay = None
    if res.get("overlay_path"):
        overlay = Image.open(res["overlay_path"]).convert("RGB")

    # 整理 caption 與平均 logit 顯示
    flog = res.get("feature_avg_logits", {})
    msg = (f"PRNU logit(avg)={flog.get('prnu')} | "
           f"ELA logit(avg)={flog.get('ela')} | "
           f"CLIP logit(avg)={flog.get('clip')}")
    progress(1.0, desc="完成")
    return msg, overlay, res

def reaggregate(raw_result: dict, methods: list[str]):
    if not raw_result or "tiles" not in raw_result:
        return "尚未有推論結果", []
    tiles = raw_result["tiles"]
    table = []
    best = ("", -1.0, 0)
    for m in (methods or []):
        p = aggregate_probs_offline(tiles, m)
        y = int(p >= 0.5)
        table.append([m, round(float(p), 6), y])
        if p > best[1]:
            best = (m, p, y)
    if not methods:
        return "請至少選擇一個彙總方式", []
    msg = f"最佳：{best[0]}  prob={best[1]:.4f}  pred={best[2]}"
    return msg, table

with gr.Blocks(title="Slide-256 Detector — PRNU/ELA/CLIP + Fusion (CUDA)") as demo:
    gr.Markdown("## Slide-256 Full-Image Inference — PRNU / ELA / CLIP + Fusion\n上傳或拍照 → 跑一次 → 右側彙總（不重跑）")

    with gr.Row():
        with gr.Column(scale=1):
            inp = gr.Image(type="filepath", label="上傳或拍照", sources=["upload","webcam","clipboard"], height=320)
            with gr.Row():
                tile   = gr.Number(value=TILE, precision=0, label="Tile")
                stride = gr.Number(value=STRIDE, precision=0, label="Stride")
            use_clip   = gr.Checkbox(value=DEFAULT_ENABLED.get("clip", True), label="Use CLIP head")
            with gr.Row():
                enable_prnu = gr.Checkbox(value=DEFAULT_ENABLED.get("prnu", True), label="Enable PRNU")
                enable_ela  = gr.Checkbox(value=DEFAULT_ENABLED.get("ela",  True), label="Enable ELA")
                enable_clip = gr.Checkbox(value=DEFAULT_ENABLED.get("clip", True), label="Enable CLIP")
            alpha      = gr.Slider(0.0, 1.0, value=0.35, step=0.05, label="Overlay alpha")
            run_btn    = gr.Button("開始推論", variant="primary")
        with gr.Column(scale=1):
            caption = gr.Label(label="平均 logit / 訊息")
            overlay = gr.Image(label="Overlay（熱度 + tile 分數）")
            raw     = gr.JSON(label="Raw 結果（tiles）")

            methods = gr.CheckboxGroup(
                choices=["mean_prob","max_prob","topk_mean:0.5","topk_mean:0.25","mean_logit"],
                value=["topk_mean:0.5","mean_logit"],
                label="選擇一個或多個彙總方式（不重跑）"
            )
            recompute = gr.Button("重新計算彙總")
            table = gr.Dataframe(
                headers=["method","prob_fake","pred_label"],
                datatype=["str","number","number"],
                interactive=False,
                label="各彙總方式結果"
            )
            out_msg = gr.Label(label="彙總訊息 / 最佳策略")

    run_btn.click(
        run_once,
        inputs=[inp, tile, stride, use_clip, enable_prnu, enable_ela, enable_clip, alpha],
        outputs=[caption, overlay, raw]
    )
    recompute.click(
        reaggregate,
        inputs=[raw, methods],
        outputs=[out_msg, table]
    )

demo.queue(max_size=16).launch(share=True, inline=True)  # share=True 方便手機示範


* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://12cfe6a533f3b00520.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




  base_rgb = base_img.convert('RGB'); heat_img = Image.fromarray(heat_rgb, mode='RGB')
  base_rgb = base_img.convert('RGB'); heat_img = Image.fromarray(heat_rgb, mode='RGB')
  base_rgb = base_img.convert('RGB'); heat_img = Image.fromarray(heat_rgb, mode='RGB')
  base_rgb = base_img.convert('RGB'); heat_img = Image.fromarray(heat_rgb, mode='RGB')
  base_rgb = base_img.convert('RGB'); heat_img = Image.fromarray(heat_rgb, mode='RGB')
  base_rgb = base_img.convert('RGB'); heat_img = Image.fromarray(heat_rgb, mode='RGB')
  base_rgb = base_img.convert('RGB'); heat_img = Image.fromarray(heat_rgb, mode='RGB')
  base_rgb = base_img.convert('RGB'); heat_img = Image.fromarray(heat_rgb, mode='RGB')
  base_rgb = base_img.convert('RGB'); heat_img = Image.fromarray(heat_rgb, mode='RGB')
  base_rgb = base_img.convert('RGB'); heat_img = Image.fromarray(heat_rgb, mode='RGB')
  base_rgb = base_img.convert('RGB'); heat_img = Image.fromarray(heat_rgb, mode='RGB')
  base_rgb = base_img.convert('RGB'); heat_