In [None]:
# =========================
# FULL PIPELINE: Cross-Attention Fusion (Leak-safe split)
# =========================
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import sys, subprocess, warnings
warnings.filterwarnings("ignore")

def ensure_pkg(import_name, pip_name=None):
    if pip_name is None:
        pip_name = import_name
    try:
        __import__(import_name)
    except Exception:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", pip_name])

print("📦 Installing packages...")
for pkg in [
    ("transformers", "transformers"),
    ("sklearn", "scikit-learn"),
    ("tqdm", "tqdm"),
    ("pandas", "pandas"),
    ("numpy", "numpy"),
    ("torch", "torch"),
]:
    ensure_pkg(pkg[0], pkg[1])

import re, random, copy
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup

from sklearn.metrics import classification_report, f1_score, accuracy_score
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.neighbors import NearestNeighbors
from tqdm.auto import tqdm


# ============================================================
# CONFIG
# ============================================================

CONFIG = {
    "data_path": "/content/sample_data/fake_news_dataset.csv",

    "electra_name": "FPTAI/velectra-base-discriminator-cased",
    "phobert_name": "vinai/phobert-base",

    "max_length_electra": 256,
    "max_length_phobert": 256,

    # Cross-attn train (fusion model)
    "batch_size": 4,                 # cross-attn nặng -> batch nhỏ
    "learning_rate": 2e-4,           # vì chỉ train fusion/head (backbone freeze)
    "epochs": 20,
    "warmup_steps": 80,
    "weight_decay": 0.01,
    "dropout": 0.3,
    "grad_clip": 1.0,

    # Freeze strategy (recommended for ~2k)
    "freeze_backbones": True,        # True: freeze both encoders
    "unfreeze_last_n_layers_electra": 0,  # nếu muốn mở vài layer cuối: 2-4
    "unfreeze_last_n_layers_phobert": 0,

    # Cross-attn config
    "cross_attn_heads": 8,
    "cross_attn_dropout": 0.1,

    # Leak/Near-dup settings
    "near_dup_threshold": 0.92,
    "near_dup_k": 20,
    "char_ngram_range": (4, 6),
    "min_df": 2,

    # Split ratios
    "test_ratio": 0.10,
    "val_ratio_of_trainval": 0.11,

    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "seed": 42,
}

print("="*90)
print("🚀 VIETNAMESE FAKE NEWS - CROSS-ATTENTION FUSION (LEAK-SAFE SPLIT)")
print("="*90)
print(f"🖥️ Device: {CONFIG['device']}")
print("🧾 Label mapping: 0=REAL, 1=FAKE")


# ============================================================
# SEED
# ============================================================

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        try:
            torch.cuda.manual_seed_all(seed)
        except Exception:
            pass

set_seed(CONFIG["seed"])


# ============================================================
# CLEAN TEXT
# ============================================================

