In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import sys, subprocess, warnings
warnings.filterwarnings("ignore")

def ensure_pkg(import_name, pip_name=None):
    if pip_name is None:
        pip_name = import_name
    try:
        __import__(import_name)
    except Exception:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", pip_name])

print("📦 Installing packages...")
for pkg in [
    ("transformers", "transformers"),
    ("sklearn", "scikit-learn"),
    ("tqdm", "tqdm"),
    ("pandas", "pandas"),
    ("numpy", "numpy"),
    ("torch", "torch"),
]:
    ensure_pkg(pkg[0], pkg[1])

import re, random, copy
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup

from sklearn.metrics import classification_report, f1_score, accuracy_score
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.neighbors import NearestNeighbors
from tqdm.auto import tqdm

# ============================================================
# CONFIG
# ============================================================

CONFIG = {
    "data_path": "/content/sample_data/fake_news_dataset.csv",

    "electra_name": "FPTAI/velectra-base-discriminator-cased",
    "phobert_name": "vinai/phobert-base",

    "max_length_electra": 256,
    "max_length_phobert": 256,

    "batch_size_electra": 16,
    "batch_size_phobert": 8,

    "learning_rate": 1e-5,
    "epochs_electra": 5,
    "epochs_phobert": 5,
    "warmup_steps": 100,
    "weight_decay": 0.01,
    "dropout": 0.2,

    "freeze_electra_layers": 8,
    "freeze_electra_embeddings": True,
    "freeze_phobert_layers": 0,
    "freeze_phobert_embeddings": False,

    # MoE / Gating
    "gate_hidden": 256,
    "gate_dropout": 0.3,
    "gate_lr": 2e-4,
    "epochs_gate": 15,
    "gate_warmup_steps": 50,
    "entropy_reg": 0.01,  # khuyến khích gate không collapse 0/1 quá sớm
    "use_style_features": True,  # thêm feature giật tít vào gate

    # Leak/Near-dup settings
    "near_dup_threshold": 0.92,
    "near_dup_k": 20,
    "char_ngram_range": (4, 6),
    "min_df": 2,

    # Split ratios (approx)
    "test_ratio": 0.10,
    "val_ratio_of_trainval": 0.11,

    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "seed": 42,
}

print("="*90)
print("🚀 VIETNAMESE FAKE NEWS - MoE GATING (LEAK-SAFE SPLIT)")
print("="*90)
print(f"🖥️ Device: {CONFIG['device']}")
print("🧾 Label mapping: 0=REAL, 1=FAKE")

# ============================================================
# SEED
# ============================================================

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        try:
            torch.cuda.manual_seed_all(seed)
        except Exception:
            pass

set_seed(CONFIG["seed"])

# ============================================================
# CLEAN TEXT
# ============================================================

