In [2]:
# ============================================================
import os, json, math, random, numpy as np, pandas as pd, torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score, log_loss
from sklearn.model_selection import KFold, train_test_split
from torch.amp import GradScaler, autocast

class CTRDataset(Dataset):
    def __init__(self, Xc, Xn, y):
        self.Xc = torch.as_tensor(Xc, dtype=torch.long)
        self.Xn = torch.as_tensor(Xn, dtype=torch.float32)
        self.y  = torch.as_tensor(y,  dtype=torch.float32)
    def __len__(self): return len(self.y)
    def __getitem__(self, i): return self.Xc[i], self.Xn[i], self.y[i].unsqueeze(-1)

class ImprovedAFM(nn.Module):
    """AFM with attention; supports masking + returns attention for faithfulness tests."""
    def __init__(self, d, attn_dim=64, dropout=0.1):
        super().__init__()
        self.W = nn.Linear(d, attn_dim, bias=False)
        self.h = nn.Linear(attn_dim, 1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, E, pair_mask=None, return_attn=False):
        B, F, d = E.shape
        pairs = []
        for i in range(F):
            for j in range(i+1, F):
                pairs.append(E[:, i] * E[:, j])
        P = torch.stack(pairs, dim=1)   # [B, Pairs, d]
        P = self.dropout(P)

        attn_logits = self.h(torch.tanh(self.W(P)))  # [B, Pairs, 1]
        A = torch.softmax(attn_logits, dim=1)        # [B, Pairs, 1]

        pair_contrib = A * P                         # [B, Pairs, d]
        if pair_mask is not None:
            pair_contrib = pair_contrib * pair_mask.unsqueeze(-1)

        v = pair_contrib.sum(dim=1)                  # [B, d]
        if return_attn:
            return v, A.squeeze(-1)                  # [B,d], [B,Pairs]
        return v


class FeatureTokenizer(nn.Module):
    def __init__(self, cat_cardinalities, n_num, d_model):
        super().__init__()
        self.cat_embs = nn.ModuleList([nn.Embedding(card, d_model) for card in cat_cardinalities])
        self.num_proj = nn.ModuleList([nn.Linear(1, d_model) for _ in range(n_num)])
        self.cls = nn.Parameter(torch.zeros(1,1,d_model)); nn.init.trunc_normal_(self.cls, std=0.02)
    def forward(self, x_cat, x_num):
        B = x_cat.size(0)
        cat_tokens = [emb(x_cat[:, i]) for i, emb in enumerate(self.cat_embs)]
        num_tokens = [proj(x_num[:, i:i+1]) for i, proj in enumerate(self.num_proj)]
        field_embs = torch.stack(cat_tokens + num_tokens, dim=1)
        cls = self.cls.expand(B, -1, -1)
        tokens = torch.cat([cls, field_embs], dim=1)
        return tokens, field_embs

class FTTransformer(nn.Module):
    def __init__(self, d_model=192, nhead=8, ff=512, n_layers=3, dropout=0.15):
        super().__init__()
        enc = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=ff,
                                         dropout=dropout, batch_first=True, activation="gelu", norm_first=True)
        self.encoder = nn.TransformerEncoder(enc, num_layers=n_layers)
    def forward(self, tokens): return self.encoder(tokens)

class ImprovedFTAFM(nn.Module):
    """Improved FT+AFM with larger capacity and better fusion"""
    def __init__(self, cat_cards, n_num, d_model=192, nhead=8, ff=512, n_layers=3, dropout=0.15, afm_attn_dim=64):
        super().__init__()
        self.tok = FeatureTokenizer(cat_cards, n_num, d_model)
        self.backbone = FTTransformer(d_model, nhead, ff, n_layers, dropout)
        self.afm = ImprovedAFM(d_model, afm_attn_dim, dropout=dropout)
        
        # Larger head
        fusion_dim = d_model + d_model
        self.head = nn.Sequential(
            nn.Linear(fusion_dim, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 1)
        )
        
    def forward(self, x_cat, x_num):
        tokens, field_embs = self.tok(x_cat, x_num)
        H = self.backbone(tokens); h_cls = H[:,0,:]
        v_afm = self.afm(field_embs)
        z = torch.cat([h_cls, v_afm], dim=1)
        return self.head(z).squeeze(1)

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_schema(path):
    with open(path) as f:
        return json.load(f)