def clean_text(text):
    if pd.isna(text):
        return ""
    text = str(text)
    text = re.sub(r"<[^>]+>", " ", text)
    text = re.sub(r"http[s]?://\S+", " ", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text

def is_valid(text):
    return len(text.split()) >= 8


# ============================================================
# Union-Find for clustering near-duplicates
# ============================================================

class UnionFind:
    def __init__(self, n):
        self.p = list(range(n))
        self.r = [0]*n

    def find(self, x):
        while self.p[x] != x:
            self.p[x] = self.p[self.p[x]]
            x = self.p[x]
        return x

    def union(self, a, b):
        ra, rb = self.find(a), self.find(b)
        if ra == rb:
            return
        if self.r[ra] < self.r[rb]:
            self.p[ra] = rb
        elif self.r[ra] > self.r[rb]:
            self.p[rb] = ra
        else:
            self.p[rb] = ra
            self.r[ra] += 1

def build_near_dup_groups(texts, threshold=0.92, k=20, ngram_range=(4,6), min_df=2):
    n = len(texts)
    if n == 0:
        return np.array([], dtype=int)

    vec = TfidfVectorizer(
        analyzer="char_wb",
        ngram_range=ngram_range,
        min_df=min_df,
        dtype=np.float32
    )
    X = vec.fit_transform(texts)

    nnm = NearestNeighbors(
        n_neighbors=min(k, n),
        metric="cosine",
        algorithm="brute",
        n_jobs=-1
    )
    nnm.fit(X)
    dists, idxs = nnm.kneighbors(X, return_distance=True)

    uf = UnionFind(n)
    for i in range(n):
        for dist, j in zip(dists[i], idxs[i]):
            if j == i:
                continue
            sim = 1.0 - float(dist)
            if sim >= threshold:
                uf.union(i, int(j))

    roots = np.array([uf.find(i) for i in range(n)], dtype=int)

    uniq = {}
    gid = np.zeros(n, dtype=int)
    c = 0
    for i, r in enumerate(roots):
        if r not in uniq:
            uniq[r] = c
            c += 1
        gid[i] = uniq[r]
    return gid

def report_leak_exact(a_texts, b_texts, name):
    sa = set(a_texts)
    sb = set(b_texts)
    inter = len(sa & sb)
    print(f"Leak exact {name}: {inter}")
    return inter

def report_leak_near(a_texts, b_texts, threshold=0.92, ngram_range=(4,6), min_df=2):
    if len(a_texts)==0 or len(b_texts)==0:
        return {"mean": 0, "median": 0, "p95": 0, "max": 0}

    vec = TfidfVectorizer(analyzer="char_wb", ngram_range=ngram_range, min_df=min_df, dtype=np.float32)
    X_a = vec.fit_transform(a_texts)
    X_b = vec.transform(b_texts)

    nnm = NearestNeighbors(n_neighbors=1, metric="cosine", algorithm="brute", n_jobs=-1).fit(X_a)
    dists, _ = nnm.kneighbors(X_b, return_distance=True)
    sims = 1.0 - dists.reshape(-1)

    stats = {
        "mean": float(np.mean(sims)),
        "median": float(np.median(sims)),
        "p95": float(np.quantile(sims, 0.95)),
        "max": float(np.max(sims)),
        "count_ge_thr": int(np.sum(sims >= threshold))
    }
    return stats


# ============================================================
# LOAD DATA + CLEAN + EXACT DEDUP
# ============================================================

print("\n" + "="*90)
print("📂 LOADING DATA")
print("="*90)

if not os.path.exists(CONFIG["data_path"]):
    raise FileNotFoundError(f"❌ Not found: {CONFIG['data_path']}")

df = pd.read_csv(CONFIG["data_path"])

# detect columns
if "text" not in df.columns:
    for c in ["content", "article", "news", "body", "title"]:
        if c in df.columns:
            df["text"] = df[c]
            break
if "text" not in df.columns:
    raise ValueError("❌ Cannot find text column")

if "label" not in df.columns:
    for c in ["class", "category", "y"]:
        if c in df.columns:
            df["label"] = df[c]
            break
if "label" not in df.columns:
    raise ValueError("❌ Cannot find label column")

df = df[["text", "label"]].dropna()
df["label"] = df["label"].astype(int)

bad = df[~df["label"].isin([0, 1])]
if len(bad) > 0:
    raise ValueError(f"❌ Found labels not in {{0,1}}. Examples:\n{bad.head()}")

print("🧹 Cleaning + exact dedup...")
df["text_clean"] = df["text"].apply(clean_text)
df = df[df["text_clean"].apply(is_valid)].copy()
before = len(df)
df = df.drop_duplicates(subset=["text_clean"], keep="first").reset_index(drop=True)
print(f"✅ After exact dedup: {len(df)} (removed {before-len(df)})")

n = len(df)
c0 = int((df["label"]==0).sum())
c1 = int((df["label"]==1).sum())
print(f"✅ Final: {n} samples | REAL={c0} ({c0/n:.1%}) | FAKE={c1} ({c1/n:.1%})")


# ============================================================
# BUILD NEAR-DUP GROUPS (CLUSTERING)
# ============================================================

print("\n" + "="*90)
print("🧩 BUILDING NEAR-DUP CLUSTERS")
print("="*90)

groups = build_near_dup_groups(
    df["text_clean"].tolist(),
    threshold=float(CONFIG["near_dup_threshold"]),
    k=int(CONFIG["near_dup_k"]),
    ngram_range=tuple(CONFIG["char_ngram_range"]),
    min_df=int(CONFIG["min_df"])
)
df["group"] = groups

n_groups = int(df["group"].nunique())
group_sizes = df["group"].value_counts()
print(f"✅ Groups: {n_groups} | Largest group size: {int(group_sizes.max())}")
print(f"Top 5 group sizes:\n{group_sizes.head(5).to_string()}")


# ============================================================
# STRATIFIED GROUP SPLIT: TRAIN / VAL / TEST
# ============================================================

print("\n" + "="*90)
print("✂️ LEAK-SAFE SPLIT (STRATIFIED BY LABEL, GROUP-AWARE)")
print("="*90)

y = df["label"].values
g = df["group"].values

kfold = 10
sgkf = StratifiedGroupKFold(n_splits=kfold, shuffle=True, random_state=CONFIG["seed"])

best_fold = None
best_diff = 1e9
test_target = float(CONFIG["test_ratio"])

splits = list(sgkf.split(df, y, groups=g))
for fold_i, (trainval_idx, test_idx) in enumerate(splits):
    ratio = len(test_idx)/len(df)
    diff = abs(ratio - test_target)
    if diff < best_diff:
        best_diff = diff
        best_fold = (trainval_idx, test_idx, ratio, fold_i)

trainval_idx, test_idx, ratio, fold_i = best_fold
print(f"Picked fold {fold_i} for TEST: size={len(test_idx)} ({ratio:.3f})")

df_trainval = df.iloc[trainval_idx].reset_index(drop=True)
df_test = df.iloc[test_idx].reset_index(drop=True)

val_target = float(CONFIG["val_ratio_of_trainval"])
kfold2 = 9
sgkf2 = StratifiedGroupKFold(n_splits=kfold2, shuffle=True, random_state=CONFIG["seed"]+7)
splits2 = list(sgkf2.split(df_trainval, df_trainval["label"].values, groups=df_trainval["group"].values))

best2 = None
best2_diff = 1e9
for fold_i2, (train_idx, val_idx) in enumerate(splits2):
    ratio2 = len(val_idx)/len(df_trainval)
    diff2 = abs(ratio2 - val_target)
    if diff2 < best2_diff:
        best2_diff = diff2
        best2 = (train_idx, val_idx, ratio2, fold_i2)

train_idx, val_idx, ratio2, fold_i2 = best2
print(f"Picked fold {fold_i2} for VAL: size={len(val_idx)} ({ratio2:.3f})")

train_df = df_trainval.iloc[train_idx].reset_index(drop=True)
val_df   = df_trainval.iloc[val_idx].reset_index(drop=True)
test_df  = df_test.copy()

def dist_print(name, dfx):
    n = len(dfx)
    c0 = int((dfx["label"]==0).sum())
    c1 = int((dfx["label"]==1).sum())
    print(f"{name}: {n} | REAL={c0} ({c0/n:.1%}) | FAKE={c1} ({c1/n:.1%}) | groups={dfx['group'].nunique()}")

dist_print("Train", train_df)
dist_print("Val  ", val_df)
dist_print("Test ", test_df)

overlap_tv = set(train_df["group"]) & set(val_df["group"])
overlap_tt = set(train_df["group"]) & set(test_df["group"])
overlap_vt = set(val_df["group"]) & set(test_df["group"])
print(f"\nGroup overlap Train∩Val={len(overlap_tv)} | Train∩Test={len(overlap_tt)} | Val∩Test={len(overlap_vt)}")


# ============================================================
# LEAK REPORT (EXACT + NEAR)
# ============================================================

print("\n" + "="*90)
print("🧪 LEAK REPORT")
print("="*90)

report_leak_exact(train_df["text_clean"], val_df["text_clean"], "Train∩Val")
report_leak_exact(train_df["text_clean"], test_df["text_clean"], "Train∩Test")
report_leak_exact(val_df["text_clean"], test_df["text_clean"], "Val∩Test")

stats_tv = report_leak_near(train_df["text_clean"].tolist(), val_df["text_clean"].tolist(),
                            threshold=CONFIG["near_dup_threshold"],
                            ngram_range=CONFIG["char_ngram_range"],
                            min_df=CONFIG["min_df"])
stats_tt = report_leak_near(train_df["text_clean"].tolist(), test_df["text_clean"].tolist(),
                            threshold=CONFIG["near_dup_threshold"],
                            ngram_range=CONFIG["char_ngram_range"],
                            min_df=CONFIG["min_df"])
stats_vt = report_leak_near(val_df["text_clean"].tolist(), test_df["text_clean"].tolist(),
                            threshold=CONFIG["near_dup_threshold"],
                            ngram_range=CONFIG["char_ngram_range"],
                            min_df=CONFIG["min_df"])

print("\nNear-dup cosine stats (char-ngram TFIDF):")
print("Val->Train:", stats_tv)
print("Test->Train:", stats_tt)
print("Test->Val:", stats_vt)


# ============================================================
# DATASET / DATALOADER (dual tokenizers)
# ============================================================

tokE = AutoTokenizer.from_pretrained(CONFIG["electra_name"], use_fast=False)
tokP = AutoTokenizer.from_pretrained(CONFIG["phobert_name"], use_fast=False)

class DualTokDataset(Dataset):
    def __init__(self, texts, labels, tokE, tokP, maxE=256, maxP=256):
        self.texts = list(texts)
        self.labels = list(labels)
        self.tokE = tokE
        self.tokP = tokP
        self.maxE = int(maxE)
        self.maxP = int(maxP)

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        y = int(self.labels[idx])

        e = self.tokE(text, max_length=self.maxE, padding="max_length", truncation=True, return_tensors="pt")
        p = self.tokP(text, max_length=self.maxP, padding="max_length", truncation=True, return_tensors="pt")

        e_input_ids = e["input_ids"].squeeze(0).long()
        e_attn      = e["attention_mask"].squeeze(0).long()
        p_input_ids = p["input_ids"].squeeze(0).long()
        p_attn      = p["attention_mask"].squeeze(0).long()

        # sanity: token ids < vocab size
        if int(e_input_ids.max()) >= len(self.tokE):
            raise ValueError("Bad token id for Electra tokenizer")
        if int(p_input_ids.max()) >= len(self.tokP):
            raise ValueError("Bad token id for PhoBERT tokenizer")

        return {
            "e_input_ids": e_input_ids,
            "e_attn": e_attn,
            "p_input_ids": p_input_ids,
            "p_attn": p_attn,
            "label": torch.tensor(y, dtype=torch.long),
        }

def make_dual_loaders(bs):
    train_loader = DataLoader(
        DualTokDataset(train_df["text_clean"].values, train_df["label"].values, tokE, tokP,
                       CONFIG["max_length_electra"], CONFIG["max_length_phobert"]),
        batch_size=int(bs), shuffle=True
    )
    val_loader = DataLoader(
        DualTokDataset(val_df["text_clean"].values, val_df["label"].values, tokE, tokP,
                       CONFIG["max_length_electra"], CONFIG["max_length_phobert"]),
        batch_size=int(bs), shuffle=False
    )
    test_loader = DataLoader(
        DualTokDataset(test_df["text_clean"].values, test_df["label"].values, tokE, tokP,
                       CONFIG["max_length_electra"], CONFIG["max_length_phobert"]),
        batch_size=int(bs), shuffle=False
    )
    return train_loader, val_loader, test_loader

train_loader, val_loader, test_loader = make_dual_loaders(CONFIG["batch_size"])

def compute_class_weights(y):
    counts = np.bincount(y, minlength=2)
    weights = counts.sum() / (2.0 * np.maximum(counts, 1))
    return counts, torch.tensor(weights, dtype=torch.float32, device=CONFIG["device"])


# ============================================================
# MODEL: Cross-attention fusion
#   PhoBERT tokens (Q) attend to vELECTRA tokens (K,V)
# ============================================================

def set_requires_grad(module, flag: bool):
    for p in module.parameters():
        p.requires_grad = flag

def unfreeze_last_n_layers_electra(electra_model, n_last: int):
    # electra_model: AutoModel (ElectraModel) inside is electra_model.encoder.layer
    if n_last <= 0:
        return
    if not hasattr(electra_model, "encoder"):
        return
    layers = electra_model.encoder.layer
    L = len(layers)
    for i in range(max(0, L - n_last), L):
        for p in layers[i].parameters():
            p.requires_grad = True

def unfreeze_last_n_layers_roberta(roberta_model, n_last: int):
    if n_last <= 0:
        return
    if not hasattr(roberta_model, "encoder"):
        return
    layers = roberta_model.encoder.layer
    L = len(layers)
    for i in range(max(0, L - n_last), L):
        for p in layers[i].parameters():
            p.requires_grad = True

class CrossAttnFusionClassifier(nn.Module):
    """
    - e_encoder: vELECTRA (AutoModel) -> hidden [B, Le, He]
    - p_encoder: PhoBERT (AutoModel)  -> hidden [B, Lp, Hp]
    - cross-attn: Q from PhoBERT tokens, K/V from Electra tokens
    - pool: CLS_P and mean(cross_out over non-pad tokens)
    - head: MLP -> logits(2)
    """
    def __init__(self, electra_name, phobert_name, dropout=0.3, heads=8, cross_attn_dropout=0.1):
        super().__init__()
        self.e = AutoModel.from_pretrained(electra_name)
        self.p = AutoModel.from_pretrained(phobert_name)

        He = self.e.config.hidden_size
        Hp = self.p.config.hidden_size

        self.proj_kv = nn.Linear(He, Hp) if He != Hp else nn.Identity()

        # MultiheadAttention expects [L, B, H] if batch_first=False
        self.cross_attn = nn.MultiheadAttention(
            embed_dim=Hp,
            num_heads=heads,
            dropout=cross_attn_dropout,
            batch_first=False
        )

        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(Hp)

        self.head = nn.Sequential(
            nn.Linear(Hp * 2, Hp),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(Hp, 2)
        )

    def forward(self, e_input_ids, e_attn, p_input_ids, p_attn):
        e_out = self.e(input_ids=e_input_ids, attention_mask=e_attn, return_dict=True)
        p_out = self.p(input_ids=p_input_ids, attention_mask=p_attn, return_dict=True)

        # [B, L, H]
        E = e_out.last_hidden_state
        P = p_out.last_hidden_state

        # project Electra hidden to PhoBERT hidden size if needed
        E2 = self.proj_kv(E)  # [B, Le, Hp]

        # Convert to [L, B, H]
        Q = P.transpose(0, 1)      # [Lp, B, Hp]
        K = E2.transpose(0, 1)     # [Le, B, Hp]
        V = E2.transpose(0, 1)     # [Le, B, Hp]

        # key_padding_mask: True for PAD positions
        e_key_pad = (e_attn == 0)  # [B, Le]

        cross, _ = self.cross_attn(Q, K, V, key_padding_mask=e_key_pad)  # [Lp, B, Hp]
        cross = cross.transpose(0, 1)  # [B, Lp, Hp]

        # residual + norm (stabilize)
        cross = self.norm(cross + P)

        # pool
        cls_p = P[:, 0, :]  # [B, Hp]

        # mean pool cross over non-pad tokens of PhoBERT
        mask = p_attn.unsqueeze(-1).float()          # [B, Lp, 1]
        cross_sum = (cross * mask).sum(dim=1)        # [B, Hp]
        denom = mask.sum(dim=1).clamp(min=1.0)       # [B, 1]
        cross_mean = cross_sum / denom               # [B, Hp]

        feat = torch.cat([cls_p, cross_mean], dim=1) # [B, 2Hp]
        feat = self.dropout(feat)
        logits = self.head(feat)
        return logits


# ============================================================
# Build model + freeze strategy
# ============================================================

model = CrossAttnFusionClassifier(
    CONFIG["electra_name"],
    CONFIG["phobert_name"],
    dropout=float(CONFIG["dropout"]),
    heads=int(CONFIG["cross_attn_heads"]),
    cross_attn_dropout=float(CONFIG["cross_attn_dropout"]),
).to(CONFIG["device"])

if CONFIG["freeze_backbones"]:
    set_requires_grad(model.e, False)
    set_requires_grad(model.p, False)

    # optionally unfreeze last N layers for each backbone
    if int(CONFIG["unfreeze_last_n_layers_electra"]) > 0:
        unfreeze_last_n_layers_electra(model.e, int(CONFIG["unfreeze_last_n_layers_electra"]))
    if int(CONFIG["unfreeze_last_n_layers_phobert"]) > 0:
        unfreeze_last_n_layers_roberta(model.p, int(CONFIG["unfreeze_last_n_layers_phobert"]))

    # Always train fusion layers + head
    set_requires_grad(model.proj_kv, True)
    set_requires_grad(model.cross_attn, True)
    set_requires_grad(model.norm, True)
    set_requires_grad(model.head, True)

# ============================================================
# TRAINING
# ============================================================

y_train = train_df["label"].values.astype(int)
counts, class_w = compute_class_weights(y_train)
print("\nClass counts [REAL, FAKE]:", counts, "| class_weights:", class_w.detach().cpu().numpy())

criterion = nn.CrossEntropyLoss(weight=class_w)

optimizer = AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=float(CONFIG["learning_rate"]),
    weight_decay=float(CONFIG["weight_decay"]),
)