def clean_text(text):
    if pd.isna(text):
        return ""
    text = str(text)
    text = re.sub(r"<[^>]+>", " ", text)
    text = re.sub(r"http[s]?://\S+", " ", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text

def is_valid(text):
    return len(text.split()) >= 8

# ============================================================
# Union-Find for clustering near-duplicates
# ============================================================

class UnionFind:
    def __init__(self, n):
        self.p = list(range(n))
        self.r = [0]*n

    def find(self, x):
        while self.p[x] != x:
            self.p[x] = self.p[self.p[x]]
            x = self.p[x]
        return x

    def union(self, a, b):
        ra, rb = self.find(a), self.find(b)
        if ra == rb:
            return
        if self.r[ra] < self.r[rb]:
            self.p[ra] = rb
        elif self.r[ra] > self.r[rb]:
            self.p[rb] = ra
        else:
            self.p[rb] = ra
            self.r[ra] += 1

def build_near_dup_groups(texts, threshold=0.92, k=20, ngram_range=(4,6), min_df=2):
    n = len(texts)
    if n == 0:
        return np.array([], dtype=int)

    vec = TfidfVectorizer(
        analyzer="char_wb",
        ngram_range=ngram_range,
        min_df=min_df,
        dtype=np.float32
    )
    X = vec.fit_transform(texts)

    nn = NearestNeighbors(
        n_neighbors=min(k, n),
        metric="cosine",
        algorithm="brute",
        n_jobs=-1
    )
    nn.fit(X)
    dists, idxs = nn.kneighbors(X, return_distance=True)

    uf = UnionFind(n)
    for i in range(n):
        for dist, j in zip(dists[i], idxs[i]):
            if j == i:
                continue
            sim = 1.0 - float(dist)
            if sim >= threshold:
                uf.union(i, int(j))

    roots = np.array([uf.find(i) for i in range(n)], dtype=int)

    uniq = {}
    gid = np.zeros(n, dtype=int)
    c = 0
    for i, r in enumerate(roots):
        if r not in uniq:
            uniq[r] = c
            c += 1
        gid[i] = uniq[r]
    return gid

def report_leak_exact(a_texts, b_texts, name):
    sa = set(a_texts)
    sb = set(b_texts)
    inter = len(sa & sb)
    print(f"Leak exact {name}: {inter}")
    return inter

def report_leak_near(a_texts, b_texts, threshold=0.92, ngram_range=(4,6), min_df=2):
    if len(a_texts)==0 or len(b_texts)==0:
        return {"mean": 0, "median": 0, "p95": 0, "max": 0}

    vec = TfidfVectorizer(analyzer="char_wb", ngram_range=ngram_range, min_df=min_df, dtype=np.float32)
    X_a = vec.fit_transform(a_texts)
    X_b = vec.transform(b_texts)

    nn = NearestNeighbors(n_neighbors=1, metric="cosine", algorithm="brute", n_jobs=-1).fit(X_a)
    dists, _ = nn.kneighbors(X_b, return_distance=True)
    sims = 1.0 - dists.reshape(-1)

    stats = {
        "mean": float(np.mean(sims)),
        "median": float(np.median(sims)),
        "p95": float(np.quantile(sims, 0.95)),
        "max": float(np.max(sims)),
        "count_ge_thr": int(np.sum(sims >= threshold))
    }
    return stats

# ============================================================
# LOAD DATA + CLEAN + EXACT DEDUP
# ============================================================

print("\n" + "="*90)
print("📂 LOADING DATA")
print("="*90)

if not os.path.exists(CONFIG["data_path"]):
    raise FileNotFoundError(f"❌ Not found: {CONFIG['data_path']}")

df = pd.read_csv(CONFIG["data_path"])

if "text" not in df.columns:
    for c in ["content", "article", "news", "body", "title"]:
        if c in df.columns:
            df["text"] = df[c]
            break
if "text" not in df.columns:
    raise ValueError("❌ Cannot find text column")

if "label" not in df.columns:
    for c in ["class", "category", "y"]:
        if c in df.columns:
            df["label"] = df[c]
            break
if "label" not in df.columns:
    raise ValueError("❌ Cannot find label column")

df = df[["text", "label"]].dropna()
df["label"] = df["label"].astype(int)

bad = df[~df["label"].isin([0, 1])]
if len(bad) > 0:
    raise ValueError(f"❌ Found labels not in {{0,1}}. Examples:\n{bad.head()}")

print("🧹 Cleaning + exact dedup...")
df["text_clean"] = df["text"].apply(clean_text)
df = df[df["text_clean"].apply(is_valid)].copy()
before = len(df)
df = df.drop_duplicates(subset=["text_clean"], keep="first").reset_index(drop=True)
print(f"✅ After exact dedup: {len(df)} (removed {before-len(df)})")

n = len(df)
c0 = int((df["label"]==0).sum())
c1 = int((df["label"]==1).sum())
print(f"✅ Final: {n} samples | REAL={c0} ({c0/n:.1%}) | FAKE={c1} ({c1/n:.1%})")

# ============================================================
# BUILD NEAR-DUP GROUPS (CLUSTERING)
# ============================================================

print("\n" + "="*90)
print("🧩 BUILDING NEAR-DUP CLUSTERS")
print("="*90)

groups = build_near_dup_groups(
    df["text_clean"].tolist(),
    threshold=float(CONFIG["near_dup_threshold"]),
    k=int(CONFIG["near_dup_k"]),
    ngram_range=tuple(CONFIG["char_ngram_range"]),
    min_df=int(CONFIG["min_df"])
)
df["group"] = groups

n_groups = int(df["group"].nunique())
group_sizes = df["group"].value_counts()
print(f"✅ Groups: {n_groups} | Largest group size: {int(group_sizes.max())}")
print(f"Top 5 group sizes:\n{group_sizes.head(5).to_string()}")

# ============================================================
# STRATIFIED GROUP SPLIT: TRAIN / VAL / TEST
# ============================================================

print("\n" + "="*90)
print("✂️ LEAK-SAFE SPLIT (STRATIFIED BY LABEL, GROUP-AWARE)")
print("="*90)

y = df["label"].values
g = df["group"].values

kfold = 10
sgkf = StratifiedGroupKFold(n_splits=kfold, shuffle=True, random_state=CONFIG["seed"])

best_fold = None
best_diff = 1e9
test_target = float(CONFIG["test_ratio"])

splits = list(sgkf.split(df, y, groups=g))
for fold_i, (trainval_idx, test_idx) in enumerate(splits):
    ratio = len(test_idx)/len(df)
    diff = abs(ratio - test_target)
    if diff < best_diff:
        best_diff = diff
        best_fold = (trainval_idx, test_idx, ratio, fold_i)

trainval_idx, test_idx, ratio, fold_i = best_fold
print(f"Picked fold {fold_i} for TEST: size={len(test_idx)} ({ratio:.3f})")

df_trainval = df.iloc[trainval_idx].reset_index(drop=True)
df_test = df.iloc[test_idx].reset_index(drop=True)

val_target = float(CONFIG["val_ratio_of_trainval"])
kfold2 = 9
sgkf2 = StratifiedGroupKFold(n_splits=kfold2, shuffle=True, random_state=CONFIG["seed"]+7)
splits2 = list(sgkf2.split(df_trainval, df_trainval["label"].values, groups=df_trainval["group"].values))

best2 = None
best2_diff = 1e9
for fold_i2, (train_idx, val_idx) in enumerate(splits2):
    ratio2 = len(val_idx)/len(df_trainval)
    diff2 = abs(ratio2 - val_target)
    if diff2 < best2_diff:
        best2_diff = diff2
        best2 = (train_idx, val_idx, ratio2, fold_i2)

train_idx, val_idx, ratio2, fold_i2 = best2
print(f"Picked fold {fold_i2} for VAL: size={len(val_idx)} ({ratio2:.3f})")

train_df = df_trainval.iloc[train_idx].reset_index(drop=True)
val_df = df_trainval.iloc[val_idx].reset_index(drop=True)
test_df = df_test.copy()

def dist_print(name, dfx):
    n = len(dfx)
    c0 = int((dfx["label"]==0).sum())
    c1 = int((dfx["label"]==1).sum())
    print(f"{name}: {n} | REAL={c0} ({c0/n:.1%}) | FAKE={c1} ({c1/n:.1%}) | groups={dfx['group'].nunique()}")

dist_print("Train", train_df)
dist_print("Val  ", val_df)
dist_print("Test ", test_df)

overlap_tv = set(train_df["group"]) & set(val_df["group"])
overlap_tt = set(train_df["group"]) & set(test_df["group"])
overlap_vt = set(val_df["group"]) & set(test_df["group"])
print(f"\nGroup overlap Train∩Val={len(overlap_tv)} | Train∩Test={len(overlap_tt)} | Val∩Test={len(overlap_vt)}")

# ============================================================
# LEAK REPORT (EXACT + NEAR)
# ============================================================

print("\n" + "="*90)
print("🧪 LEAK REPORT")
print("="*90)

report_leak_exact(train_df["text_clean"], val_df["text_clean"], "Train∩Val")
report_leak_exact(train_df["text_clean"], test_df["text_clean"], "Train∩Test")
report_leak_exact(val_df["text_clean"], test_df["text_clean"], "Val∩Test")

stats_tv = report_leak_near(train_df["text_clean"].tolist(), val_df["text_clean"].tolist(),
                            threshold=CONFIG["near_dup_threshold"],
                            ngram_range=CONFIG["char_ngram_range"],
                            min_df=CONFIG["min_df"])
stats_tt = report_leak_near(train_df["text_clean"].tolist(), test_df["text_clean"].tolist(),
                            threshold=CONFIG["near_dup_threshold"],
                            ngram_range=CONFIG["char_ngram_range"],
                            min_df=CONFIG["min_df"])
stats_vt = report_leak_near(val_df["text_clean"].tolist(), test_df["text_clean"].tolist(),
                            threshold=CONFIG["near_dup_threshold"],
                            ngram_range=CONFIG["char_ngram_range"],
                            min_df=CONFIG["min_df"])

print("\nNear-dup cosine stats (char-ngram TFIDF):")
print("Val->Train:", stats_tv)
print("Test->Train:", stats_tt)
print("Test->Val:", stats_vt)

# ============================================================
# DATASET / DATALOADER
# ============================================================

class NewsDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = list(texts)
        self.labels = list(labels)
        self.tok = tokenizer
        self.max_length = int(max_length)

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = int(self.labels[idx])
        enc = self.tok(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        input_ids = enc["input_ids"].squeeze(0).to(torch.long)
        attn = enc["attention_mask"].squeeze(0).to(torch.long)

        mx = int(input_ids.max().item())
        if mx >= len(self.tok):
            raise ValueError(f"Bad token id: max={mx} >= len(tokenizer)={len(self.tok)}")

        return {
            "input_ids": input_ids,
            "attention_mask": attn,
            "label": torch.tensor(label, dtype=torch.long),
        }

def make_loaders(tokenizer, max_length, batch_size):
    train_loader = DataLoader(
        NewsDataset(train_df["text_clean"].values, train_df["label"].values, tokenizer, max_length),
        batch_size=int(batch_size), shuffle=True
    )
    val_loader = DataLoader(
        NewsDataset(val_df["text_clean"].values, val_df["label"].values, tokenizer, max_length),
        batch_size=int(batch_size), shuffle=False
    )
    test_loader = DataLoader(
        NewsDataset(test_df["text_clean"].values, test_df["label"].values, tokenizer, max_length),
        batch_size=int(batch_size), shuffle=False
    )
    return train_loader, val_loader, test_loader

def compute_class_weights(y):
    counts = np.bincount(y, minlength=2)
    weights = counts.sum() / (2.0 * np.maximum(counts, 1))
    return counts, torch.tensor(weights, dtype=torch.float32, device=CONFIG["device"])

def freeze_backbone(model, model_type: str, freeze_layers: int, freeze_embeddings: bool):
    if freeze_layers <= 0 and not freeze_embeddings:
        return
    if model_type == "electra":
        if freeze_embeddings and hasattr(model, "electra") and hasattr(model.electra, "embeddings"):
            for p in model.electra.embeddings.parameters():
                p.requires_grad = False
        if hasattr(model, "electra") and hasattr(model.electra, "encoder"):
            layers = model.electra.encoder.layer
            for i, layer in enumerate(layers):
                if i < freeze_layers:
                    for p in layer.parameters():
                        p.requires_grad = False
    else:
        base = model.roberta if hasattr(model, "roberta") else (model.bert if hasattr(model, "bert") else None)
        if base is None:
            return
        if freeze_embeddings and hasattr(base, "embeddings"):
            for p in base.embeddings.parameters():
                p.requires_grad = False
        if hasattr(base, "encoder") and hasattr(base.encoder, "layer"):
            layers = base.encoder.layer
            for i, layer in enumerate(layers):
                if i < freeze_layers:
                    for p in layer.parameters():
                        p.requires_grad = False

@torch.no_grad()
def infer_probs(model, dataloader):
    model.eval()
    p_fake_list, conf_list = [], []
    for batch in tqdm(dataloader, desc="Infer", leave=False):
        input_ids = batch["input_ids"].to(CONFIG["device"])
        attention_mask = batch["attention_mask"].to(CONFIG["device"])
        out = model(input_ids=input_ids, attention_mask=attention_mask)
        probs = F.softmax(out.logits, dim=1)
        p_fake = probs[:, 1].detach().cpu().numpy()
        conf = np.maximum(p_fake, 1.0 - p_fake)
        p_fake_list.extend(p_fake.tolist())
        conf_list.extend(conf.tolist())
    return np.array(p_fake_list, dtype=np.float32), np.array(conf_list, dtype=np.float32)

def train_backbone(model_name, model_type, epochs, max_length, batch_size, freeze_layers, freeze_embeddings):
    print("\n" + "="*90)
    print(f"🤖 TRAINING BACKBONE: {model_name}")
    print("="*90)

    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
    train_loader, val_loader, test_loader = make_loaders(tokenizer, max_length, batch_size)

    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(CONFIG["device"])
    try:
        model.resize_token_embeddings(len(tokenizer))
    except Exception:
        pass

    try:
        model.config.hidden_dropout_prob = float(CONFIG["dropout"])
        model.config.attention_probs_dropout_prob = float(CONFIG["dropout"])
    except Exception:
        pass

    freeze_backbone(model, model_type, int(freeze_layers), bool(freeze_embeddings))

    y_train = train_df["label"].values.astype(int)
    counts, class_w = compute_class_weights(y_train)
    print("Class counts [REAL, FAKE]:", counts, "| class_weights:", class_w.detach().cpu().numpy())
    criterion = nn.CrossEntropyLoss(weight=class_w)

    optimizer = AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=float(CONFIG["learning_rate"]),
        weight_decay=float(CONFIG["weight_decay"]),
    )

    total_steps = max(len(train_loader) * int(epochs), 1)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(CONFIG["warmup_steps"]),
        num_training_steps=total_steps
    )

    best_state, best_f1 = None, -1.0

    for ep in range(int(epochs)):
        model.train()
        total_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Train ep{ep+1}", leave=False):
            input_ids = batch["input_ids"].to(CONFIG["device"])
            attention_mask = batch["attention_mask"].to(CONFIG["device"])
            labels = batch["label"].to(CONFIG["device"])

            optimizer.zero_grad(set_to_none=True)
            out = model(input_ids=input_ids, attention_mask=attention_mask)
            loss = criterion(out.logits, labels)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            total_loss += float(loss.item())

        p_val, _ = infer_probs(model, val_loader)
        pred_val = (p_val >= 0.5).astype(int)
        f1m = f1_score(val_df["label"].values.astype(int), pred_val, average="macro")
        acc = accuracy_score(val_df["label"].values.astype(int), pred_val)
        print(f"Epoch {ep+1}/{epochs} | loss={total_loss/max(len(train_loader),1):.4f} | val_acc={acc*100:.2f}% | val_macroF1={f1m:.4f}")

        if f1m > best_f1:
            best_f1 = f1m
            best_state = copy.deepcopy(model.state_dict())
            print("✅ New best")

    if best_state is not None:
        model.load_state_dict(best_state)

    p_val, c_val = infer_probs(model, val_loader)
    p_test, c_test = infer_probs(model, test_loader)

    return {
        "name": model_name,
        "tokenizer": tokenizer,
        "model": model,
        "p_val": p_val,
        "c_val": c_val,
        "p_test": p_test,
        "c_test": c_test,
        "max_length": int(max_length),
    }