def load_test_dl(dir_path, batch_size=4096, num_workers=4):
    Xc = np.load(os.path.join(dir_path, "Xc_test.npy"))
    Xn = np.load(os.path.join(dir_path, "Xn_test.npy"))
    y  = np.load(os.path.join(dir_path, "y_test.npy"))
    ds = CTRDataset(Xc, Xn, y)
    return DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

def save_json(obj, path):
    with open(path, "w") as f:
        json.dump(obj, f, indent=2)

def mask_delete_topk(A, k):
    B, P = A.shape
    mask = torch.ones((B, P), device=A.device)
    idx = torch.topk(A, k, dim=1).indices
    mask.scatter_(1, idx, 0.0)
    return mask

def mask_delete_bottomk(A, k):
    B, P = A.shape
    mask = torch.ones((B, P), device=A.device)
    idx = torch.topk(A, k, dim=1, largest=False).indices
    mask.scatter_(1, idx, 0.0)
    return mask

def mask_delete_randomk(A, k):
    B, P = A.shape
    mask = torch.ones((B, P), device=A.device)
    idx = torch.stack([torch.randperm(P, device=A.device)[:k] for _ in range(B)], dim=0)
    mask.scatter_(1, idx, 0.0)
    return mask

@torch.no_grad()
def forward_with_afm_mask(model, x_cat, x_num, pair_mask=None, return_A=False):
    m = model.module if isinstance(model, nn.DataParallel) else model
    m.eval()

    tokens, field_embs = m.tok(x_cat, x_num)
    H = m.backbone(tokens)
    h_cls = H[:, 0, :]

    if return_A:
        v_afm, A = m.afm(field_embs, pair_mask=pair_mask, return_attn=True)
    else:
        v_afm = m.afm(field_embs, pair_mask=pair_mask, return_attn=False)

    z = torch.cat([h_cls, v_afm], dim=1)
    logits = m.head(z).squeeze(1)

    if return_A:
        return logits, A
    return logits