total_steps = max(len(train_loader) * int(CONFIG["epochs"]), 1)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(CONFIG["warmup_steps"]),
    num_training_steps=total_steps
)

@torch.no_grad()
def eval_model(loader, y_true):
    model.eval()
    probs = []
    for batch in tqdm(loader, desc="Eval", leave=False):
        e_input_ids = batch["e_input_ids"].to(CONFIG["device"])
        e_attn      = batch["e_attn"].to(CONFIG["device"])
        p_input_ids = batch["p_input_ids"].to(CONFIG["device"])
        p_attn      = batch["p_attn"].to(CONFIG["device"])

        logits = model(e_input_ids, e_attn, p_input_ids, p_attn)
        pr = F.softmax(logits, dim=1)[:, 1].detach().cpu().numpy()
        probs.append(pr)

    p_fake = np.concatenate(probs, axis=0)
    pred = (p_fake >= 0.5).astype(int)

    acc = accuracy_score(y_true, pred)
    f1m = f1_score(y_true, pred, average="macro")
    return acc, f1m, p_fake, pred

print("\n" + "="*90)
print("🧠 TRAINING: CROSS-ATTENTION FUSION")
print("="*90)

best_state = None
best_f1 = -1.0

for ep in range(int(CONFIG["epochs"])):
    model.train()
    total_loss = 0.0

    for batch in tqdm(train_loader, desc=f"Train ep{ep+1}", leave=False):
        e_input_ids = batch["e_input_ids"].to(CONFIG["device"])
        e_attn      = batch["e_attn"].to(CONFIG["device"])
        p_input_ids = batch["p_input_ids"].to(CONFIG["device"])
        p_attn      = batch["p_attn"].to(CONFIG["device"])
        labels      = batch["label"].to(CONFIG["device"])

        optimizer.zero_grad(set_to_none=True)
        logits = model(e_input_ids, e_attn, p_input_ids, p_attn)
        loss = criterion(logits, labels)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), float(CONFIG["grad_clip"]))
        optimizer.step()
        scheduler.step()

        total_loss += float(loss.item())

    y_val = val_df["label"].values.astype(int)
    acc, f1m, _, _ = eval_model(val_loader, y_val)

    print(f"Epoch {ep+1}/{CONFIG['epochs']} | loss={total_loss/max(len(train_loader),1):.4f} "
          f"| val_acc={acc*100:.2f}% | val_macroF1={f1m:.4f}")

    if f1m > best_f1:
        best_f1 = f1m
        best_state = copy.deepcopy(model.state_dict())
        print("✅ New best")

