<a href="https://colab.research.google.com/github/changmin-jen/clickbait-detection/blob/main/clickbait_detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [42]:
# ============================================================
# Clickbait Pair-Fusion (Title/Thumb + STT/Keyframe, Var-Length)
# End-to-end Colab script (Top-k pooling for all branches)
# ============================================================

# -------- 0) 기본 설정 & 드라이브 마운트 --------
import os, random, warnings
warnings.filterwarnings("ignore")

from google.colab import drive
drive.mount('/content/drive')

ROOT = "/content/drive/MyDrive/clickbait_data"
EMB_NPZ = f"{ROOT}/embeddings_sbert_dot_v1.npz"
LABEL_XLSX = f"{ROOT}/clickbait_enriched_caps_QWEN_thumbcaps_qwen.xlsx"

ID_COL = "id"; LABEL_COL = "label"; SPLIT_COL = "split"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [43]:
# 하이퍼파라미터
BATCH_TRAIN = 64
BATCH_EVAL  = 256
LR = 1e-4
EPOCHS = 50
MU = 1.5 #fuse가중비율
BR_W = 0.5   # 브랜치 합에 곱할 가중
WD = 1e-3
HIDDEN = 64
DROPOUT = 0.4
LABEL_SMOOTH = 0.01
SEED = 42