# ============================================================
# STAGE 1: TRAIN 2 BACKBONES (same as your old flow)
# ============================================================

electra_pack = train_backbone(
    model_name=CONFIG["electra_name"],
    model_type="electra",
    epochs=CONFIG["epochs_electra"],
    max_length=CONFIG["max_length_electra"],
    batch_size=CONFIG["batch_size_electra"],
    freeze_layers=CONFIG["freeze_electra_layers"],
    freeze_embeddings=CONFIG["freeze_electra_embeddings"],
)

phobert_pack = train_backbone(
    model_name=CONFIG["phobert_name"],
    model_type="roberta",
    epochs=CONFIG["epochs_phobert"],
    max_length=CONFIG["max_length_phobert"],
    batch_size=CONFIG["batch_size_phobert"],
    freeze_layers=CONFIG["freeze_phobert_layers"],
    freeze_embeddings=CONFIG["freeze_phobert_embeddings"],
)

# ============================================================
# STYLE FEATURES for GATE (optional but recommended)
# ============================================================

CLICKBAIT_WORDS = [
    "hot", "sốc", "shock", "khẩn", "cực sốc", "gây sốc", "bất ngờ", "không ngờ",
    "lộ diện", "lộ", "thực hư", "sự thật", "cảnh báo", "ngay", "lập tức",
    "chấn động", "rúng động", "kinh hoàng", "đừng bỏ qua", "xem ngay",
    "100%", "cam kết", "bạn sẽ", "không tin nổi", "tức giận", "phẫn nộ",
]

def style_features(text: str):
    t = str(text)
    tl = t.lower()

    exclam = t.count("!")
    ques = t.count("?")
    dots = t.count("...")
    allcaps = sum(1 for ch in t if ch.isalpha() and ch.isupper())
    alpha = sum(1 for ch in t if ch.isalpha())
    caps_ratio = (allcaps / max(alpha, 1))

    n_words = len(re.findall(r"\w+", t))
    n_chars = len(t)

    cb = 0
    for w in CLICKBAIT_WORDS:
        if w in tl:
            cb += 1

    # returns fixed-dim numeric vector
    return np.array([
        exclam, ques, dots,
        caps_ratio,
        n_words,
        n_chars,
        cb
    ], dtype=np.float32)

def build_style_matrix(texts):
    feats = np.stack([style_features(x) for x in texts], axis=0)
    # normalize simple (z-score)
    mu = feats.mean(axis=0, keepdims=True)
    sd = feats.std(axis=0, keepdims=True) + 1e-6
    return (feats - mu) / sd, mu, sd

# ============================================================
# GATE DATASET: align 2 tokenizers per sample
# ============================================================