if best_state is not None:
    model.load_state_dict(best_state)

# ============================================================
# FINAL EVAL ON TEST
# ============================================================

print("\n" + "="*90)
print("🎯 FINAL EVALUATION ON TEST (CROSS-ATTENTION, LEAK-SAFE)")
print("="*90)

y_test = test_df["label"].values.astype(int)
acc, f1m, p_fake, pred_test = eval_model(test_loader, y_test)

f1_fake = f1_score(y_test, pred_test, pos_label=1, average="binary")
f1_real = f1_score(y_test, pred_test, pos_label=0, average="binary")
macro = (f1_fake + f1_real)/2

print(f"Accuracy:    {acc*100:.2f}%")
print(f"Macro-F1:    {macro:.4f}")
print(f"F1 FAKE(1):  {f1_fake:.4f}")
print(f"F1 REAL(0):  {f1_real:.4f}")
print(f"Gap:         {abs(f1_fake - f1_real):.4f}")

print("\n📋 Classification Report:")
print(classification_report(y_test, pred_test, target_names=["REAL (0)", "FAKE (1)"], digits=4))

# ============================================================
# QUICK INTERACTIVE PREDICT
# ============================================================

@torch.no_grad()
def predict_one(text: str):
    text = clean_text(text)
    if not is_valid(text):
        return {"error": "Text too short after cleaning"}

    encE = tokE(text, max_length=CONFIG["max_length_electra"], padding="max_length", truncation=True, return_tensors="pt")
    encP = tokP(text, max_length=CONFIG["max_length_phobert"], padding="max_length", truncation=True, return_tensors="pt")

    e_input_ids = encE["input_ids"].to(CONFIG["device"])
    e_attn      = encE["attention_mask"].to(CONFIG["device"])
    p_input_ids = encP["input_ids"].to(CONFIG["device"])
    p_attn      = encP["attention_mask"].to(CONFIG["device"])

    model.eval()
    logits = model(e_input_ids, e_attn, p_input_ids, p_attn)
    probs = F.softmax(logits, dim=1).squeeze(0).detach().cpu().numpy()

    p_fake = float(probs[1])
    pred = 1 if p_fake >= 0.5 else 0
    conf = max(p_fake, 1 - p_fake)

    return {"pred": pred, "p_fake": p_fake, "conf": conf}