In [44]:
# -------- 1) 시드 고정 --------
import numpy as np, torch, pandas as pd
def set_seed(seed=SEED):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
set_seed(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

device: cpu


In [45]:
# -------- 2) NPZ 로딩 --------
npz = np.load(EMB_NPZ, allow_pickle=True)
print("NPZ keys ->", list(npz.files))

def npz_get(npz_obj, candidates):
    keys = list(npz_obj.files)
    for c in candidates:
        if c in keys: return npz_obj[c]
        if f"{c}.npy" in keys: return npz_obj[f"{c}.npy"]
    for c in candidates:
        m = [k for k in keys if c in k]
        if len(m) == 1: return npz_obj[m[0]]
    raise KeyError(f"NPZ에 {candidates} 중 일치 키 없음. 존재키={keys}")

emb_title = npz_get(npz, ["emb_title","title"])
emb_thumb = npz_get(npz, ["emb_thumb","thumb"])
emb_stt   = npz_get(npz, ["emb_stt","stt"])
emb_kf    = npz_get(npz, ["emb_kf","kf"])
index_ids = npz_get(npz, ["index"])

def to_2d(arr):
    if isinstance(arr, np.ndarray) and arr.ndim == 2: return arr
    if isinstance(arr, np.ndarray) and arr.dtype == object: return arr
    if isinstance(arr, np.ndarray) and arr.ndim == 1 and arr.dtype == object:
        return np.stack([np.asarray(v).reshape(-1) for v in arr], axis=0)
    if isinstance(arr, np.ndarray) and arr.ndim == 1: return arr.reshape(1,-1)
    return arr

emb_title = to_2d(emb_title)
emb_thumb = to_2d(emb_thumb)
print("shapes:", emb_title.shape, (len(emb_stt),"object_seq"),
      emb_thumb.shape, (len(emb_kf),"object_seq"), "index_len:", len(index_ids))

NPZ keys -> ['emb_title', 'emb_thumb', 'emb_stt', 'emb_kf', 'index', 'title_col', 'stt_col', 'kf_col', 'thumb_col']
shapes: (398, 768) (398, 'object_seq') (398, 768) (398, 'object_seq') index_len: 398


In [46]:
# -------- 3) 라벨 로딩 & 인덱스 매핑(위치 기반) --------
df = pd.read_excel(LABEL_XLSX)
df[ID_COL] = df[ID_COL].astype(str).str.strip()
df[LABEL_COL] = df[LABEL_COL].astype(str).str.strip().str.lower()
df[SPLIT_COL] = df[SPLIT_COL].astype(str).str.strip().str.lower().replace({"val":"valid"})

N = emb_title.shape[0]
assert emb_thumb.shape[0]==N and len(emb_stt)==N and len(emb_kf)==N
assert len(df)==N, f"라벨({len(df)}) != 임베딩({N})"

df = df.reset_index(drop=True).copy()
df["emb_idx"] = np.arange(N, dtype=int)
df["y"] = (df[LABEL_COL] == "clickbait").astype("float32")

df_tr = df[df[SPLIT_COL]=="train"].reset_index(drop=True)
df_va = df[df[SPLIT_COL]=="valid"].reset_index(drop=True)
df_te = df[df[SPLIT_COL]=="test"].reset_index(drop=True)
print("split sizes:", len(df_tr), len(df_va), len(df_te))

# 클래스 가중치(선택)
pos = float(df_tr["y"].sum()); neg = float((1 - df_tr["y"]).sum())
pos_weight = torch.tensor([(neg / max(pos, 1.0))], device=device)
print("pos_weight:", pos_weight.item())

split sizes: 200 98 100
pos_weight: 1.0


In [47]:
# -------- 4) Dataset / DataLoader (정규화 포함) --------
import torch
from torch.utils.data import Dataset, DataLoader

# 고정 길이 L2 정규화 적용
def l2n(t):
    return t / (t.norm(dim=1, keepdim=True) + 1e-8)  # 두 벡터를 단위벡터로 만들어 크기가 같게한다(크기가 같은 벡터의 내적은 코사인유사도 이기때문)

T_title = l2n(torch.from_numpy(emb_title).float())
T_thumb = l2n(torch.from_numpy(emb_thumb).float())

# 시퀀스 정리 함수
def get_seq(obj_arr, i):
    x = obj_arr[i]
    x = torch.as_tensor(x, dtype=torch.float32)
    if x.ndim == 1:
        x = x.unsqueeze(0)
    if x.size(0) == 0:  # 빈 시퀀스 방어
        x = torch.zeros(1, x.size(1))
    # 시퀀스 내부 L2 정규화
    x = x / (x.norm(dim=1, keepdim=True) + 1e-8)
    return x

class PairFusionVarLenDS(Dataset):
    # 데이터셋: 제목, 썸네일, STT, 키프레임
    def __init__(self, frame):
        self.idx = frame["emb_idx"].to_numpy().astype(int)
        self.y = frame["y"].to_numpy().astype(np.float32)
    def __len__(self):
        return len(self.idx)
    def __getitem__(self, i):
        j = int(self.idx[i])
        return {
            "t": T_title[j], "th": T_thumb[j],
            "s": get_seq(emb_stt, j), "k": get_seq(emb_kf, j),
            "y": torch.tensor(self.y[i])
        }

def pad_and_mask(seqs):
    lens = [x.size(0) for x in seqs]
    L = max(lens)
    d = seqs[0].size(1)
    B = len(seqs)
    out = torch.zeros(B, L, d)
    mask = torch.zeros(B, L, dtype=torch.bool)
    for i, x in enumerate(seqs):
        out[i, :x.size(0)] = x
        mask[i, :x.size(0)] = True
    return out, mask

def collate_fn(batch):
    t = torch.stack([b["t"] for b in batch], 0)
    th = torch.stack([b["th"] for b in batch], 0)
    s_p, s_m = pad_and_mask([b["s"] for b in batch])
    k_p, k_m = pad_and_mask([b["k"] for b in batch])
    y = torch.stack([b["y"] for b in batch], 0)
    return {"t": t, "th": th, "s": s_p, "s_mask": s_m, "k": k_p, "k_mask": k_m, "y": y}

# 데이터로더 정의
train_dl = DataLoader(PairFusionVarLenDS(df_tr), batch_size=BATCH_TRAIN, shuffle=True, num_workers=2, pin_memory=True, collate_fn=collate_fn)
val_dl = DataLoader(PairFusionVarLenDS(df_va), batch_size=BATCH_EVAL, shuffle=False, num_workers=2, pin_memory=True, collate_fn=collate_fn)
test_dl = DataLoader(PairFusionVarLenDS(df_te), batch_size=BATCH_EVAL, shuffle=False, num_workers=2, pin_memory=True, collate_fn=collate_fn)



In [51]:
# -------- 5) 모델  --------
import torch.nn as nn
import torch.nn.functional as F

BRANCH_NAMES = ["TK","TS","ThK","ThS"]
BRANCH_IDX = {n:i for i,n in enumerate(BRANCH_NAMES)}

def pair_features(h1, h2):
    return torch.cat([h1, h2, torch.abs(h1-h2), h1*h2], dim=-1)

class SeqPool(nn.Module):
    def __init__(self, d, mode="attn", k=3, tau=0.5):
        super().__init__()
        self.mode, self.k, self.tau = mode, k, tau

    def forward(self, X, mask, q=None):
        if self.mode == "mean":
            pooled = (X * mask.unsqueeze(-1)).sum(1) / (mask.sum(1, keepdim=True) + 1e-8)
            return pooled, None

        if self.mode == "attn":
            assert q is not None
            score = F.cosine_similarity(X, q.unsqueeze(1), dim=-1).masked_fill(~mask, -1e9)
            alpha = F.softmax(score / self.tau, dim=1)
            pooled = (X * alpha.unsqueeze(-1)).sum(1)
            return pooled, alpha

        if self.mode == "topk":
            score = X.norm(dim=-1).masked_fill(~mask, -1e9)
            k = min(self.k, X.size(1))
            idx = torch.topk(score, k, dim=1).indices
            b = torch.arange(X.size(0), device=X.device)[:, None]
            picked = X[b, idx]
            pooled = picked.mean(1)
            return pooled, idx

        if self.mode == "topk_sim":
            assert q is not None
            score = F.cosine_similarity(X, q.unsqueeze(1), dim=-1).masked_fill(~mask, -1e9)
            k = min(self.k, X.size(1))
            idx = torch.topk(score, k, dim=1).indices
            b = torch.arange(X.size(0), device=X.device)[:, None]
            picked = X[b, idx]
            pooled = picked.mean(1)
            return pooled, score

        raise ValueError(self.mode)

class BranchMLP(nn.Module):
    def __init__(self, d4, hidden=HIDDEN, dropout=DROPOUT, prior_p=None):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d4, hidden), nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, 1)
        )
        if prior_p is not None:
            b = np.log(prior_p/(1-prior_p + 1e-8))
            with torch.no_grad():
                self.net[-1].bias.fill_(b)

    def forward(self, z):
        return self.net(z).squeeze(-1)