def eval_faithfulness(model, dl, k_list=(1,3,5,10), chunk=2048, max_batches=None):
    # ---------- Baseline ----------
    ys, p0_all = [], []
    with torch.no_grad():
        for b, (Xc, Xn, yb) in enumerate(dl):
            if max_batches is not None and b >= max_batches:
                break
            ys.append(yb.squeeze(-1).numpy())

            probs = []
            for i in range(0, Xc.size(0), chunk):
                Xc_c = Xc[i:i+chunk].to(DEVICE)
                Xn_c = Xn[i:i+chunk].to(DEVICE)
                logits = forward_with_afm_mask(model, Xc_c, Xn_c, pair_mask=None)
                probs.append(torch.sigmoid(logits).cpu().numpy())
            p0_all.append(np.concatenate(probs))

    y = np.concatenate(ys)
    p0 = np.clip(np.concatenate(p0_all), 1e-7, 1-1e-7)
    base_auc = roc_auc_score(y, p0)
    base_ll  = log_loss(y, p0)

    out = {
        "base_auc": float(base_auc),
        "base_ll": float(base_ll),
        "by_k": {}
    }

    # ---------- Masked variants ----------
    for k in k_list:
        p_top_all, p_rnd_all, p_bot_all = [], [], []

        with torch.no_grad():
            for b, (Xc, Xn, yb) in enumerate(dl):
                if max_batches is not None and b >= max_batches:
                    break

                pt, pr, pb = [], [], []
                for i in range(0, Xc.size(0), chunk):
                    Xc_c = Xc[i:i+chunk].to(DEVICE)
                    Xn_c = Xn[i:i+chunk].to(DEVICE)

                    # get attention for this chunk
                    _, A = forward_with_afm_mask(model, Xc_c, Xn_c, pair_mask=None, return_A=True)

                    m_top = mask_delete_topk(A, k)
                    m_rnd = mask_delete_randomk(A, k)
                    m_bot = mask_delete_bottomk(A, k)

                    pt.append(torch.sigmoid(forward_with_afm_mask(model, Xc_c, Xn_c, m_top)).cpu().numpy())
                    pr.append(torch.sigmoid(forward_with_afm_mask(model, Xc_c, Xn_c, m_rnd)).cpu().numpy())
                    pb.append(torch.sigmoid(forward_with_afm_mask(model, Xc_c, Xn_c, m_bot)).cpu().numpy())

                p_top_all.append(np.concatenate(pt))
                p_rnd_all.append(np.concatenate(pr))
                p_bot_all.append(np.concatenate(pb))

        def metrics(p_list):
            p = np.clip(np.concatenate(p_list), 1e-7, 1-1e-7)
            return roc_auc_score(y, p), log_loss(y, p)

        auc_top, ll_top = metrics(p_top_all)
        auc_rnd, ll_rnd = metrics(p_rnd_all)
        auc_bot, ll_bot = metrics(p_bot_all)

        out["by_k"][int(k)] = {
            # absolute metrics
            "topk_delete_auc": float(auc_top),
            "topk_delete_ll": float(ll_top),
            "randk_delete_auc": float(auc_rnd),
            "randk_delete_ll": float(ll_rnd),
            "botk_delete_auc": float(auc_bot),
            "botk_delete_ll": float(ll_bot),

            # deltas vs baseline
            "delta_auc_topk": float(auc_top - base_auc),
            "delta_auc_rand": float(auc_rnd - base_auc),
            "delta_auc_bottom": float(auc_bot - base_auc),

            "delta_ll_topk": float(ll_top - base_ll),
            "delta_ll_rand": float(ll_rnd - base_ll),
            "delta_ll_bottom": float(ll_bot - base_ll),
        }

    return out


In [None]:
AVAZU_DIR = "runs_avazu_40m_improved_ft_afm"
ckpt = os.path.join(AVAZU_DIR, "improved_ft_afm_best.pth")
schema = load_schema(os.path.join(AVAZU_DIR, "schema.json"))

model = ImprovedFTAFM(
    schema["cat_cards"], len(schema["num_cols"]),
    d_model=192, nhead=8, ff=512, n_layers=3, dropout=0.15, afm_attn_dim=64
).to(DEVICE)

model.load_state_dict(torch.load(ckpt, map_location=DEVICE))
model.eval()

dl = load_test_dl(AVAZU_DIR, batch_size=4096)
faith = eval_faithfulness(model, dl, k_list=(1,3,5,10), chunk=2048)

save_json(faith, os.path.join(AVAZU_DIR, "avazu_faithfulness_afm.json"))
faith


  model.load_state_dict(torch.load(ckpt, map_location=DEVICE))


In [4]:
CRITEO_PREP = "criteo_preprocessed"
CRITEO_CKPT_DIR = "criteo_ft_afm_results"
ckpt = os.path.join(CRITEO_CKPT_DIR, "criteo_ft_afm_best.pth")
schema = load_schema(os.path.join(CRITEO_PREP, "schema.json"))

model = ImprovedFTAFM(
    schema["cat_cards"], len(schema["num_cols"]),
    d_model=192, nhead=8, ff=512, n_layers=3, dropout=0.15, afm_attn_dim=64
).to(DEVICE)

model.load_state_dict(torch.load(ckpt, map_location=DEVICE))
model.eval()

dl = load_test_dl(CRITEO_PREP, batch_size=4096)
faith = eval_faithfulness(model, dl, k_list=(1,3,5,10), chunk=2048)