print("\n✅ Ready. Try: predict_one('Tin sốc: ...')\n")


📦 Installing packages...
🚀 VIETNAMESE FAKE NEWS - CROSS-ATTENTION FUSION (LEAK-SAFE SPLIT)
🖥️ Device: cuda
🧾 Label mapping: 0=REAL, 1=FAKE

📂 LOADING DATA
🧹 Cleaning + exact dedup...
✅ After exact dedup: 1843 (removed 0)
✅ Final: 1843 samples | REAL=974 (52.8%) | FAKE=869 (47.2%)

🧩 BUILDING NEAR-DUP CLUSTERS
✅ Groups: 1806 | Largest group size: 8
Top 5 group sizes:
group
35     8
136    7
26     4
698    2
200    2

✂️ LEAK-SAFE SPLIT (STRATIFIED BY LABEL, GROUP-AWARE)
Picked fold 2 for TEST: size=185 (0.100)
Picked fold 2 for VAL: size=182 (0.110)
Train: 1476 | REAL=776 (52.6%) | FAKE=700 (47.4%) | groups=1444
Val  : 182 | REAL=98 (53.8%) | FAKE=84 (46.2%) | groups=180
Test : 185 | REAL=100 (54.1%) | FAKE=85 (45.9%) | groups=182

Group overlap Train∩Val=0 | Train∩Test=0 | Val∩Test=0