class ClickbaitPairFusionVar(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.pool_s  = SeqPool(d, mode="attn", tau=0.5)  # STT
        self.pool_k  = SeqPool(d, mode="attn", tau=0.5)  # Keyframe

        d4 = d * 4
        self.pre_norm  = nn.LayerNorm(d)
        self.norm_pair = nn.LayerNorm(d4)

        self.b_TK  = BranchMLP(d4)
        self.b_TS  = BranchMLP(d4)
        self.b_ThK = BranchMLP(d4)
        self.b_ThS = BranchMLP(d4)

        self.norm_z   = nn.LayerNorm(d4)
        self.att_w    = nn.Linear(d4, 1, bias=False)
        self.fuse_out = nn.Linear(d4, 1)

        self.branch_names = BRANCH_NAMES
        self.branch_idx   = BRANCH_IDX

    def forward(self, t, th, s_pad, s_mask, k_pad, k_mask, use_branches=("TK","TS","ThK","ThS")):
        # ---- 1) 브랜치별 쿼리로 풀링 결과 4개 생성 ----
        s_by_t,  _  = self.pool_s(s_pad, s_mask, q=t)
        s_by_th, _  = self.pool_s(s_pad, s_mask, q=th)
        k_by_t,  _  = self.pool_k(k_pad, k_mask, q=t)
        k_by_th, _  = self.pool_k(k_pad, k_mask, q=th)

        # ---- 2) 정규화 ----
        t      = self.pre_norm(t)
        th     = self.pre_norm(th)
        s_by_t = self.pre_norm(s_by_t)
        s_by_th= self.pre_norm(s_by_th)
        k_by_t = self.pre_norm(k_by_t)
        k_by_th= self.pre_norm(k_by_th)

        # ---- 3) 브랜치별 pair 특징 ----
        z_TK  = self.norm_pair(pair_features(t,  k_by_t))
        z_TS  = self.norm_pair(pair_features(t,  s_by_t))
        z_ThK = self.norm_pair(pair_features(th, k_by_th))
        z_ThS = self.norm_pair(pair_features(th, s_by_th))

        # ---- 4) 브랜치 로짓 ----
        s_TK  = self.b_TK(z_TK)
        s_TS  = self.b_TS(z_TS)
        s_ThK = self.b_ThK(z_ThK)
        s_ThS = self.b_ThS(z_ThS)

        branch_logits = torch.stack([s_TK, s_TS, s_ThK, s_ThS], 1)  # (B,4)

        # ---- 5) 선택된 브랜치만 attention-fuse ----
        Z = torch.stack([z_TK, z_TS, z_ThK, z_ThS], dim=1)         # (B,4,4d)
        Z = self.norm_z(Z)

        idxs = [self.branch_idx[b] for b in use_branches]          # 선택 브랜치 인덱스
        Z_sel = Z[:, idxs, :]                                      # (B,m,4d)

        alpha_sel = F.softmax(self.att_w(Z_sel).squeeze(-1), dim=1)  # (B,m)
        Zf = (Z_sel * alpha_sel.unsqueeze(-1)).sum(1)                # (B,4d)
        fuse_logit = self.fuse_out(Zf).squeeze(-1)                   # (B,)

        return {
            "branch_logits": branch_logits,   # (B,4)
            "fuse_logit": fuse_logit,         # (B,)
            "alpha_sel": alpha_sel,           # (B,m)
            "idxs": idxs,                     # 선택된 브랜치 인덱스(해석용)
            "Z": Z,                           # (B,4,4d) 전체(원하면)
            # 디버깅/해석용
            "s_by_t": s_by_t, "s_by_th": s_by_th,
            "k_by_t": k_by_t, "k_by_th": k_by_th,
        }

d = T_title.shape[1]
model = ClickbaitPairFusionVar(d).to(device)
print("params(M):", sum(p.numel() for p in model.parameters())/1e6)


params(M): 0.806917


In [None]:
# -------- 6) 학습/평가 --------
bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

BRANCH_NAMES = ["TK","TS","ThK","ThS"]
BRANCH_IDX = {n:i for i,n in enumerate(BRANCH_NAMES)}

def run_epoch(dl, train=True, return_preds=False, use_branches=("TK","TS","ThK","ThS")):
    model.train(train)
    tot, n = 0.0, 0
    all_p, all_y = [], []
    ctx = torch.enable_grad() if train else torch.no_grad()

    idxs = [BRANCH_IDX[b] for b in use_branches]

    with ctx:
        for B in dl:
            y = B["y"].to(device)
            y_sm = y * (1 - LABEL_SMOOTH) + 0.5 * LABEL_SMOOTH

            out = model(
                B["t"].to(device),
                B["th"].to(device),
                B["s"].to(device),
                B["s_mask"].to(device),
                B["k"].to(device),
                B["k_mask"].to(device),
                use_branches=use_branches
            )

            branch_logits = out["branch_logits"]  # (B,4)
            fuse_logit    = out["fuse_logit"]     # (B,)  ※ use_branches로 fuse된 logit

            # ---- Loss: 선택 브랜치만 branch loss에 포함 ----
            branch_loss = sum(bce(branch_logits[:, i], y_sm) for i in idxs)/ len(idxs)
            loss = BR_W * branch_loss + MU * bce(fuse_logit, y_sm)

            if train:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            bs = y.size(0)
            tot += float(loss) * bs
            n += bs

            p = torch.sigmoid(fuse_logit).detach().cpu().numpy()
            all_p.append(p)
            all_y.append(y.detach().cpu().numpy())

    all_p = np.concatenate(all_p)
    all_y = np.concatenate(all_y)
    auc = roc_auc_score(all_y, all_p) if len(np.unique(all_y)) > 1 else float("nan")
    acc = accuracy_score(all_y, (all_p >= 0.5).astype(int))

    if return_preds:
        return tot/n, auc, acc, all_p, all_y
    return tot/n, auc, acc


# -------- 7) 모델 학습 (세팅별로 따로 학습) --------
ABLATIONS = {
    "TS": ("TS",),
    "ThK": ("ThK",),
    "All Branches": ("TK","TS","ThK","ThS"),
}

results = {}

for name, branches in ABLATIONS.items():
    print("\n" + "="*80)
    print("TRAINING:", name, "use_branches=", branches)

    model = ClickbaitPairFusionVar(d).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=3)

    best_auc, best_state, no_improve = -1.0, None, 0
    patience = 8

    for ep in range(1, EPOCHS+1):
        tr_loss, tr_auc, tr_acc = run_epoch(train_dl, True,  use_branches=branches)
        va_loss, va_auc, va_acc = run_epoch(val_dl,   False, use_branches=branches)

        print(f"[{ep:02d}] train loss {tr_loss:.4f} auc {tr_auc:.3f} acc {tr_acc:.3f} | "
              f"val loss {va_loss:.4f} auc {va_auc:.3f} acc {va_acc:.3f}")

        scheduler.step(va_auc)

        if va_auc > best_auc:
            best_auc = va_auc
            best_state = {k:v.detach().cpu() for k,v in model.state_dict().items()}
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                print(f"Early stop at epoch {ep} (best VAL AUC={best_auc:.3f})")
                break

    if best_state is not None:
        model.load_state_dict({k:v.to(device) for k,v in best_state.items()})

    te_loss, te_auc, te_acc = run_epoch(test_dl, False, use_branches=branches)
    print(f"[{name}] TEST loss {te_loss:.4f} | AUC {te_auc:.3f} | ACC {te_acc:.3f}")

    # 저장(세팅별로 파일명 다르게)
    SAVE_PATH = f"{ROOT}/pair_fusion_{name.replace(' ','_')}.pt"
    torch.save(model.state_dict(), SAVE_PATH)
    print("saved to:", SAVE_PATH)

    results[name] = {"val_auc": best_auc, "test_auc": te_auc, "test_acc": te_acc}