save_json(faith, os.path.join(CRITEO_CKPT_DIR, "criteo_faithfulness_afm.json"))
faith


  model.load_state_dict(torch.load(ckpt, map_location=DEVICE))


{'base_auc': 0.80764130603302,
 'base_ll': 0.44466287553408423,
 'by_k': {1: {'topk_delete_auc': 0.8075364893723618,
   'topk_delete_ll': 0.4446657568078751,
   'randk_delete_auc': 0.8076409925795068,
   'randk_delete_ll': 0.44466288175293484,
   'botk_delete_auc': 0.8076413059889463,
   'botk_delete_ll': 0.44466287554097345,
   'delta_auc_topk': -0.00010481666065820239,
   'delta_auc_rand': -3.1345351314548964e-07,
   'delta_auc_bottom': -4.4073633631569464e-11,
   'delta_ll_topk': 2.881273790888983e-06,
   'delta_ll_rand': 6.218850601147352e-09,
   'delta_ll_bottom': 6.889211423555253e-12},
  3: {'topk_delete_auc': 0.8074629677364917,
   'topk_delete_ll': 0.4446624416820689,
   'randk_delete_auc': 0.8076403555928955,
   'randk_delete_ll': 0.4446627884293513,
   'botk_delete_auc': 0.8076413060343934,
   'botk_delete_ll': 0.4446628755351303,
   'delta_auc_topk': -0.0001783382965282465,
   'delta_auc_rand': -9.504401244919691e-07,
   'delta_auc_bottom': 1.3734569037637812e-12,
   'delta

In [5]:
OUTBRAIN_PREP = "outbrain_preprocessed_40m"
OUTBRAIN_CKPT_DIR = "outbrain_ft_afm_results_40m"
ckpt = os.path.join(OUTBRAIN_CKPT_DIR, "outbrain_ft_afm_best.pth")
schema = load_schema(os.path.join(OUTBRAIN_PREP, "schema.json"))

model = ImprovedFTAFM(
    schema["cat_cards"], len(schema["num_cols"]),
    d_model=192, nhead=8, ff=512, n_layers=3, dropout=0.15, afm_attn_dim=64
).to(DEVICE)

model.load_state_dict(torch.load(ckpt, map_location=DEVICE))
model.eval()

dl = load_test_dl(OUTBRAIN_PREP, batch_size=4096)
faith = eval_faithfulness(model, dl, k_list=(1,3,5,10), chunk=2048)

save_json(faith, os.path.join(OUTBRAIN_CKPT_DIR, "outbrain_faithfulness_afm.json"))
faith


  model.load_state_dict(torch.load(ckpt, map_location=DEVICE))


{'base_auc': 0.7003085858204137,
 'base_ll': 0.451559541236667,
 'by_k': {1: {'topk_delete_auc': 0.6835491201549444,
   'topk_delete_ll': 0.4602958168446729,
   'randk_delete_auc': 0.7001118569454646,
   'randk_delete_ll': 0.4516338448184922,
   'botk_delete_auc': 0.7003085840110443,
   'botk_delete_ll': 0.451559542134899,
   'delta_auc_topk': -0.016759465665469264,
   'delta_auc_rand': -0.00019672887494903701,
   'delta_auc_bottom': -1.8093693210374795e-09,
   'delta_ll_topk': 0.008736275608005939,
   'delta_ll_rand': 7.430358182525243e-05,
   'delta_ll_bottom': 8.982320442996183e-10},
  3: {'topk_delete_auc': 0.6755131863129705,
   'topk_delete_ll': 0.46429575202786416,
   'randk_delete_auc': 0.6997135056229014,
   'randk_delete_ll': 0.4517826901525592,
   'botk_delete_auc': 0.7003085801707208,
   'botk_delete_ll': 0.4515595443930273,
   'delta_auc_topk': -0.02479539950744314,
   'delta_auc_rand': -0.0005950801975123099,
   'delta_auc_bottom': -5.6496928246829725e-09,
   'delta_ll_to