🧪 LEAK REPORT
Leak exact Train∩Val: 0
Leak exact Train∩Test: 0
Leak exact Val∩Test: 0

Near-dup cosine stats (char-ngram TFIDF):
Val->Train: {'mean': 0.3505195379257202, 'median': 0.30619990825653076, 'p

config.json: 0.00B [00:00, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/557 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]



bpe.codes: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/443M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/197 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/443M [00:00<?, ?B/s]

ElectraModel LOAD REPORT from: FPTAI/velectra-base-discriminator-cased
Key                                               | Status     |  | 
--------------------------------------------------+------------+--+-
discriminator_predictions.dense.weight            | UNEXPECTED |  | 
discriminator_predictions.dense_prediction.weight | UNEXPECTED |  | 
discriminator_predictions.dense.bias              | UNEXPECTED |  | 
discriminator_predictions.dense_prediction.bias   | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


pytorch_model.bin:   0%|          | 0.00/543M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

RobertaModel LOAD REPORT from: vinai/phobert-base
Key                             | Status     |  | 
--------------------------------+------------+--+-
roberta.embeddings.position_ids | UNEXPECTED |  | 
lm_head.layer_norm.weight       | UNEXPECTED |  | 
lm_head.dense.weight            | UNEXPECTED |  | 
lm_head.layer_norm.bias         | UNEXPECTED |  | 
lm_head.decoder.bias            | UNEXPECTED |  | 
lm_head.dense.bias              | UNEXPECTED |  | 
lm_head.bias                    | UNEXPECTED |  | 
lm_head.decoder.weight          | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


model.safetensors:   0%|          | 0.00/543M [00:00<?, ?B/s]


Class counts [REAL, FAKE]: [776 700] | class_weights: [0.9510309 1.0542858]

🧠 TRAINING: CROSS-ATTENTION FUSION


Train ep1:   0%|          | 0/369 [00:00<?, ?it/s]

Eval:   0%|          | 0/46 [00:00<?, ?it/s]

Epoch 1/20 | loss=0.2726 | val_acc=96.70% | val_macroF1=0.9668
✅ New best


Train ep2:   0%|          | 0/369 [00:00<?, ?it/s]

Eval:   0%|          | 0/46 [00:00<?, ?it/s]

Epoch 2/20 | loss=0.1779 | val_acc=96.15% | val_macroF1=0.9610


Train ep3:   0%|          | 0/369 [00:00<?, ?it/s]

Eval:   0%|          | 0/46 [00:00<?, ?it/s]

Epoch 3/20 | loss=0.1455 | val_acc=97.25% | val_macroF1=0.9723
✅ New best


Train ep4:   0%|          | 0/369 [00:00<?, ?it/s]

Eval:   0%|          | 0/46 [00:00<?, ?it/s]

Epoch 4/20 | loss=0.1038 | val_acc=97.25% | val_macroF1=0.9724
✅ New best


Train ep5:   0%|          | 0/369 [00:00<?, ?it/s]

Eval:   0%|          | 0/46 [00:00<?, ?it/s]

Epoch 5/20 | loss=0.0871 | val_acc=97.80% | val_macroF1=0.9780
✅ New best


Train ep6:   0%|          | 0/369 [00:00<?, ?it/s]

Eval:   0%|          | 0/46 [00:00<?, ?it/s]

Epoch 6/20 | loss=0.0704 | val_acc=97.80% | val_macroF1=0.9779


Train ep7:   0%|          | 0/369 [00:00<?, ?it/s]

Eval:   0%|          | 0/46 [00:00<?, ?it/s]

Epoch 7/20 | loss=0.0576 | val_acc=97.80% | val_macroF1=0.9779


Train ep8:   0%|          | 0/369 [00:00<?, ?it/s]

Eval:   0%|          | 0/46 [00:00<?, ?it/s]

Epoch 8/20 | loss=0.0444 | val_acc=97.25% | val_macroF1=0.9723


Train ep9:   0%|          | 0/369 [00:00<?, ?it/s]

Eval:   0%|          | 0/46 [00:00<?, ?it/s]

Epoch 9/20 | loss=0.0328 | val_acc=97.80% | val_macroF1=0.9779


Train ep10:   0%|          | 0/369 [00:00<?, ?it/s]

Eval:   0%|          | 0/46 [00:00<?, ?it/s]

Epoch 10/20 | loss=0.0348 | val_acc=97.25% | val_macroF1=0.9723


Train ep11:   0%|          | 0/369 [00:00<?, ?it/s]

Eval:   0%|          | 0/46 [00:00<?, ?it/s]

Epoch 11/20 | loss=0.0163 | val_acc=97.25% | val_macroF1=0.9723


Train ep12:   0%|          | 0/369 [00:00<?, ?it/s]

Eval:   0%|          | 0/46 [00:00<?, ?it/s]

Epoch 12/20 | loss=0.0181 | val_acc=97.25% | val_macroF1=0.9723


Train ep13:   0%|          | 0/369 [00:00<?, ?it/s]

Eval:   0%|          | 0/46 [00:00<?, ?it/s]

Epoch 13/20 | loss=0.0115 | val_acc=98.35% | val_macroF1=0.9834
✅ New best


Train ep14:   0%|          | 0/369 [00:00<?, ?it/s]

Eval:   0%|          | 0/46 [00:00<?, ?it/s]

Epoch 14/20 | loss=0.0130 | val_acc=97.25% | val_macroF1=0.9723


Train ep15:   0%|          | 0/369 [00:00<?, ?it/s]

Eval:   0%|          | 0/46 [00:00<?, ?it/s]

Epoch 15/20 | loss=0.0101 | val_acc=97.25% | val_macroF1=0.9723


Train ep16:   0%|          | 0/369 [00:00<?, ?it/s]

Eval:   0%|          | 0/46 [00:00<?, ?it/s]

Epoch 16/20 | loss=0.0114 | val_acc=97.80% | val_macroF1=0.9779


Train ep17:   0%|          | 0/369 [00:00<?, ?it/s]

Eval:   0%|          | 0/46 [00:00<?, ?it/s]

Epoch 17/20 | loss=0.0076 | val_acc=96.70% | val_macroF1=0.9668


Train ep18:   0%|          | 0/369 [00:00<?, ?it/s]

Eval:   0%|          | 0/46 [00:00<?, ?it/s]

Epoch 18/20 | loss=0.0081 | val_acc=97.25% | val_macroF1=0.9723


Train ep19:   0%|          | 0/369 [00:00<?, ?it/s]

Eval:   0%|          | 0/46 [00:00<?, ?it/s]

Epoch 19/20 | loss=0.0053 | val_acc=97.80% | val_macroF1=0.9779


Train ep20:   0%|          | 0/369 [00:00<?, ?it/s]

Eval:   0%|          | 0/46 [00:00<?, ?it/s]

Epoch 20/20 | loss=0.0054 | val_acc=97.80% | val_macroF1=0.9779

🎯 FINAL EVALUATION ON TEST (CROSS-ATTENTION, LEAK-SAFE)


Eval:   0%|          | 0/47 [00:00<?, ?it/s]

Accuracy:    99.46%
Macro-F1:    0.9946
F1 FAKE(1):  0.9942
F1 REAL(0):  0.9950
Gap:         0.0008

📋 Classification Report:
              precision    recall  f1-score   support

    REAL (0)     1.0000    0.9900    0.9950       100
    FAKE (1)     0.9884    1.0000    0.9942        85

    accuracy                         0.9946       185
   macro avg     0.9942    0.9950    0.9946       185
weighted avg     0.9947    0.9946    0.9946       185


✅ Ready. Try: predict_one('Tin sốc: ...')



In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, f1_score
from tqdm.auto import tqdm

DEVICE = CONFIG["device"]

@torch.no_grad()
def predict_one_cross(text: str):
    text = clean_text(text)
    if not is_valid(text):
        return {"ok": False, "error": "Text too short after cleaning", "text": text}

    encE = tokE(
        text,
        max_length=CONFIG["max_length_electra"],
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    encP = tokP(
        text,
        max_length=CONFIG["max_length_phobert"],
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )

    e_input_ids = encE["input_ids"].to(DEVICE)
    e_attn      = encE["attention_mask"].to(DEVICE)
    p_input_ids = encP["input_ids"].to(DEVICE)
    p_attn      = encP["attention_mask"].to(DEVICE)

    model.eval()
    logits = model(e_input_ids, e_attn, p_input_ids, p_attn)   # [1,2]
    probs = F.softmax(logits, dim=1).squeeze(0).detach().cpu().numpy()

    p_fake = float(probs[1])
    pred = 1 if p_fake >= 0.5 else 0
    conf = max(p_fake, 1 - p_fake)

    return {"ok": True, "pred": pred, "p_fake": p_fake, "conf": conf, "text": text}

@torch.no_grad()
def evaluate_on_test_cross(top_k_errors=15):
    y_true = test_df["label"].values.astype(int)
    texts  = test_df["text_clean"].values

    preds, p_fakes, confs = [], [], []

    for t in tqdm(texts, desc="CrossAttn test inference"):
        out = predict_one_cross(t)
        if not out["ok"]:
            preds.append(0); p_fakes.append(0.0); confs.append(0.5)
            continue
        preds.append(out["pred"])
        p_fakes.append(out["p_fake"])
        confs.append(out["conf"])

    preds = np.array(preds, dtype=int)
    p_fakes = np.array(p_fakes, dtype=float)
    confs = np.array(confs, dtype=float)

    acc = accuracy_score(y_true, preds)
    f1_fake = f1_score(y_true, preds, pos_label=1)
    f1_real = f1_score(y_true, preds, pos_label=0)
    macro = (f1_fake + f1_real) / 2

    print("="*90)
    print("📌 TEST SET METRICS (CROSS-ATTENTION)")
    print("="*90)
    print(f"Accuracy: {acc*100:.2f}%")
    print(f"Macro-F1: {macro:.4f} | F1_FAKE={f1_fake:.4f} | F1_REAL={f1_real:.4f} | Gap={abs(f1_fake-f1_real):.4f}")

    cm = confusion_matrix(y_true, preds, labels=[0,1])
    print("\nConfusion matrix [rows=true 0/1, cols=pred 0/1]:")
    print(cm)

    print("\nClassification report:")
    print(classification_report(y_true, preds, target_names=["REAL (0)", "FAKE (1)"], digits=4))

    wrong = np.where(preds != y_true)[0]
    if len(wrong) == 0:
        print("\n✅ No errors on test.")
        return

    wrong_sorted = wrong[np.argsort(-confs[wrong])]
    print("\n" + "="*90)
    print(f"❌ TOP {min(top_k_errors, len(wrong_sorted))} WRONG PREDICTIONS (highest confidence first)")
    print("="*90)

    for i in wrong_sorted[:top_k_errors]:
        snippet = texts[i][:260].replace("\n"," ")
        print(f"\nIDX={i} | TRUE={y_true[i]} | PRED={preds[i]} | p_fake={p_fakes[i]:.3f} | conf={confs[i]:.3f}")
        print(f"TEXT: {snippet}...")

# same 10 “ngoài đời” samples
GENERALIZATION_10 = [
    "TIN SỐC!!! Chỉ cần uống nước chanh theo cách này 3 ngày là khỏi hoàn toàn tiểu đường, bác sĩ cũng bất ngờ. Xem ngay!!!",
    "Sở Giao thông Vận tải TP.HCM thông báo điều chỉnh tổ chức giao thông một số tuyến đường khu vực trung tâm để phục vụ thi công, thời gian áp dụng từ ngày 10/02.",
    "Các nhà khoa học xác nhận người ngoài hành tinh đã hạ cánh ở Việt Nam và để lại thiết bị lạ, video bằng chứng đang lan truyền mạnh.",
    "Giá vàng trong nước sáng nay biến động nhẹ, nhiều doanh nghiệp điều chỉnh tăng/giảm vài chục nghìn đồng mỗi lượng so với cuối ngày hôm qua.",
    "CẢNH BÁO KHẨN: Ai nhận cuộc gọi số lạ đọc 3 số cuối CCCD sẽ bị trừ tiền tài khoản ngay lập tức. Hãy chia sẻ để cứu mọi người!",
    "Công an cho biết đang xác minh thông tin lan truyền trên mạng liên quan đến vụ việc tại một khu dân cư, đồng thời đề nghị người dân không chia sẻ thông tin chưa kiểm chứng.",
    "Không cần vốn, chỉ với điện thoại bạn có thể kiếm 5 triệu/ngày đảm bảo 100%. Ai cũng làm được, đăng ký ngay kẻo lỡ!",
    "Trung tâm dự báo khí tượng thủy văn nhận định trong vài ngày tới, khu vực Nam Bộ có mưa rào và dông rải rác vào chiều tối.",
    "Bộ Y tế khuyến cáo người dân tiêm nhắc lại vắc-xin theo hướng dẫn và theo dõi thông tin từ các nguồn chính thống khi có dịch bệnh.",
    "Thực hư chuyện một loại nước ngọt đang bị cấm bán vì gây ung thư ngay lập tức? Nhiều người hoang mang, sự thật khiến ai cũng sốc!",
]

def generalization_check_10_cross(samples=GENERALIZATION_10):
    print("\n" + "="*90)
    print("🌍 GENERALIZATION CHECK (10 'ngoài đời' samples) - CROSS-ATTENTION")
    print("="*90)
    print("Legend: pred 0=REAL, 1=FAKE\n")

    for i, s in enumerate(samples, 1):
        out = predict_one_cross(s)
        if not out["ok"]:
            print(f"{i:02d}. [INVALID] {out.get('error')}")
            continue
        print(f"{i:02d}. PRED={out['pred']} | p_fake={out['p_fake']:.3f} | conf={out['conf']:.3f}")
        print(f"    TEXT: {s}")

# run
evaluate_on_test_cross(top_k_errors=12)
generalization_check_10_cross()


CrossAttn test inference:   0%|          | 0/185 [00:00<?, ?it/s]

📌 TEST SET METRICS (CROSS-ATTENTION)
Accuracy: 99.46%
Macro-F1: 0.9946 | F1_FAKE=0.9942 | F1_REAL=0.9950 | Gap=0.0008

Confusion matrix [rows=true 0/1, cols=pred 0/1]:
[[99  1]
 [ 0 85]]

Classification report:
              precision    recall  f1-score   support

    REAL (0)     1.0000    0.9900    0.9950       100
    FAKE (1)     0.9884    1.0000    0.9942        85

    accuracy                         0.9946       185
   macro avg     0.9942    0.9950    0.9946       185
weighted avg     0.9947    0.9946    0.9946       185


❌ TOP 1 WRONG PREDICTIONS (highest confidence first)

IDX=89 | TRUE=0 | PRED=1 | p_fake=1.000 | conf=1.000
TEXT: Kiểu đeo đeo nhẫn giữa ngón tay được nữ sinh đặc biệt yêu thích - Ảnh: Thanh Nam “Cách đeo nhẫn giữa ngón tay là mốt hiện nay, được cách tân mới lạ”, Ngọc Trâm, sinh viên Trường ĐH Kinh tế - Tài chính TP.HCM giơ ngón tay đeo nhẫn vui vẻ nói. Không phải ngẫu nh...

🌍 GENERALIZATION CHECK (10 'ngoài đời' samples) - CROSS-ATTENTION
Legend: pred 0=RE