class GateDataset(Dataset):
    def __init__(self, texts, labels, tokE, tokP, maxE, maxP, style_mat=None):
        self.texts = list(texts)
        self.labels = list(labels)
        self.tokE = tokE
        self.tokP = tokP
        self.maxE = int(maxE)
        self.maxP = int(maxP)
        self.style_mat = style_mat  # np array [n, d] or None

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        y = int(self.labels[idx])

        encE = self.tokE(
            text, max_length=self.maxE, padding="max_length", truncation=True, return_tensors="pt"
        )
        encP = self.tokP(
            text, max_length=self.maxP, padding="max_length", truncation=True, return_tensors="pt"
        )

        item = {
            "e_input_ids": encE["input_ids"].squeeze(0).to(torch.long),
            "e_attn": encE["attention_mask"].squeeze(0).to(torch.long),
            "p_input_ids": encP["input_ids"].squeeze(0).to(torch.long),
            "p_attn": encP["attention_mask"].squeeze(0).to(torch.long),
            "label": torch.tensor(y, dtype=torch.long),
        }

        if self.style_mat is not None:
            item["style"] = torch.tensor(self.style_mat[idx], dtype=torch.float32)

        return item

def make_gate_loaders(tokE, tokP, style_train=None, style_val=None, style_test=None, bs=8):
    train_loader = DataLoader(
        GateDataset(train_df["text_clean"].values, train_df["label"].values,
                    tokE, tokP, CONFIG["max_length_electra"], CONFIG["max_length_phobert"], style_train),
        batch_size=int(bs), shuffle=True
    )
    val_loader = DataLoader(
        GateDataset(val_df["text_clean"].values, val_df["label"].values,
                    tokE, tokP, CONFIG["max_length_electra"], CONFIG["max_length_phobert"], style_val),
        batch_size=int(bs), shuffle=False
    )
    test_loader = DataLoader(
        GateDataset(test_df["text_clean"].values, test_df["label"].values,
                    tokE, tokP, CONFIG["max_length_electra"], CONFIG["max_length_phobert"], style_test),
        batch_size=int(bs), shuffle=False
    )
    return train_loader, val_loader, test_loader

# ============================================================
# MoE GATING MODEL
# ============================================================

class MoEGateNet(nn.Module):
    """
    Learnable gating:
      g = sigmoid(MLP([CLS_E, CLS_P, style?]))
      p = g*pE + (1-g)*pP
    Backbones are passed in and typically frozen in stage 2.
    """
    def __init__(self, electra_model, phobert_model, hidden=256, dropout=0.3, style_dim=0):
        super().__init__()
        self.e = electra_model
        self.p = phobert_model

        # Infer hidden sizes
        e_h = self.e.config.hidden_size
        p_h = self.p.config.hidden_size
        in_dim = e_h + p_h + style_dim

        self.gate = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, 1)
        )

    def forward(self, e_input_ids, e_attn, p_input_ids, p_attn, style=None):
        # Get logits/probs from both experts
        outE = self.e(input_ids=e_input_ids, attention_mask=e_attn, output_hidden_states=True, return_dict=True)
        outP = self.p(input_ids=p_input_ids, attention_mask=p_attn, output_hidden_states=True, return_dict=True)

        # CLS embeddings (last hidden layer, position 0)
        hE = outE.hidden_states[-1][:, 0, :]  # [B, e_h]
        hP = outP.hidden_states[-1][:, 0, :]  # [B, p_h]

        if style is not None:
            x = torch.cat([hE, hP, style], dim=1)
        else:
            x = torch.cat([hE, hP], dim=1)

        g_logit = self.gate(x)              # [B,1]
        g = torch.sigmoid(g_logit)          # (0..1)

        pE = F.softmax(outE.logits, dim=1)  # [B,2]
        pP = F.softmax(outP.logits, dim=1)  # [B,2]

        pMix = g * pE + (1.0 - g) * pP      # [B,2]
        return pMix, g

def set_requires_grad(model, flag: bool):
    for p in model.parameters():
        p.requires_grad = flag

# ============================================================
# STAGE 2: TRAIN GATE (freeze 2 backbones)
# ============================================================

print("\n" + "="*90)
print("🧠 STAGE 2: TRAIN MoE GATE (end-to-end gating, no ensemble grid/override)")
print("="*90)

# Tokenizers from packs
tokE = electra_pack["tokenizer"]
tokP = phobert_pack["tokenizer"]

# Build style matrices (optional)
style_dim = 0
style_train = style_val = style_test = None
if CONFIG["use_style_features"]:
    tr, mu, sd = build_style_matrix(train_df["text_clean"].values)
    va = (np.stack([style_features(x) for x in val_df["text_clean"].values], axis=0) - mu) / sd
    te = (np.stack([style_features(x) for x in test_df["text_clean"].values], axis=0) - mu) / sd
    style_train, style_val, style_test = tr.astype(np.float32), va.astype(np.float32), te.astype(np.float32)
    style_dim = style_train.shape[1]
    print(f"✅ Using style features for gate: dim={style_dim}")

gate_train_loader, gate_val_loader, gate_test_loader = make_gate_loaders(
    tokE, tokP,
    style_train=style_train, style_val=style_val, style_test=style_test,
    bs=min(CONFIG["batch_size_phobert"], 8)  # gate batch nhỏ cho an toàn VRAM
)

# Build gate model
gate_model = MoEGateNet(
    electra_pack["model"],
    phobert_pack["model"],
    hidden=int(CONFIG["gate_hidden"]),
    dropout=float(CONFIG["gate_dropout"]),
    style_dim=int(style_dim),
).to(CONFIG["device"])

# Freeze both experts for stable gate training (recommended for ~2k)
set_requires_grad(gate_model.e, False)
set_requires_grad(gate_model.p, False)

# Only train gate parameters
optimizer = AdamW(gate_model.gate.parameters(), lr=float(CONFIG["gate_lr"]), weight_decay=float(CONFIG["weight_decay"]))

total_steps = max(len(gate_train_loader) * int(CONFIG["epochs_gate"]), 1)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(CONFIG["gate_warmup_steps"]),
    num_training_steps=total_steps
)

y_train = train_df["label"].values.astype(int)
counts, class_w = compute_class_weights(y_train)
criterion = nn.CrossEntropyLoss(weight=class_w)

def gate_entropy(g):
    # g in (0,1) -> entropy per sample
    eps = 1e-6
    g = torch.clamp(g, eps, 1-eps)
    H = -(g*torch.log(g) + (1-g)*torch.log(1-g))
    return H.mean()

@torch.no_grad()
def eval_gate(model, loader, y_true, use_style):
    model.eval()
    ps = []
    gs = []
    for batch in tqdm(loader, desc="EvalGate", leave=False):
        e_input_ids = batch["e_input_ids"].to(CONFIG["device"])
        e_attn = batch["e_attn"].to(CONFIG["device"])
        p_input_ids = batch["p_input_ids"].to(CONFIG["device"])
        p_attn = batch["p_attn"].to(CONFIG["device"])
        style = batch["style"].to(CONFIG["device"]) if (use_style and "style" in batch) else None

        pMix, g = model(e_input_ids, e_attn, p_input_ids, p_attn, style=style)
        ps.append(pMix.detach().cpu())
        gs.append(g.detach().cpu())

    pMix = torch.cat(ps, dim=0).numpy()
    g_all = torch.cat(gs, dim=0).numpy().reshape(-1)

    p_fake = pMix[:, 1]
    pred = (p_fake >= 0.5).astype(int)

    acc = accuracy_score(y_true, pred)
    f1m = f1_score(y_true, pred, average="macro")
    return acc, f1m, float(g_all.mean()), float(np.quantile(g_all, 0.1)), float(np.quantile(g_all, 0.9))