print("\n=== Summary ===")
for k,v in results.items():
    print(k, v)

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

BRANCH_NAMES = ["TK", "TS", "ThK", "ThS"]

@torch.no_grad()
def collect_branch_attention(model, dl, device="cuda",
                             use_branches=("TK","TS","ThK","ThS"),
                             n_branches=4):
    model.eval()
    alphas_full = []

    for B in dl:
        out = model(
            B["t"].to(device),
            B["th"].to(device),
            B["s"].to(device),
            B["s_mask"].to(device),
            B["k"].to(device),
            B["k_mask"].to(device),
            use_branches=use_branches
        )

        a_sel = out["alpha_sel"]   # (B,m)
        idxs  = out["idxs"]        # list length m (e.g., [0,1,2,3] subset)

        # (B,4)로 복원: 선택된 브랜치 위치에만 alpha_sel을 넣고 나머진 0
        Bsz = a_sel.size(0)
        a_full = torch.zeros(Bsz, n_branches, device=a_sel.device)
        a_full[:, idxs] = a_sel

        alphas_full.append(a_full.detach().cpu())

    A = torch.cat(alphas_full, dim=0)   # (N,4)
    mean_a = A.mean(dim=0).numpy()
    std_a  = A.std(dim=0).numpy()
    return mean_a, std_a


def plot_branch_contribution(val_mean, test_mean,
                             branch_names=BRANCH_NAMES,
                             title="Branch wise Attention Contribution",
                             ylabel="Average attention weight"):
    x = np.arange(len(branch_names))
    width = 0.35

    plt.figure(figsize=(7,4.2))
    plt.bar(x - width/2, val_mean,  width, label="VAL")
    plt.bar(x + width/2, test_mean, width, label="TEST")

    plt.xticks(x, branch_names)
    plt.ylim(0, 1.0)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.show()


# ====== 실행 ======
# All branches 모델이면 기본 그대로 OK
val_mean, val_std = collect_branch_attention(model, val_dl, device=device,
                                             use_branches=("TK","TS","ThK","ThS"))
test_mean, test_std = collect_branch_attention(model, test_dl, device=device,
                                               use_branches=("TK","TS","ThK","ThS"))

print("VAL mean alpha :", dict(zip(BRANCH_NAMES, np.round(val_mean, 4))))
print("TEST mean alpha:", dict(zip(BRANCH_NAMES, np.round(test_mean, 4))))

plot_branch_contribution(val_mean, test_mean)