best_state = None
best_f1 = -1.0

for ep in range(int(CONFIG["epochs_gate"])):
    gate_model.train()
    total_loss = 0.0

    for batch in tqdm(gate_train_loader, desc=f"GateTrain ep{ep+1}", leave=False):
        e_input_ids = batch["e_input_ids"].to(CONFIG["device"])
        e_attn = batch["e_attn"].to(CONFIG["device"])
        p_input_ids = batch["p_input_ids"].to(CONFIG["device"])
        p_attn = batch["p_attn"].to(CONFIG["device"])
        labels = batch["label"].to(CONFIG["device"])
        style = batch["style"].to(CONFIG["device"]) if (CONFIG["use_style_features"] and "style" in batch) else None

        optimizer.zero_grad(set_to_none=True)
        pMix, g = gate_model(e_input_ids, e_attn, p_input_ids, p_attn, style=style)

        loss_cls = criterion(pMix, labels)
        loss_ent = -float(CONFIG["entropy_reg"]) * gate_entropy(g)  # maximize entropy a bit
        loss = loss_cls + loss_ent

        loss.backward()
        torch.nn.utils.clip_grad_norm_(gate_model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        total_loss += float(loss.item())

    y_val = val_df["label"].values.astype(int)
    acc, f1m, gmean, g10, g90 = eval_gate(gate_model, gate_val_loader, y_val, CONFIG["use_style_features"])
    print(f"Gate Epoch {ep+1}/{CONFIG['epochs_gate']} | loss={total_loss/max(len(gate_train_loader),1):.4f} "
          f"| val_acc={acc*100:.2f}% | val_macroF1={f1m:.4f} | gate_mean={gmean:.3f} (p10={g10:.3f}, p90={g90:.3f})")

    if f1m > best_f1:
        best_f1 = f1m
        best_state = copy.deepcopy(gate_model.state_dict())
        print("✅ New best gate")

if best_state is not None:
    gate_model.load_state_dict(best_state)

# ============================================================
# FINAL EVAL ON TEST
# ============================================================

print("\n" + "="*90)
print("🎯 FINAL EVALUATION ON TEST (MoE GATING, LEAK-SAFE)")
print("="*90)

y_test = test_df["label"].values.astype(int)
acc, f1m, gmean, g10, g90 = eval_gate(gate_model, gate_test_loader, y_test, CONFIG["use_style_features"])

# build predictions + report
gate_model.eval()
p_list = []
g_list = []
for batch in tqdm(gate_test_loader, desc="InferTest", leave=False):
    e_input_ids = batch["e_input_ids"].to(CONFIG["device"])
    e_attn = batch["e_attn"].to(CONFIG["device"])
    p_input_ids = batch["p_input_ids"].to(CONFIG["device"])
    p_attn = batch["p_attn"].to(CONFIG["device"])
    style = batch["style"].to(CONFIG["device"]) if (CONFIG["use_style_features"] and "style" in batch) else None
    pMix, g = gate_model(e_input_ids, e_attn, p_input_ids, p_attn, style=style)
    p_list.append(pMix.detach().cpu())
    g_list.append(g.detach().cpu())

pMix = torch.cat(p_list, dim=0).numpy()
g_all = torch.cat(g_list, dim=0).numpy().reshape(-1)

p_fake = pMix[:, 1]
pred_test = (p_fake >= 0.5).astype(int)

f1_fake = f1_score(y_test, pred_test, pos_label=1, average="binary")
f1_real = f1_score(y_test, pred_test, pos_label=0, average="binary")
macro = (f1_fake + f1_real) / 2
acc = accuracy_score(y_test, pred_test)

print(f"Accuracy:    {acc*100:.2f}%")
print(f"Macro-F1:    {macro:.4f}")
print(f"F1 FAKE(1):  {f1_fake:.4f}")
print(f"F1 REAL(0):  {f1_real:.4f}")
print(f"Gap:         {abs(f1_fake - f1_real):.4f}")
print(f"Gate mean g: {g_all.mean():.3f} | p10={np.quantile(g_all,0.1):.3f} | p90={np.quantile(g_all,0.9):.3f}")
print("\n📋 Classification Report:")
print(classification_report(y_test, pred_test, target_names=["REAL (0)", "FAKE (1)"], digits=4))

# ============================================================
# QUICK INTERACTIVE PREDICT (optional)
# ============================================================

@torch.no_grad()
def predict_one(text: str):
    text = clean_text(text)
    if not is_valid(text):
        return {"error": "Text too short after cleaning"}

    # style
    st = None
    if CONFIG["use_style_features"]:
        # reuse mu/sd computed earlier if available
        feat = style_features(text)
        # if mu/sd not defined (should be), fallback no-norm
        try:
            st = torch.tensor(((feat - mu.squeeze(0)) / sd.squeeze(0)), dtype=torch.float32).unsqueeze(0).to(CONFIG["device"])
        except Exception:
            st = torch.tensor(feat, dtype=torch.float32).unsqueeze(0).to(CONFIG["device"])

    encE = tokE(text, max_length=CONFIG["max_length_electra"], padding="max_length", truncation=True, return_tensors="pt")
    encP = tokP(text, max_length=CONFIG["max_length_phobert"], padding="max_length", truncation=True, return_tensors="pt")

    e_input_ids = encE["input_ids"].to(CONFIG["device"])
    e_attn = encE["attention_mask"].to(CONFIG["device"])
    p_input_ids = encP["input_ids"].to(CONFIG["device"])
    p_attn = encP["attention_mask"].to(CONFIG["device"])

    pMix, g = gate_model(e_input_ids, e_attn, p_input_ids, p_attn, style=st)
    pMix = pMix.squeeze(0).detach().cpu().numpy()
    g = float(g.squeeze(0).detach().cpu().item())

    p_fake = float(pMix[1])
    pred = 1 if p_fake >= 0.5 else 0
    conf = max(p_fake, 1 - p_fake)

    return {
        "pred": pred,
        "p_fake": p_fake,
        "conf": conf,
        "gate_g": g,  # g gần 1 => thiên về vELECTRA, gần 0 => thiên về PhoBERT
        "note": "gate_g≈1 => dựa nhiều vELECTRA (văn phong); gate_g≈0 => dựa nhiều PhoBERT (ngữ nghĩa)"
    }

print("\n✅ Ready. Try: predict_one('Tin sốc: ...')\n")


📦 Installing packages...
🚀 VIETNAMESE FAKE NEWS - MoE GATING (LEAK-SAFE SPLIT)
🖥️ Device: cuda
🧾 Label mapping: 0=REAL, 1=FAKE

📂 LOADING DATA
🧹 Cleaning + exact dedup...
✅ After exact dedup: 1843 (removed 0)
✅ Final: 1843 samples | REAL=974 (52.8%) | FAKE=869 (47.2%)

🧩 BUILDING NEAR-DUP CLUSTERS
✅ Groups: 1806 | Largest group size: 8
Top 5 group sizes:
group
35     8
136    7
26     4
704    2
147    2

✂️ LEAK-SAFE SPLIT (STRATIFIED BY LABEL, GROUP-AWARE)
Picked fold 2 for TEST: size=185 (0.100)
Picked fold 2 for VAL: size=182 (0.110)
Train: 1476 | REAL=776 (52.6%) | FAKE=700 (47.4%) | groups=1444
Val  : 182 | REAL=98 (53.8%) | FAKE=84 (46.2%) | groups=180
Test : 185 | REAL=100 (54.1%) | FAKE=85 (45.9%) | groups=182

Group overlap Train∩Val=0 | Train∩Test=0 | Val∩Test=0

🧪 LEAK REPORT
Leak exact Train∩Val: 0
Leak exact Train∩Test: 0
Leak exact Val∩Test: 0

Near-dup cosine stats (char-ngram TFIDF):
Val->Train: {'mean': 0.3505195379257202, 'median': 0.30619990825653076, 'p95': 0.68805

config.json: 0.00B [00:00, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]



pytorch_model.bin:   0%|          | 0.00/443M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/197 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/443M [00:00<?, ?B/s]

ElectraForSequenceClassification LOAD REPORT from: FPTAI/velectra-base-discriminator-cased
Key                                               | Status     | 
--------------------------------------------------+------------+-
discriminator_predictions.dense_prediction.bias   | UNEXPECTED | 
discriminator_predictions.dense.bias              | UNEXPECTED | 
discriminator_predictions.dense.weight            | UNEXPECTED | 
discriminator_predictions.dense_prediction.weight | UNEXPECTED | 
classifier.out_proj.bias                          | MISSING    | 
classifier.dense.bias                             | MISSING    | 
classifier.out_proj.weight                        | MISSING    | 
classifier.dense.weight                           | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.


Class counts [REAL, FAKE]: [776 700] | class_weights: [0.9510309 1.0542858]


Train ep1:   0%|          | 0/93 [00:00<?, ?it/s]

Infer:   0%|          | 0/12 [00:00<?, ?it/s]

Epoch 1/5 | loss=0.5200 | val_acc=95.60% | val_macroF1=0.9555
✅ New best


Train ep2:   0%|          | 0/93 [00:00<?, ?it/s]

Infer:   0%|          | 0/12 [00:00<?, ?it/s]

Epoch 2/5 | loss=0.1304 | val_acc=93.96% | val_macroF1=0.9386


Train ep3:   0%|          | 0/93 [00:00<?, ?it/s]

Infer:   0%|          | 0/12 [00:00<?, ?it/s]

Epoch 3/5 | loss=0.0672 | val_acc=96.70% | val_macroF1=0.9667
✅ New best


Train ep4:   0%|          | 0/93 [00:00<?, ?it/s]

Infer:   0%|          | 0/12 [00:00<?, ?it/s]

Epoch 4/5 | loss=0.0438 | val_acc=96.15% | val_macroF1=0.9612


Train ep5:   0%|          | 0/93 [00:00<?, ?it/s]

Infer:   0%|          | 0/12 [00:00<?, ?it/s]

Epoch 5/5 | loss=0.0309 | val_acc=96.70% | val_macroF1=0.9667


Infer:   0%|          | 0/12 [00:00<?, ?it/s]

Infer:   0%|          | 0/12 [00:00<?, ?it/s]


🤖 TRAINING BACKBONE: vinai/phobert-base


config.json:   0%|          | 0.00/557 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

bpe.codes: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/543M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/197 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/543M [00:00<?, ?B/s]

RobertaForSequenceClassification LOAD REPORT from: vinai/phobert-base
Key                             | Status     | 
--------------------------------+------------+-
lm_head.layer_norm.bias         | UNEXPECTED | 
lm_head.dense.bias              | UNEXPECTED | 
roberta.embeddings.position_ids | UNEXPECTED | 
roberta.pooler.dense.bias       | UNEXPECTED | 
roberta.pooler.dense.weight     | UNEXPECTED | 
lm_head.decoder.bias            | UNEXPECTED | 
lm_head.dense.weight            | UNEXPECTED | 
lm_head.layer_norm.weight       | UNEXPECTED | 
lm_head.bias                    | UNEXPECTED | 
lm_head.decoder.weight          | UNEXPECTED | 
classifier.out_proj.bias        | MISSING    | 
classifier.dense.bias           | MISSING    | 
classifier.out_proj.weight      | MISSING    | 
classifier.dense.weight         | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initia

Class counts [REAL, FAKE]: [776 700] | class_weights: [0.9510309 1.0542858]


Train ep1:   0%|          | 0/185 [00:00<?, ?it/s]

Infer:   0%|          | 0/23 [00:00<?, ?it/s]

Epoch 1/5 | loss=0.4404 | val_acc=97.25% | val_macroF1=0.9725
✅ New best


Train ep2:   0%|          | 0/185 [00:00<?, ?it/s]

Infer:   0%|          | 0/23 [00:00<?, ?it/s]

Epoch 2/5 | loss=0.1446 | val_acc=98.35% | val_macroF1=0.9834
✅ New best


Train ep3:   0%|          | 0/185 [00:00<?, ?it/s]

Infer:   0%|          | 0/23 [00:00<?, ?it/s]

Epoch 3/5 | loss=0.0748 | val_acc=95.60% | val_macroF1=0.9555


Train ep4:   0%|          | 0/185 [00:00<?, ?it/s]

Infer:   0%|          | 0/23 [00:00<?, ?it/s]

Epoch 4/5 | loss=0.0437 | val_acc=98.35% | val_macroF1=0.9834


Train ep5:   0%|          | 0/185 [00:00<?, ?it/s]

Infer:   0%|          | 0/23 [00:00<?, ?it/s]

Epoch 5/5 | loss=0.0352 | val_acc=98.35% | val_macroF1=0.9834


Infer:   0%|          | 0/23 [00:00<?, ?it/s]

Infer:   0%|          | 0/24 [00:00<?, ?it/s]


🧠 STAGE 2: TRAIN MoE GATE (end-to-end gating, no ensemble grid/override)
✅ Using style features for gate: dim=7


GateTrain ep1:   0%|          | 0/185 [00:00<?, ?it/s]

EvalGate:   0%|          | 0/23 [00:00<?, ?it/s]

Gate Epoch 1/15 | loss=0.3255 | val_acc=96.70% | val_macroF1=0.9667 | gate_mean=0.583 (p10=0.463, p90=0.764)
✅ New best gate


GateTrain ep2:   0%|          | 0/185 [00:00<?, ?it/s]

EvalGate:   0%|          | 0/23 [00:00<?, ?it/s]

Gate Epoch 2/15 | loss=0.3228 | val_acc=97.25% | val_macroF1=0.9723 | gate_mean=0.636 (p10=0.531, p90=0.792)
✅ New best gate


GateTrain ep3:   0%|          | 0/185 [00:00<?, ?it/s]

EvalGate:   0%|          | 0/23 [00:00<?, ?it/s]

Gate Epoch 3/15 | loss=0.3226 | val_acc=97.25% | val_macroF1=0.9723 | gate_mean=0.603 (p10=0.456, p90=0.863)


GateTrain ep4:   0%|          | 0/185 [00:00<?, ?it/s]

EvalGate:   0%|          | 0/23 [00:00<?, ?it/s]

Gate Epoch 4/15 | loss=0.3223 | val_acc=97.25% | val_macroF1=0.9723 | gate_mean=0.548 (p10=0.391, p90=0.767)


GateTrain ep5:   0%|          | 0/185 [00:00<?, ?it/s]

EvalGate:   0%|          | 0/23 [00:00<?, ?it/s]

Gate Epoch 5/15 | loss=0.3216 | val_acc=97.25% | val_macroF1=0.9723 | gate_mean=0.578 (p10=0.416, p90=0.812)


GateTrain ep6:   0%|          | 0/185 [00:00<?, ?it/s]

EvalGate:   0%|          | 0/23 [00:00<?, ?it/s]

Gate Epoch 6/15 | loss=0.3219 | val_acc=97.25% | val_macroF1=0.9723 | gate_mean=0.528 (p10=0.359, p90=0.785)


GateTrain ep7:   0%|          | 0/185 [00:00<?, ?it/s]

EvalGate:   0%|          | 0/23 [00:00<?, ?it/s]

Gate Epoch 7/15 | loss=0.3212 | val_acc=97.25% | val_macroF1=0.9723 | gate_mean=0.578 (p10=0.438, p90=0.827)


GateTrain ep8:   0%|          | 0/185 [00:00<?, ?it/s]

EvalGate:   0%|          | 0/23 [00:00<?, ?it/s]

Gate Epoch 8/15 | loss=0.3222 | val_acc=97.25% | val_macroF1=0.9723 | gate_mean=0.583 (p10=0.414, p90=0.836)


GateTrain ep9:   0%|          | 0/185 [00:00<?, ?it/s]

EvalGate:   0%|          | 0/23 [00:00<?, ?it/s]

Gate Epoch 9/15 | loss=0.3213 | val_acc=97.25% | val_macroF1=0.9723 | gate_mean=0.647 (p10=0.471, p90=0.815)


GateTrain ep10:   0%|          | 0/185 [00:00<?, ?it/s]

EvalGate:   0%|          | 0/23 [00:00<?, ?it/s]

Gate Epoch 10/15 | loss=0.3210 | val_acc=97.25% | val_macroF1=0.9723 | gate_mean=0.553 (p10=0.405, p90=0.766)


GateTrain ep11:   0%|          | 0/185 [00:00<?, ?it/s]

EvalGate:   0%|          | 0/23 [00:00<?, ?it/s]

Gate Epoch 11/15 | loss=0.3200 | val_acc=97.25% | val_macroF1=0.9723 | gate_mean=0.617 (p10=0.468, p90=0.804)


GateTrain ep12:   0%|          | 0/185 [00:00<?, ?it/s]

EvalGate:   0%|          | 0/23 [00:00<?, ?it/s]

Gate Epoch 12/15 | loss=0.3202 | val_acc=97.25% | val_macroF1=0.9723 | gate_mean=0.550 (p10=0.398, p90=0.712)


GateTrain ep13:   0%|          | 0/185 [00:00<?, ?it/s]

EvalGate:   0%|          | 0/23 [00:00<?, ?it/s]

Gate Epoch 13/15 | loss=0.3201 | val_acc=97.25% | val_macroF1=0.9723 | gate_mean=0.599 (p10=0.446, p90=0.753)


GateTrain ep14:   0%|          | 0/185 [00:00<?, ?it/s]

EvalGate:   0%|          | 0/23 [00:00<?, ?it/s]

Gate Epoch 14/15 | loss=0.3204 | val_acc=97.25% | val_macroF1=0.9723 | gate_mean=0.585 (p10=0.443, p90=0.745)


GateTrain ep15:   0%|          | 0/185 [00:00<?, ?it/s]

EvalGate:   0%|          | 0/23 [00:00<?, ?it/s]

Gate Epoch 15/15 | loss=0.3201 | val_acc=97.25% | val_macroF1=0.9723 | gate_mean=0.580 (p10=0.426, p90=0.745)

🎯 FINAL EVALUATION ON TEST (MoE GATING, LEAK-SAFE)


EvalGate:   0%|          | 0/24 [00:00<?, ?it/s]

InferTest:   0%|          | 0/24 [00:00<?, ?it/s]

Accuracy:    97.84%
Macro-F1:    0.9782
F1 FAKE(1):  0.9762
F1 REAL(0):  0.9802
Gap:         0.0040
Gate mean g: 0.619 | p10=0.522 | p90=0.683

📋 Classification Report:
              precision    recall  f1-score   support

    REAL (0)     0.9706    0.9900    0.9802       100
    FAKE (1)     0.9880    0.9647    0.9762        85

    accuracy                         0.9784       185
   macro avg     0.9793    0.9774    0.9782       185
weighted avg     0.9786    0.9784    0.9784       185


✅ Ready. Try: predict_one('Tin sốc: ...')



In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, f1_score
from tqdm.auto import tqdm

DEVICE = CONFIG["device"]

# ---------- helper: style feature + normalize ----------
def _style_vec(text):
    if not CONFIG.get("use_style_features", False):
        return None
    feat = style_features(text).astype(np.float32)
    # normalize using mu/sd from training (created in your MoE code)
    try:
        feat = (feat - mu.squeeze(0)) / sd.squeeze(0)
    except Exception:
        pass
    return torch.tensor(feat, dtype=torch.float32).unsqueeze(0).to(DEVICE)

@torch.no_grad()
def predict_one_moe(text: str):
    text = clean_text(text)
    if not is_valid(text):
        return {
            "ok": False,
            "error": "Text too short after cleaning",
            "text": text
        }

    st = _style_vec(text)

    encE = tokE(text, max_length=CONFIG["max_length_electra"], padding="max_length",
               truncation=True, return_tensors="pt")
    encP = tokP(text, max_length=CONFIG["max_length_phobert"], padding="max_length",
               truncation=True, return_tensors="pt")

    e_input_ids = encE["input_ids"].to(DEVICE)
    e_attn      = encE["attention_mask"].to(DEVICE)
    p_input_ids = encP["input_ids"].to(DEVICE)
    p_attn      = encP["attention_mask"].to(DEVICE)

    gate_model.eval()
    pMix, g = gate_model(e_input_ids, e_attn, p_input_ids, p_attn, style=st)

    probs = pMix.squeeze(0).detach().cpu().numpy()
    g = float(g.squeeze(0).detach().cpu().item())

    p_fake = float(probs[1])
    pred = 1 if p_fake >= 0.5 else 0
    conf = max(p_fake, 1 - p_fake)

    # simple "style red flags"
    sf = style_features(text)
    flags = []
    if sf[0] >= 2: flags.append("many !")
    if sf[1] >= 2: flags.append("many ?")
    if sf[3] >= 0.25: flags.append("high CAPS ratio")
    if sf[6] >= 2: flags.append("clickbait keywords")

    return {
        "ok": True,
        "pred": pred,
        "p_fake": p_fake,
        "conf": conf,
        "gate_g": g,  # ~1 => leaning vELECTRA (style); ~0 => leaning PhoBERT (semantics)
        "flags": flags,
        "text": text
    }

# ---------- 1) TEST SET EVALUATION + error analysis ----------
@torch.no_grad()
def evaluate_on_test_moe(top_k_errors=15):
    y_true = test_df["label"].values.astype(int)
    texts  = test_df["text_clean"].values

    preds = []
    p_fakes = []
    confs = []
    gs = []

    for t in tqdm(texts, desc="MoE test inference"):
        out = predict_one_moe(t)
        if not out["ok"]:
            # if invalid, fallback as REAL with low conf (rare in your cleaned test)
            preds.append(0); p_fakes.append(0.0); confs.append(0.5); gs.append(0.5)
            continue
        preds.append(out["pred"])
        p_fakes.append(out["p_fake"])
        confs.append(out["conf"])
        gs.append(out["gate_g"])

    preds = np.array(preds, dtype=int)
    p_fakes = np.array(p_fakes, dtype=float)
    confs = np.array(confs, dtype=float)
    gs = np.array(gs, dtype=float)

    acc = accuracy_score(y_true, preds)
    f1_fake = f1_score(y_true, preds, pos_label=1)
    f1_real = f1_score(y_true, preds, pos_label=0)
    macro = (f1_fake + f1_real) / 2

    print("="*90)
    print("📌 TEST SET METRICS (MoE)")
    print("="*90)
    print(f"Accuracy: {acc*100:.2f}%")
    print(f"Macro-F1: {macro:.4f} | F1_FAKE={f1_fake:.4f} | F1_REAL={f1_real:.4f} | Gap={abs(f1_fake-f1_real):.4f}")
    print(f"Gate g mean={gs.mean():.3f} | p10={np.quantile(gs,0.1):.3f} | p90={np.quantile(gs,0.9):.3f}")

    cm = confusion_matrix(y_true, preds, labels=[0,1])
    print("\nConfusion matrix [rows=true 0/1, cols=pred 0/1]:")
    print(cm)

    print("\nClassification report:")
    print(classification_report(y_true, preds, target_names=["REAL (0)", "FAKE (1)"], digits=4))

    # --- top errors by confidence (model very sure but wrong) ---
    wrong = np.where(preds != y_true)[0]
    if len(wrong) == 0:
        print("\n✅ No errors on test (unlikely but possible).")
        return

    # rank by confidence descending
    wrong_sorted = wrong[np.argsort(-confs[wrong])]
    print("\n" + "="*90)
    print(f"❌ TOP {min(top_k_errors, len(wrong_sorted))} WRONG PREDICTIONS (highest confidence first)")
    print("="*90)

    for i in wrong_sorted[:top_k_errors]:
        y = y_true[i]
        pr = preds[i]
        pf = p_fakes[i]
        cf = confs[i]
        gg = gs[i]
        snippet = texts[i][:260].replace("\n"," ")
        print(f"\nIDX={i} | TRUE={y} | PRED={pr} | p_fake={pf:.3f} | conf={cf:.3f} | gate_g={gg:.3f}")
        print(f"TEXT: {snippet}...")

# ---------- 2) GENERALIZATION: 10 “ngoài đời” samples ----------
GENERALIZATION_10 = [
    # 1 (giật tít, rất fake-ish)
    "TIN SỐC!!! Chỉ cần uống nước chanh theo cách này 3 ngày là khỏi hoàn toàn tiểu đường, bác sĩ cũng bất ngờ. Xem ngay!!!",
    # 2 (tin thật kiểu thông báo hành chính)
    "Sở Giao thông Vận tải TP.HCM thông báo điều chỉnh tổ chức giao thông một số tuyến đường khu vực trung tâm để phục vụ thi công, thời gian áp dụng từ ngày 10/02.",
    # 3 (fake khoa học kiểu phóng đại)
    "Các nhà khoa học xác nhận người ngoài hành tinh đã hạ cánh ở Việt Nam và để lại thiết bị lạ, video bằng chứng đang lan truyền mạnh.",
    # 4 (tin kinh tế trung tính)
    "Giá vàng trong nước sáng nay biến động nhẹ, nhiều doanh nghiệp điều chỉnh tăng/giảm vài chục nghìn đồng mỗi lượng so với cuối ngày hôm qua.",
    # 5 (fake “cảnh báo” chain message)
    "CẢNH BÁO KHẨN: Ai nhận cuộc gọi số lạ đọc 3 số cuối CCCD sẽ bị trừ tiền tài khoản ngay lập tức. Hãy chia sẻ để cứu mọi người!",
    # 6 (tin xã hội trung tính)
    "Công an cho biết đang xác minh thông tin lan truyền trên mạng liên quan đến vụ việc tại một khu dân cư, đồng thời đề nghị người dân không chia sẻ thông tin chưa kiểm chứng.",
    # 7 (fake kiểu “mánh làm giàu”)
    "Không cần vốn, chỉ với điện thoại bạn có thể kiếm 5 triệu/ngày đảm bảo 100%. Ai cũng làm được, đăng ký ngay kẻo lỡ!",
    # 8 (tin thời tiết/khí tượng trung tính)
    "Trung tâm dự báo khí tượng thủy văn nhận định trong vài ngày tới, khu vực Nam Bộ có mưa rào và dông rải rác vào chiều tối.",
    # 9 (tin y tế có vẻ thật nhưng kiểm chứng khó)
    "Bộ Y tế khuyến cáo người dân tiêm nhắc lại vắc-xin theo hướng dẫn và theo dõi thông tin từ các nguồn chính thống khi có dịch bệnh.",
    # 10 (fake kiểu “thực hư/giật gân”)
    "Thực hư chuyện một loại nước ngọt đang bị cấm bán vì gây ung thư ngay lập tức? Nhiều người hoang mang, sự thật khiến ai cũng sốc!"
]

def generalization_check_10(samples=GENERALIZATION_10):
    print("\n" + "="*90)
    print("🌍 GENERALIZATION CHECK (10 'ngoài đời' samples)")
    print("="*90)
    print("Legend: pred 0=REAL, 1=FAKE | gate_g≈1 => leaning vELECTRA(style) | gate_g≈0 => leaning PhoBERT(semantics)\n")

    for i, s in enumerate(samples, 1):
        out = predict_one_moe(s)
        if not out["ok"]:
            print(f"{i:02d}. [INVALID] {out.get('error')}")
            continue
        print(f"{i:02d}. PRED={out['pred']} | p_fake={out['p_fake']:.3f} | conf={out['conf']:.3f} | gate_g={out['gate_g']:.3f} | flags={out['flags']}")
        print(f"    TEXT: {s}")

# ---------- run both ----------
evaluate_on_test_moe(top_k_errors=12)
generalization_check_10()


MoE test inference:   0%|          | 0/185 [00:00<?, ?it/s]

📌 TEST SET METRICS (MoE)
Accuracy: 97.84%
Macro-F1: 0.9782 | F1_FAKE=0.9762 | F1_REAL=0.9802 | Gap=0.0040
Gate g mean=0.619 | p10=0.522 | p90=0.683

Confusion matrix [rows=true 0/1, cols=pred 0/1]:
[[99  1]
 [ 3 82]]

Classification report:
              precision    recall  f1-score   support

    REAL (0)     0.9706    0.9900    0.9802       100
    FAKE (1)     0.9880    0.9647    0.9762        85

    accuracy                         0.9784       185
   macro avg     0.9793    0.9774    0.9782       185
weighted avg     0.9786    0.9784    0.9784       185


❌ TOP 4 WRONG PREDICTIONS (highest confidence first)

IDX=69 | TRUE=1 | PRED=0 | p_fake=0.041 | conf=0.959 | gate_g=0.963
TEXT: HÀ NỘI: Học sinh từ cấp 2 trở lên dự kiến đi học trở lại từ 4/5, sau kì nghỉ lễ 30/4 Phương án chính thức sẽ được đưa ra vào cuộc họp ngày 27/4....

IDX=93 | TRUE=0 | PRED=1 | p_fake=0.859 | conf=0.859 | gate_g=0.641
TEXT: Toyota Hilux mới tăng giá 22 triệu tại Việt Nam / Fortuner mới tăng giá gần 50 t