In [4]:
# ============================================================
# FT-AFM Training on 40M Samples (8:1:1 split)
# Based on original pipeline parameters
# ============================================================
import os, json, math, random, numpy as np, pandas as pd, torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score, log_loss
from sklearn.model_selection import KFold, train_test_split
from torch.amp import GradScaler, autocast

# ============== Repro/Device ==============
def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(42)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_GPUS = torch.cuda.device_count()
USE_DDP = NUM_GPUS > 1

print(f"Available GPUs: {NUM_GPUS}")
if USE_DDP:
    print(f"Using DataParallel with {NUM_GPUS} GPUs")
else:
    print(f"Using single device: {DEVICE}")

# ============== Utils ==============
def ensure_dir(p): os.makedirs(p, exist_ok=True)

def save_json(obj, path):
    with open(path, "w") as f:
        json.dump(obj, f, indent=2)

def kfold_target_encode(df_tr, df_va, col, yname, n_splits=5, min_samples=50, prior=None, seed=42):
    if prior is None: prior = float(df_tr[yname].mean())
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=seed)
    te_tr = pd.Series(np.zeros(len(df_tr), dtype="float32"), index=df_tr.index)
    for tr_idx, hold_idx in kf.split(df_tr):
        cur = df_tr.iloc[tr_idx]
        means = cur.groupby(col)[yname].mean()
        cnts  = cur.groupby(col)[yname].size()
        m = (means*cnts + prior*min_samples) / (cnts + min_samples)
        te_tr.iloc[hold_idx] = df_tr.iloc[hold_idx][col].map(m).fillna(prior).astype("float32")
    means = df_tr.groupby(col)[yname].mean()
    cnts  = df_tr.groupby(col)[yname].size()
    mfull = (means*cnts + prior*min_samples) / (cnts + min_samples)
    te_va = df_va[col].map(mfull).fillna(prior).astype("float32")
    return te_tr, te_va

def temporal_or_stratified_split(df, label, time_cols=None, order_key=None, train=0.8, val=0.1, test=0.1):
    assert abs(train + val + test - 1.0) < 1e-8
    if time_cols and all(c in df.columns for c in time_cols):
        key = order_key if order_key is not None else time_cols
        df_sorted = df.sort_values(key).reset_index(drop=True)
        n = len(df_sorted); n_tr = int(train*n); n_va = int(val*n)
        df_tr = df_sorted.iloc[:n_tr].copy()
        df_va = df_sorted.iloc[n_tr:n_tr+n_va].copy()
        df_te = df_sorted.iloc[n_tr+n_va:].copy()
    else:
        df_tr, df_tmp = train_test_split(df, test_size=(1-train), stratify=df[label], random_state=42)
        df_va, df_te = train_test_split(df_tmp, test_size=(test/(test+val)), stratify=df_tmp[label], random_state=42)
        df_tr, df_va, df_te = df_tr.copy(), df_va.copy(), df_te.copy()
    return df_tr, df_va, df_te

# ============== Models ==============
class AFM(nn.Module):
    def __init__(self, d, attn_dim=32):
        super().__init__()
        self.W = nn.Linear(d, attn_dim, bias=False)
        self.h = nn.Linear(attn_dim, 1, bias=False)
    def forward(self, E):
        B,F,d = E.shape; pairs=[]
        for i in range(F):
            for j in range(i+1, F):
                pairs.append(E[:,i]*E[:,j])
        P = torch.stack(pairs, dim=1)
        A = torch.softmax(self.h(torch.tanh(self.W(P))), dim=1)
        return (A * P).sum(dim=1)

class FeatureTokenizer(nn.Module):
    def __init__(self, cat_cardinalities, n_num, d_model):
        super().__init__()
        self.cat_embs = nn.ModuleList([nn.Embedding(card, d_model) for card in cat_cardinalities])
        self.num_proj = nn.ModuleList([nn.Linear(1, d_model) for _ in range(n_num)])
        self.cls = nn.Parameter(torch.zeros(1,1,d_model)); nn.init.trunc_normal_(self.cls, std=0.02)
    def forward(self, x_cat, x_num):
        B = x_cat.size(0)
        cat_tokens = [emb(x_cat[:, i]) for i, emb in enumerate(self.cat_embs)]
        num_tokens = [proj(x_num[:, i:i+1]) for i, proj in enumerate(self.num_proj)]
        field_embs = torch.stack(cat_tokens + num_tokens, dim=1)
        cls = self.cls.expand(B, -1, -1)
        tokens = torch.cat([cls, field_embs], dim=1)
        return tokens, field_embs

class FTTransformer(nn.Module):
    def __init__(self, d_model=128, nhead=8, ff=256, n_layers=2, dropout=0.1):
        super().__init__()
        enc = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=ff,
                                         dropout=dropout, batch_first=True, activation="gelu", norm_first=True)
        self.encoder = nn.TransformerEncoder(enc, num_layers=n_layers)
    def forward(self, tokens): return self.encoder(tokens)

class FTWithAFM(nn.Module):
    def __init__(self, cat_cards, n_num, d_model=128, nhead=8, ff=256, n_layers=2, dropout=0.1):
        super().__init__()
        self.tok = FeatureTokenizer(cat_cards, n_num, d_model)
        self.backbone = FTTransformer(d_model, nhead, ff, n_layers, dropout)
        self.afm = AFM(d_model, attn_dim=32)
        fusion_in = d_model + d_model
        self.head = nn.Sequential(nn.Linear(fusion_in, 128), nn.ReLU(), nn.Dropout(dropout), nn.Linear(128, 1))
    def forward(self, x_cat, x_num):
        tokens, field_embs = self.tok(x_cat, x_num)
        H = self.backbone(tokens); cls = H[:,0,:]
        afm_out = self.afm(field_embs)
        z = torch.cat([cls, afm_out], dim=1)
        return self.head(z).squeeze(1)

# ============== Train/Eval ==============
def init_final_bias_to_ctr(module_last_linear, base_ctr):
    with torch.no_grad():
        module_last_linear.bias.fill_(math.log(base_ctr/(1.0-base_ctr)))

def evaluate_model(model, dl):
    model.eval(); ys, ps = [], []
    with torch.no_grad():
        for Xc, Xn, yb in dl:
            Xc, Xn, yb = Xc.to(DEVICE), Xn.to(DEVICE), yb.to(DEVICE)
            logits = model(Xc, Xn)
            if logits.dim()==1: logits = logits.unsqueeze(1)
            probs = torch.sigmoid(logits).squeeze(-1)
            ys.append(yb.squeeze(-1).cpu().numpy()); ps.append(probs.cpu().numpy())
    y_true = np.concatenate(ys); y_prob = np.clip(np.concatenate(ps), 1e-7, 1-1e-7)
    return roc_auc_score(y_true, y_prob), log_loss(y_true, y_prob)

class TrainCfg:
    def __init__(self, d_model=128, nhead=8, ff=256, n_layers=2, dropout=0.10,
                 lr=1e-3, weight_decay=1e-4, max_epochs=12, patience=3, clip=1.0):
        self.d_model=d_model; self.nhead=nhead; self.ff=ff; self.n_layers=n_layers
        self.dropout=dropout; self.lr=lr; self.weight_decay=weight_decay
        self.max_epochs=max_epochs; self.patience=patience; self.clip=clip
        self.use_amp=torch.cuda.is_available()

def train_ft_afm(model, tr_dl, va_dl, cfg, outdir):
    ensure_dir(outdir)
    
    # Wrap model with DataParallel if multiple GPUs available
    if USE_DDP and NUM_GPUS > 1:
        print(f"Wrapping model with DataParallel across {NUM_GPUS} GPUs")
        model = nn.DataParallel(model)
    
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    scaler = GradScaler(device='cuda', enabled=cfg.use_amp)
    best_ll, best_state, stale = float("inf"), None, 0
    history = []
    
    for ep in range(1, cfg.max_epochs+1):
        model.train(); run = 0.0
        batch_count = 0
        for Xc, Xn, yb in tr_dl:
            Xc, Xn, yb = Xc.to(DEVICE), Xn.to(DEVICE), yb.to(DEVICE)
            opt.zero_grad(set_to_none=True)
            with autocast(device_type='cuda', enabled=cfg.use_amp):
                logits = model(Xc, Xn)
                if logits.dim()==1: logits = logits.unsqueeze(1)
                loss = F.binary_cross_entropy_with_logits(logits, yb)
            scaler.scale(loss).backward()
            
            # Get actual model for gradient clipping (unwrap DataParallel if needed)
            model_to_clip = model.module if isinstance(model, nn.DataParallel) else model
            torch.nn.utils.clip_grad_norm_(model_to_clip.parameters(), cfg.clip)
            
            scaler.step(opt); scaler.update()
            run += loss.item() * yb.size(0)
            batch_count += 1
            
            # Progress indicator every 100 batches
            if batch_count % 100 == 0:
                print(f"  Epoch {ep}, Batch {batch_count}/{len(tr_dl)}, Loss: {loss.item():.4f}")

        val_auc, val_ll = evaluate_model(model, va_dl)
        train_ll = run / len(tr_dl.dataset)
        improved = val_ll + 1e-9 < best_ll
        if improved:
            best_ll = val_ll; stale = 0
            # Save underlying model state (unwrap DataParallel)
            model_to_save = model.module if isinstance(model, nn.DataParallel) else model
            best_state = {k: v.detach().cpu().clone() for k,v in model_to_save.state_dict().items()}
        else:
            stale += 1

        history.append({"epoch": ep, "train_ll": float(train_ll), "val_ll": float(val_ll), "val_auc": float(val_auc)})
        print(f"[FT-AFM] ep{ep:02d} train_ll={train_ll:.4f} val_ll={val_ll:.4f} val_auc={val_auc:.4f} "
              f"{'*BEST*' if improved else f'stale {stale}/{cfg.patience}'}")
        if stale >= cfg.patience:
            print(f"[FT-AFM] early-stopped.")
            break

    if best_state is not None:
        # Load back to the underlying model
        model_to_load = model.module if isinstance(model, nn.DataParallel) else model
        model_to_load.load_state_dict(best_state)
        torch.save(best_state, os.path.join(outdir, "ft_afm_40m_best.pth"))
    pd.DataFrame(history).to_csv(os.path.join(outdir, "ft_afm_40m_history.csv"), index=False)
    
    val_auc, val_ll = evaluate_model(model, va_dl)
    save_json({"val_auc": float(val_auc), "val_logloss": float(val_ll)}, 
              os.path.join(outdir, "ft_afm_40m_val_metrics.json"))
    
    # Return the unwrapped model
    return model.module if isinstance(model, nn.DataParallel) else model

# ============== Dataset ==============
class CTRDataset(Dataset):
    def __init__(self, Xc, Xn, y):
        self.Xc = torch.as_tensor(Xc, dtype=torch.long)
        self.Xn = torch.as_tensor(Xn, dtype=torch.float32)
        self.y  = torch.as_tensor(y,  dtype=torch.float32)
    def __len__(self): return len(self.y)
    def __getitem__(self, i): return self.Xc[i], self.Xn[i], self.y[i].unsqueeze(-1)

# ============== Main Runner ==============
def run_ft_afm_40m(csv_path, outdir, label_col, time_info, base_num_cols, drop_cols, 
                   single_freq_cats, pair_freq_cats, te_targets):
    print(f"\n================= FT-AFM on 40M Samples =================")
    ensure_dir(outdir)

    # Load data
    print("Loading data...")
    if csv_path.endswith('.parquet'):
        df = pd.read_parquet(csv_path)
    else:
        df = pd.read_csv(csv_path)
    print(f"Loaded {len(df):,} rows")
    assert label_col in df.columns

    # Save schema before
    schema_before = {
        "n_rows": int(len(df)),
        "n_cols": int(len(df.columns)),
        "positive_rate": float(df[label_col].mean())
    }
    save_json(schema_before, os.path.join(outdir, "schema_before.json"))

    # Cast base numeric columns
    for c in base_num_cols:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors="coerce").fillna(0)

    # 8:1:1 split
    print("Splitting data (8:1:1)...")
    if time_info is not None:
        df_tr, df_va, df_te = temporal_or_stratified_split(
            df, label_col, time_cols=time_info.get("time_cols"), 
            order_key=time_info.get("order_key"), train=0.8, val=0.1, test=0.1
        )
    else:
        df_tr, df_va, df_te = temporal_or_stratified_split(
            df, label_col, time_cols=None, order_key=None, train=0.8, val=0.1, test=0.1
        )
    print(f"Train: {len(df_tr):,}, Val: {len(df_va):,}, Test: {len(df_te):,}")

    # Feature engineering
    all_cols_tr = df_tr.columns.tolist()
    drop_cols_eff = [c for c in drop_cols if c in all_cols_tr]
    base_num_eff = [c for c in base_num_cols if c in all_cols_tr]
    cat_cols = [c for c in all_cols_tr if c not in ([label_col] + drop_cols_eff + base_num_eff)]
    num_cols = base_num_eff[:]

    # Categorical encoding with UNK
    print("Encoding categoricals...")
    cat_cards, cat_maps = [], {}
    for c in cat_cols:
        uniq = pd.Index(df_tr[c].astype("object").unique())
        mapping = {v:i for i,v in enumerate(uniq)}
        unk_id = len(mapping)
        cat_maps[c] = (mapping, unk_id)
        for d in (df_tr, df_va, df_te):
            d[c] = d[c].map(mapping).fillna(unk_id).astype("int64")
        cat_cards.append(unk_id + 1)

    # Numeric base features
    for c in num_cols:
        for d in (df_tr, df_va, df_te):
            d[c] = pd.to_numeric(d[c], errors="coerce").fillna(0).astype("float32")

    # Single frequency features
    print("Engineering frequency features...")
    cand_freq = [c for c in single_freq_cats if c in df_tr.columns]
    for c in cand_freq:
        vc = df_tr[c].value_counts()
        df_tr[f"{c}_freq"] = df_tr[c].map(vc).astype("float32")
        df_va[f"{c}_freq"] = df_va[c].map(vc).fillna(0).astype("float32")
        df_te[f"{c}_freq"] = df_te[c].map(vc).fillna(0).astype("float32")

    # Pairwise frequency features
    pairs_eff = [(a,b) for (a,b) in pair_freq_cats if set([a,b]).issubset(df_tr.columns)]
    for a,b in pairs_eff:
        key_tr = df_tr[a].astype("int64")*10_000_000 + df_tr[b].astype("int64")
        key_va = df_va[a].astype("int64")*10_000_000 + df_va[b].astype("int64")
        key_te = df_te[a].astype("int64")*10_000_000 + df_te[b].astype("int64")
        vc = key_tr.value_counts()
        name = f"{a}__{b}__freq"
        df_tr[name] = key_tr.map(vc).astype("float32")
        df_va[name] = key_va.map(vc).fillna(0).astype("float32")
        df_te[name] = key_te.map(vc).fillna(0).astype("float32")

    # Target encoding
    print("Target encoding...")
    te_eff = [c for c in te_targets if c in df_tr.columns]
    for c in te_eff:
        te_tr, te_va = kfold_target_encode(df_tr, df_va, c, yname=label_col, n_splits=5, min_samples=50)
        prior = float(df_tr[label_col].mean())
        means = df_tr.groupby(c)[label_col].mean()
        cnts  = df_tr.groupby(c)[label_col].size()
        mfull = (means*cnts + prior*50) / (cnts + 50)
        te_te = df_te[c].map(mfull).fillna(prior).astype("float32")
        df_tr[f"{c}_te"] = te_tr.astype("float32")
        df_va[f"{c}_te"] = te_va.astype("float32")
        df_te[f"{c}_te"] = te_te.astype("float32")

    # Register new numeric columns
    new_num = [f"{c}_freq" for c in cand_freq] + [f"{a}__{b}__freq" for (a,b) in pairs_eff] + [f"{c}_te" for c in te_eff]
    for c in new_num:
        if c not in num_cols: num_cols.append(c)

    # Log1p freq features
    freq_like = [c for c in num_cols if c.endswith("_freq")]
    for c in freq_like:
        for d in (df_tr, df_va, df_te):
            d[c] = np.log1p(pd.to_numeric(d[c], errors="coerce").fillna(0).clip(lower=0).astype("float64"))

    # Standardize all numerics
    print("Standardizing numerics...")
    num_means = {c: float(pd.to_numeric(df_tr[c], errors="coerce").fillna(0).mean()) for c in num_cols}
    num_stds  = {c: float(pd.to_numeric(df_tr[c], errors="coerce").fillna(0).std(ddof=0)) for c in num_cols}
    for c in num_cols:
        mu, sd = num_means[c], (num_stds[c] if num_stds[c] > 1e-8 else 1.0)
        for d in (df_tr, df_va, df_te):
            v = pd.to_numeric(d[c], errors="coerce")
            d[c] = ((v - mu)/sd).replace([np.inf,-np.inf], np.nan).fillna(0).astype("float32")

    # Save schema after
    schema_after = {
        "cat_cols": cat_cols,
        "num_cols": num_cols,
        "cat_cards": cat_cards,
        "splits": {"train": len(df_tr), "val": len(df_va), "test": len(df_te)},
        "train_ctr": float(df_tr[label_col].mean()),
        "val_ctr": float(df_va[label_col].mean()),
        "test_ctr": float(df_te[label_col].mean())
    }
    save_json(schema_after, os.path.join(outdir, "schema_after.json"))

    # Prepare tensors
    print("Preparing data loaders...")
    Xc_tr = df_tr[cat_cols].to_numpy()
    Xn_tr = df_tr[num_cols].to_numpy().astype("float32")
    y_tr = df_tr[label_col].to_numpy().astype("float32")
    
    Xc_va = df_va[cat_cols].to_numpy()
    Xn_va = df_va[num_cols].to_numpy().astype("float32")
    y_va = df_va[label_col].to_numpy().astype("float32")
    
    Xc_te = df_te[cat_cols].to_numpy()
    Xn_te = df_te[num_cols].to_numpy().astype("float32")
    y_te = df_te[label_col].to_numpy().astype("float32")

    # Increase batch size for multi-GPU training
    batch_size_train = 4096 * NUM_GPUS if USE_DDP else 4096
    batch_size_eval = 8192 * NUM_GPUS if USE_DDP else 8192
    
    print(f"Train batch size: {batch_size_train} (across {NUM_GPUS} GPUs)")
    print(f"Eval batch size: {batch_size_eval}")
    
    tr_dl = DataLoader(CTRDataset(Xc_tr, Xn_tr, y_tr), batch_size=batch_size_train, 
                       shuffle=True, num_workers=4, pin_memory=True)
    va_dl = DataLoader(CTRDataset(Xc_va, Xn_va, y_va), batch_size=batch_size_eval, 
                       shuffle=False, num_workers=4, pin_memory=True)
    te_dl = DataLoader(CTRDataset(Xc_te, Xn_te, y_te), batch_size=batch_size_eval, 
                       shuffle=False, num_workers=4, pin_memory=True)

    # Build and train FT-AFM
    base_ctr = float(df_tr[label_col].mean())
    cfg = TrainCfg(d_model=128, nhead=8, ff=256, n_layers=2, dropout=0.10, 
                   lr=1e-3, weight_decay=1e-4, max_epochs=12, patience=3)
    
    print("Building FT-AFM model...")
    model = FTWithAFM(cat_cards, len(num_cols), cfg.d_model, cfg.nhead, 
                      cfg.ff, cfg.n_layers, cfg.dropout).to(DEVICE)
    init_final_bias_to_ctr(model.head[-1], base_ctr)
    
    print(f"Training FT-AFM on {len(df_tr):,} samples...")
    model = train_ft_afm(model, tr_dl, va_dl, cfg, outdir)
    
    # Final test evaluation
    print("Evaluating on test set...")
    test_auc, test_ll = evaluate_model(model, te_dl)
    save_json({"test_auc": float(test_auc), "test_logloss": float(test_ll)},
              os.path.join(outdir, "ft_afm_40m_test_metrics.json"))
    
    print(f"\n{'='*60}")
    print(f"FT-AFM Results on 40M Samples:")
    print(f"Test AUC: {test_auc:.4f}")
    print(f"Test LogLoss: {test_ll:.4f}")
    print(f"{'='*60}\n")
    
    return test_auc, test_ll

# ================== RUN: AVAZU 40M ==================
AVAZU_40M_PATH = "/home/elicer/ctr_project/data/avazu_full/avazu_full.parquet"
OUT_AVAZU_40M = "runs_avazu_40m_ft_afm"

print("\n" + "="*80)
print("Starting FT-AFM training on Avazu 40M dataset")
print("="*80)

test_auc, test_ll = run_ft_afm_40m(
    csv_path=AVAZU_40M_PATH,
    outdir=OUT_AVAZU_40M,
    label_col="click",
    time_info={"time_cols":["hour"], "order_key":"hour"},
    base_num_cols=["hour"],
    drop_cols=["id"],
    single_freq_cats=["C14","C17","site_id","app_id","device_model"],
    pair_freq_cats=[("site_id","app_id"),("device_model","hour")],
    te_targets=["C14","site_id"]
)

print("\n" + "="*80)
print("FINAL RESULTS - Avazu 40M FT-AFM")
print("="*80)
print(f"Test AUC:     {test_auc:.6f}")
print(f"Test LogLoss: {test_ll:.6f}")
print("="*80)
print(f"\nResults saved to: {OUT_AVAZU_40M}/")
print("✅ Training complete!")

Available GPUs: 4
Using DataParallel with 4 GPUs

Starting FT-AFM training on Avazu 40M dataset

Loading data...
Loaded 40,428,967 rows
Splitting data (8:1:1)...
Train: 32,343,173, Val: 4,042,896, Test: 4,042,898
Encoding categoricals...
Engineering frequency features...
Target encoding...
Standardizing numerics...
Preparing data loaders...
Train batch size: 16384 (across 4 GPUs)
Eval batch size: 32768
Building FT-AFM model...




Training FT-AFM on 32,343,173 samples...
Wrapping model with DataParallel across 4 GPUs
  Epoch 1, Batch 100/1975, Loss: 0.4097
  Epoch 1, Batch 200/1975, Loss: 0.4107
  Epoch 1, Batch 300/1975, Loss: 0.4027
  Epoch 1, Batch 400/1975, Loss: 0.4027
  Epoch 1, Batch 500/1975, Loss: 0.4064
  Epoch 1, Batch 600/1975, Loss: 0.4026
  Epoch 1, Batch 700/1975, Loss: 0.3987
  Epoch 1, Batch 800/1975, Loss: 0.3978
  Epoch 1, Batch 900/1975, Loss: 0.3984
  Epoch 1, Batch 1000/1975, Loss: 0.3970
  Epoch 1, Batch 1100/1975, Loss: 0.3975
  Epoch 1, Batch 1200/1975, Loss: 0.3967
  Epoch 1, Batch 1300/1975, Loss: 0.3980
  Epoch 1, Batch 1400/1975, Loss: 0.3978
  Epoch 1, Batch 1500/1975, Loss: 0.4044
  Epoch 1, Batch 1600/1975, Loss: 0.3968
  Epoch 1, Batch 1700/1975, Loss: 0.3980
  Epoch 1, Batch 1800/1975, Loss: 0.3982
  Epoch 1, Batch 1900/1975, Loss: 0.3987
[FT-AFM] ep01 train_ll=0.4009 val_ll=0.3874 val_auc=0.7422 *BEST*
  Epoch 2, Batch 100/1975, Loss: 0.3991
  Epoch 2, Batch 200/1975, Loss: 0.3

In [1]:
# ============================================================
# IMPROVED FT-AFM Training on 40M Samples
# Enhancements:
# - Larger model capacity (d=192, layers=3, ff=512)
# - Enhanced feature engineering (time features, more TE, more crosses)
# - Multi-head AFM with dropout
# - Learning rate scheduling (cosine annealing + warmup)
# - Better regularization
# - Longer training with patience
# ============================================================
import os, json, math, random, numpy as np, pandas as pd, torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score, log_loss
from sklearn.model_selection import KFold, train_test_split
from torch.amp import GradScaler, autocast

# ============== Repro/Device ==============
def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(42)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_GPUS = torch.cuda.device_count()
USE_DDP = NUM_GPUS > 1

print(f"Available GPUs: {NUM_GPUS}")
if USE_DDP:
    print(f"Using DataParallel with {NUM_GPUS} GPUs")
else:
    print(f"Using single device: {DEVICE}")
    
RUN_FT_AFM = False


# ============== Utils ==============
def ensure_dir(p): os.makedirs(p, exist_ok=True)

def save_json(obj, path):
    with open(path, "w") as f:
        json.dump(obj, f, indent=2)

def kfold_target_encode(df_tr, df_va, col, yname, n_splits=5, min_samples=50, prior=None, seed=42):
    if prior is None: prior = float(df_tr[yname].mean())
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=seed)
    te_tr = pd.Series(np.zeros(len(df_tr), dtype="float32"), index=df_tr.index)
    for tr_idx, hold_idx in kf.split(df_tr):
        cur = df_tr.iloc[tr_idx]
        means = cur.groupby(col)[yname].mean()
        cnts  = cur.groupby(col)[yname].size()
        m = (means*cnts + prior*min_samples) / (cnts + min_samples)
        te_tr.iloc[hold_idx] = df_tr.iloc[hold_idx][col].map(m).fillna(prior).astype("float32")
    means = df_tr.groupby(col)[yname].mean()
    cnts  = df_tr.groupby(col)[yname].size()
    mfull = (means*cnts + prior*min_samples) / (cnts + min_samples)
    te_va = df_va[col].map(mfull).fillna(prior).astype("float32")
    return te_tr, te_va

def temporal_or_stratified_split(df, label, time_cols=None, order_key=None, train=0.8, val=0.1, test=0.1):
    assert abs(train + val + test - 1.0) < 1e-8
    if time_cols and all(c in df.columns for c in time_cols):
        key = order_key if order_key is not None else time_cols
        df_sorted = df.sort_values(key).reset_index(drop=True)
        n = len(df_sorted); n_tr = int(train*n); n_va = int(val*n)
        df_tr = df_sorted.iloc[:n_tr].copy()
        df_va = df_sorted.iloc[n_tr:n_tr+n_va].copy()
        df_te = df_sorted.iloc[n_tr+n_va:].copy()
    else:
        df_tr, df_tmp = train_test_split(df, test_size=(1-train), stratify=df[label], random_state=42)
        df_va, df_te = train_test_split(df_tmp, test_size=(test/(test+val)), stratify=df_tmp[label], random_state=42)
        df_tr, df_va, df_te = df_tr.copy(), df_va.copy(), df_te.copy()
    return df_tr, df_va, df_te

# ============== IMPROVED Models ==============
class ImprovedAFM(nn.Module):
    """AFM with larger attention dimension and dropout"""
    def __init__(self, d, attn_dim=64, dropout=0.1):
        super().__init__()
        self.W = nn.Linear(d, attn_dim, bias=False)
        self.h = nn.Linear(attn_dim, 1, bias=False)
        self.dropout = nn.Dropout(dropout)
    def forward(self, E):
        B,F,d = E.shape; pairs=[]
        for i in range(F):
            for j in range(i+1, F):
                pairs.append(E[:,i]*E[:,j])
        P = torch.stack(pairs, dim=1)
        P = self.dropout(P)
        A = torch.softmax(self.h(torch.tanh(self.W(P))), dim=1)
        return (A * P).sum(dim=1)

class FeatureTokenizer(nn.Module):
    def __init__(self, cat_cardinalities, n_num, d_model):
        super().__init__()
        self.cat_embs = nn.ModuleList([nn.Embedding(card, d_model) for card in cat_cardinalities])
        self.num_proj = nn.ModuleList([nn.Linear(1, d_model) for _ in range(n_num)])
        self.cls = nn.Parameter(torch.zeros(1,1,d_model)); nn.init.trunc_normal_(self.cls, std=0.02)
    def forward(self, x_cat, x_num):
        B = x_cat.size(0)
        cat_tokens = [emb(x_cat[:, i]) for i, emb in enumerate(self.cat_embs)]
        num_tokens = [proj(x_num[:, i:i+1]) for i, proj in enumerate(self.num_proj)]
        field_embs = torch.stack(cat_tokens + num_tokens, dim=1)
        cls = self.cls.expand(B, -1, -1)
        tokens = torch.cat([cls, field_embs], dim=1)
        return tokens, field_embs

class FTTransformer(nn.Module):
    def __init__(self, d_model=192, nhead=8, ff=512, n_layers=3, dropout=0.15):
        super().__init__()
        enc = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=ff,
                                         dropout=dropout, batch_first=True, activation="gelu", norm_first=True)
        self.encoder = nn.TransformerEncoder(enc, num_layers=n_layers)
    def forward(self, tokens): return self.encoder(tokens)

class ImprovedFTAFM(nn.Module):
    """Improved FT+AFM with larger capacity and better fusion"""
    def __init__(self, cat_cards, n_num, d_model=192, nhead=8, ff=512, n_layers=3, dropout=0.15, afm_attn_dim=64):
        super().__init__()
        self.tok = FeatureTokenizer(cat_cards, n_num, d_model)
        self.backbone = FTTransformer(d_model, nhead, ff, n_layers, dropout)
        self.afm = ImprovedAFM(d_model, afm_attn_dim, dropout=dropout)
        
        # Larger head
        fusion_dim = d_model + d_model
        self.head = nn.Sequential(
            nn.Linear(fusion_dim, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 1)
        )
        
    def forward(self, x_cat, x_num):
        tokens, field_embs = self.tok(x_cat, x_num)
        H = self.backbone(tokens); h_cls = H[:,0,:]
        v_afm = self.afm(field_embs)
        z = torch.cat([h_cls, v_afm], dim=1)
        return self.head(z).squeeze(1)

# ============== Train/Eval ==============
def init_final_bias_to_ctr(module_last_linear, base_ctr):
    with torch.no_grad():
        module_last_linear.bias.fill_(math.log(base_ctr/(1.0-base_ctr)))

def evaluate_model(model, dl, use_smaller_batches=False):
    """Evaluate with optional smaller batches to avoid OOM"""
    model.eval(); ys, ps = [], []
    
    # Unwrap DataParallel if present to avoid issues
    model_eval = model.module if isinstance(model, nn.DataParallel) else model
    
    with torch.no_grad():
        for Xc, Xn, yb in dl:
            # If OOM during eval, process in smaller chunks
            if use_smaller_batches and Xc.size(0) > 2048:
                chunk_size = 2048
                batch_probs = []
                for i in range(0, Xc.size(0), chunk_size):
                    Xc_chunk = Xc[i:i+chunk_size].to(DEVICE)
                    Xn_chunk = Xn[i:i+chunk_size].to(DEVICE)
                    logits = model_eval(Xc_chunk, Xn_chunk)
                    if logits.dim()==1: logits = logits.unsqueeze(1)
                    probs = torch.sigmoid(logits).squeeze(-1)
                    batch_probs.append(probs.cpu())
                    # Clear cache after each chunk
                    del Xc_chunk, Xn_chunk, logits, probs
                    torch.cuda.empty_cache()
                probs = torch.cat(batch_probs)
            else:
                Xc, Xn = Xc.to(DEVICE), Xn.to(DEVICE)
                logits = model_eval(Xc, Xn)
                if logits.dim()==1: logits = logits.unsqueeze(1)
                probs = torch.sigmoid(logits).squeeze(-1).cpu()
            
            ys.append(yb.squeeze(-1).numpy())
            ps.append(probs.numpy())
            
            # Clear GPU memory
            torch.cuda.empty_cache()
    
    y_true = np.concatenate(ys); y_prob = np.clip(np.concatenate(ps), 1e-7, 1-1e-7)
    return roc_auc_score(y_true, y_prob), log_loss(y_true, y_prob)

class TrainCfg:
    def __init__(self, d_model=192, nhead=8, ff=512, n_layers=3, dropout=0.15,
                 lr=1e-3, weight_decay=5e-5, max_epochs=20, patience=5, clip=1.0, warmup_epochs=2):
        self.d_model=d_model; self.nhead=nhead; self.ff=ff; self.n_layers=n_layers
        self.dropout=dropout; self.lr=lr; self.weight_decay=weight_decay
        self.max_epochs=max_epochs; self.patience=patience; self.clip=clip
        self.warmup_epochs=warmup_epochs
        self.use_amp=torch.cuda.is_available()

def get_warmup_cosine_scheduler(optimizer, warmup_epochs, max_epochs):
    """Warmup for first few epochs, then cosine decay"""
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return (epoch + 1) / warmup_epochs
        else:
            progress = (epoch - warmup_epochs) / (max_epochs - warmup_epochs)
            return 0.5 * (1 + math.cos(math.pi * progress))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

def train_improved_ft_afm(model, tr_dl, va_dl, cfg, outdir):
    ensure_dir(outdir)
    
    # Wrap model with DataParallel if multiple GPUs available
    if USE_DDP and NUM_GPUS > 1:
        print(f"Wrapping model with DataParallel across {NUM_GPUS} GPUs")
        model = nn.DataParallel(model)
    
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    scheduler = get_warmup_cosine_scheduler(opt, cfg.warmup_epochs, cfg.max_epochs)
    scaler = GradScaler(device='cuda', enabled=cfg.use_amp)
    best_ll, best_state, stale = float("inf"), None, 0
    history = []
    
    for ep in range(1, cfg.max_epochs+1):
        model.train(); run = 0.0
        batch_count = 0
        for Xc, Xn, yb in tr_dl:
            Xc, Xn, yb = Xc.to(DEVICE), Xn.to(DEVICE), yb.to(DEVICE)
            opt.zero_grad(set_to_none=True)
            with autocast(device_type='cuda', enabled=cfg.use_amp):
                logits = model(Xc, Xn)
                if logits.dim()==1: logits = logits.unsqueeze(1)
                loss = F.binary_cross_entropy_with_logits(logits, yb)
            scaler.scale(loss).backward()
            
            model_to_clip = model.module if isinstance(model, nn.DataParallel) else model
            torch.nn.utils.clip_grad_norm_(model_to_clip.parameters(), cfg.clip)
            
            scaler.step(opt); scaler.update()
            run += loss.item() * yb.size(0)
            batch_count += 1
            
            if batch_count % 100 == 0:
                print(f"  Epoch {ep}, Batch {batch_count}/{len(tr_dl)}, Loss: {loss.item():.4f}, LR: {opt.param_groups[0]['lr']:.6f}")
        
        scheduler.step()  # Update learning rate

        val_auc, val_ll = evaluate_model(model, va_dl, use_smaller_batches=True)  # Changed to True
        train_ll = run / len(tr_dl.dataset)
        improved = val_ll + 1e-9 < best_ll
        if improved:
            best_ll = val_ll; stale = 0
            model_to_save = model.module if isinstance(model, nn.DataParallel) else model
            best_state = {k: v.detach().cpu().clone() for k,v in model_to_save.state_dict().items()}
        else:
            stale += 1

        history.append({"epoch": ep, "train_ll": float(train_ll), "val_ll": float(val_ll), 
                       "val_auc": float(val_auc), "lr": float(opt.param_groups[0]['lr'])})
        print(f"[IMPROVED FT-AFM] ep{ep:02d} train_ll={train_ll:.4f} val_ll={val_ll:.4f} val_auc={val_auc:.4f} "
              f"lr={opt.param_groups[0]['lr']:.6f} {'*BEST*' if improved else f'stale {stale}/{cfg.patience}'}")
        if stale >= cfg.patience:
            print(f"[IMPROVED FT-AFM] early-stopped.")
            break

    if best_state is not None:
        model_to_load = model.module if isinstance(model, nn.DataParallel) else model
        model_to_load.load_state_dict(best_state)
        torch.save(best_state, os.path.join(outdir, "improved_ft_afm_best.pth"))
    pd.DataFrame(history).to_csv(os.path.join(outdir, "improved_ft_afm_history.csv"), index=False)
    
    val_auc, val_ll = evaluate_model(model, va_dl, use_smaller_batches=True)  # Changed to True
    save_json({"val_auc": float(val_auc), "val_logloss": float(val_ll)}, 
              os.path.join(outdir, "improved_ft_afm_val_metrics.json"))
    
    return model.module if isinstance(model, nn.DataParallel) else model

# ============== Dataset ==============
class CTRDataset(Dataset):
    def __init__(self, Xc, Xn, y):
        self.Xc = torch.as_tensor(Xc, dtype=torch.long)
        self.Xn = torch.as_tensor(Xn, dtype=torch.float32)
        self.y  = torch.as_tensor(y,  dtype=torch.float32)
    def __len__(self): return len(self.y)
    def __getitem__(self, i): return self.Xc[i], self.Xn[i], self.y[i].unsqueeze(-1)

# ============== Enhanced Feature Engineering ==============
def extract_time_features(df, hour_col='hour'):
    """Extract time-based features from YYYYMMDDHH format"""
    print("Extracting time features...")
    df[hour_col] = pd.to_numeric(df[hour_col], errors='coerce').fillna(0).astype('int64')
    
    # Hour of day (0-23)
    df['hour_of_day'] = (df[hour_col] % 100).astype('int32')
    
    # Convert to datetime for more features
    df['_temp_datetime'] = pd.to_datetime(df[hour_col], format='%y%m%d%H', errors='coerce')
    df['day_of_week'] = df['_temp_datetime'].dt.dayofweek.fillna(0).astype('int32')
    df['day_of_month'] = df['_temp_datetime'].dt.day.fillna(1).astype('int32')
    df['is_weekend'] = df['day_of_week'].isin([5,6]).astype('int32')
    
    # Drop temp column
    df.drop('_temp_datetime', axis=1, inplace=True)
    
    return ['hour_of_day', 'day_of_week', 'day_of_month', 'is_weekend']

# ============== Main Runner ==============
def run_improved_ft_afm_40m(csv_path, outdir, label_col, time_info, base_num_cols, drop_cols, 
                            single_freq_cats, pair_freq_cats, te_targets):
    print(f"\n{'='*80}")
    print("IMPROVED FT-AFM on 40M Samples")
    print(f"{'='*80}")
    ensure_dir(outdir)

    # Load data
    print("Loading data...")
    if csv_path.endswith('.parquet'):
        df = pd.read_parquet(csv_path)
    else:
        df = pd.read_csv(csv_path)
    print(f"Loaded {len(df):,} rows")
    assert label_col in df.columns

    # Save schema before
    schema_before = {
        "n_rows": int(len(df)),
        "n_cols": int(len(df.columns)),
        "positive_rate": float(df[label_col].mean())
    }
    save_json(schema_before, os.path.join(outdir, "schema_before.json"))

    # Cast base numeric columns
    for c in base_num_cols:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors="coerce").fillna(0)

    # Extract time features BEFORE splitting
    if 'hour' in df.columns:
        time_features = extract_time_features(df, 'hour')
        base_num_cols = base_num_cols + time_features
        print(f"Added time features: {time_features}")

    # 8:1:1 split
    print("Splitting data (8:1:1)...")
    if time_info is not None:
        df_tr, df_va, df_te = temporal_or_stratified_split(
            df, label_col, time_cols=time_info.get("time_cols"), 
            order_key=time_info.get("order_key"), train=0.8, val=0.1, test=0.1
        )
    else:
        df_tr, df_va, df_te = temporal_or_stratified_split(
            df, label_col, time_cols=None, order_key=None, train=0.8, val=0.1, test=0.1
        )
    print(f"Train: {len(df_tr):,}, Val: {len(df_va):,}, Test: {len(df_te):,}")

    # Feature lists
    all_cols_tr = df_tr.columns.tolist()
    drop_cols_eff = [c for c in drop_cols if c in all_cols_tr]
    base_num_eff = [c for c in base_num_cols if c in all_cols_tr]
    cat_cols = [c for c in all_cols_tr if c not in ([label_col] + drop_cols_eff + base_num_eff)]
    num_cols = base_num_eff[:]

    # Categorical encoding with rare category handling
    print("Encoding categoricals with rare category handling...")
    cat_cards, cat_maps = [], {}
    MIN_FREQ = 10  # Categories with < 10 occurrences -> UNK
    
    for c in cat_cols:
        vc = df_tr[c].value_counts()
        frequent = vc[vc >= MIN_FREQ].index
        uniq = pd.Index(frequent)
        mapping = {v:i for i,v in enumerate(uniq)}
        unk_id = len(mapping)
        cat_maps[c] = (mapping, unk_id)
        
        for d in (df_tr, df_va, df_te):
            d[c] = d[c].map(mapping).fillna(unk_id).astype("int64")
        cat_cards.append(unk_id + 1)
        print(f"  {c}: {len(mapping)} frequent categories, {unk_id+1} total (with UNK)")

    # Numeric base features
    for c in num_cols:
        for d in (df_tr, df_va, df_te):
            d[c] = pd.to_numeric(d[c], errors="coerce").fillna(0).astype("float32")

    # Single frequency features
    print("Engineering frequency features...")
    cand_freq = [c for c in single_freq_cats if c in df_tr.columns]
    for c in cand_freq:
        vc = df_tr[c].value_counts()
        df_tr[f"{c}_freq"] = df_tr[c].map(vc).astype("float32")
        df_va[f"{c}_freq"] = df_va[c].map(vc).fillna(0).astype("float32")
        df_te[f"{c}_freq"] = df_te[c].map(vc).fillna(0).astype("float32")

    # Pairwise frequency features
    pairs_eff = [(a,b) for (a,b) in pair_freq_cats if set([a,b]).issubset(df_tr.columns)]
    print(f"Creating {len(pairs_eff)} pairwise frequency features...")
    for a,b in pairs_eff:
        key_tr = df_tr[a].astype("int64")*10_000_000 + df_tr[b].astype("int64")
        key_va = df_va[a].astype("int64")*10_000_000 + df_va[b].astype("int64")
        key_te = df_te[a].astype("int64")*10_000_000 + df_te[b].astype("int64")
        vc = key_tr.value_counts()
        name = f"{a}__{b}__freq"
        df_tr[name] = key_tr.map(vc).astype("float32")
        df_va[name] = key_va.map(vc).fillna(0).astype("float32")
        df_te[name] = key_te.map(vc).fillna(0).astype("float32")

    # Target encoding
    te_eff = [c for c in te_targets if c in df_tr.columns]
    print(f"Target encoding {len(te_eff)} features: {te_eff}")
    for c in te_eff:
        te_tr, te_va = kfold_target_encode(df_tr, df_va, c, yname=label_col, n_splits=5, min_samples=50)
        prior = float(df_tr[label_col].mean())
        means = df_tr.groupby(c)[label_col].mean()
        cnts  = df_tr.groupby(c)[label_col].size()
        mfull = (means*cnts + prior*50) / (cnts + 50)
        te_te = df_te[c].map(mfull).fillna(prior).astype("float32")
        df_tr[f"{c}_te"] = te_tr.astype("float32")
        df_va[f"{c}_te"] = te_va.astype("float32")
        df_te[f"{c}_te"] = te_te.astype("float32")

    # Register engineered numerics
    new_num = [f"{c}_freq" for c in cand_freq] + [f"{a}__{b}__freq" for (a,b) in pairs_eff] + [f"{c}_te" for c in te_eff]
    for c in new_num:
        if c not in num_cols: num_cols.append(c)

    # Log1p freq features
    freq_like = [c for c in num_cols if c.endswith("_freq")]
    for c in freq_like:
        for d in (df_tr, df_va, df_te):
            d[c] = np.log1p(pd.to_numeric(d[c], errors="coerce").fillna(0).clip(lower=0).astype("float64"))

    # Standardize all numerics
    print("Standardizing numerics...")
    num_means = {c: float(pd.to_numeric(df_tr[c], errors="coerce").fillna(0).mean()) for c in num_cols}
    num_stds  = {c: float(pd.to_numeric(df_tr[c], errors="coerce").fillna(0).std(ddof=0)) for c in num_cols}
    for c in num_cols:
        mu, sd = num_means[c], (num_stds[c] if num_stds[c] > 1e-8 else 1.0)
        for d in (df_tr, df_va, df_te):
            v = pd.to_numeric(d[c], errors="coerce")
            d[c] = ((v - mu)/sd).replace([np.inf,-np.inf], np.nan).fillna(0).astype("float32")

    # Save schema after
    schema_after = {
        "cat_cols": cat_cols,
        "num_cols": num_cols,
        "cat_cards": cat_cards,
        "n_features": len(cat_cols) + len(num_cols),
        "splits": {"train": len(df_tr), "val": len(df_va), "test": len(df_te)},
        "train_ctr": float(df_tr[label_col].mean()),
        "val_ctr": float(df_va[label_col].mean()),
        "test_ctr": float(df_te[label_col].mean())
    }
    save_json(schema_after, os.path.join(outdir, "schema_after.json"))
    print(f"Total features: {len(cat_cols)} categorical + {len(num_cols)} numeric = {len(cat_cols)+len(num_cols)}")

    # Prepare tensors
    print("Preparing data loaders...")
    Xc_tr = df_tr[cat_cols].to_numpy()
    Xn_tr = df_tr[num_cols].to_numpy().astype("float32")
    y_tr = df_tr[label_col].to_numpy().astype("float32")
    
    Xc_va = df_va[cat_cols].to_numpy()
    Xn_va = df_va[num_cols].to_numpy().astype("float32")
    y_va = df_va[label_col].to_numpy().astype("float32")
    
    Xc_te = df_te[cat_cols].to_numpy()
    Xn_te = df_te[num_cols].to_numpy().astype("float32")
    y_te = df_te[label_col].to_numpy().astype("float32")
    
    # Right after you have Xc_tr, Xn_tr, y_tr, etc.
    # and BEFORE you create DataLoaders

    print("Saving preprocessed data for ablation studies...")
    save_dir = "runs_avazu_40m_improved_ft_afm"
    os.makedirs(save_dir, exist_ok=True)

    # Save arrays
    np.save(os.path.join(save_dir, 'Xc_train.npy'), Xc_tr)
    np.save(os.path.join(save_dir, 'Xn_train.npy'), Xn_tr)
    np.save(os.path.join(save_dir, 'y_train.npy'), y_tr)

    np.save(os.path.join(save_dir, 'Xc_val.npy'), Xc_va)
    np.save(os.path.join(save_dir, 'Xn_val.npy'), Xn_va)
    np.save(os.path.join(save_dir, 'y_val.npy'), y_va)

    np.save(os.path.join(save_dir, 'Xc_test.npy'), Xc_te)
    np.save(os.path.join(save_dir, 'Xn_test.npy'), Xn_te)
    np.save(os.path.join(save_dir, 'y_test.npy'), y_te)

    # Save schema info
    schema = {
       'cat_cards': cat_cards,
       'num_cols': num_cols
    }
    with open(os.path.join(save_dir, 'schema.json'), 'w') as f:
        json.dump(schema, f)

    print(f"✅ Preprocessed data saved to {save_dir}/")


    # REDUCED batch sizes to avoid OOM with many features
    batch_size_train = 2048 * NUM_GPUS if USE_DDP else 2048  # Reduced from 4096
    batch_size_eval = 4096 * NUM_GPUS if USE_DDP else 4096   # Reduced from 8192
    
    print(f"Train batch size: {batch_size_train} (reduced for memory)")
    print(f"Eval batch size: {batch_size_eval} (reduced for memory)")
    
    tr_dl = DataLoader(CTRDataset(Xc_tr, Xn_tr, y_tr), batch_size=batch_size_train, 
                       shuffle=True, num_workers=4, pin_memory=True)
    va_dl = DataLoader(CTRDataset(Xc_va, Xn_va, y_va), batch_size=batch_size_eval, 
                       shuffle=False, num_workers=4, pin_memory=True)
    te_dl = DataLoader(CTRDataset(Xc_te, Xn_te, y_te), batch_size=batch_size_eval, 
                       shuffle=False, num_workers=4, pin_memory=True)

    # Build improved model
    base_ctr = float(df_tr[label_col].mean())
    cfg = TrainCfg(
        d_model=192,        # Increased from 128
        nhead=8, 
        ff=512,             # Increased from 256
        n_layers=3,         # Increased from 2
        dropout=0.15,       # Increased from 0.10
        lr=1e-3, 
        weight_decay=5e-5,  # Decreased from 1e-4
        max_epochs=20,      # Increased from 12
        patience=5,         # Increased from 3
        warmup_epochs=2
    )
    
    print(f"\nModel Configuration:")
    print(f"  d_model: {cfg.d_model}, n_layers: {cfg.n_layers}, ff: {cfg.ff}")
    print(f"  dropout: {cfg.dropout}, lr: {cfg.lr}, weight_decay: {cfg.weight_decay}")
    print(f"  max_epochs: {cfg.max_epochs}, patience: {cfg.patience}, warmup: {cfg.warmup_epochs}")
    
    print("\nBuilding Improved FT-AFM model...")
    model = ImprovedFTAFM(cat_cards, len(num_cols), cfg.d_model, cfg.nhead, 
                          cfg.ff, cfg.n_layers, cfg.dropout, afm_attn_dim=64).to(DEVICE)
    init_final_bias_to_ctr(model.head[-1], base_ctr)
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    if RUN_FT_AFM:
        print(f"\nTraining Improved FT-AFM on {len(df_tr):,} samples...")
        model = train_improved_ft_afm(model, tr_dl, va_dl, cfg, outdir)
    
        # Final test evaluation with smaller batches to avoid OOM
        print("\nEvaluating on test set (using memory-efficient batching)...")
        torch.cuda.empty_cache()  # Clear cache before final eval
        test_auc, test_ll = evaluate_model(model, te_dl, use_smaller_batches=True)
        save_json({"test_auc": float(test_auc), "test_logloss": float(test_ll)},
              os.path.join(outdir, "improved_ft_afm_test_metrics.json"))
    
        print(f"\n{'='*80}")
        print(f"IMPROVED FT-AFM Results on 40M Samples:")
        print(f"Test AUC:     {test_auc:.6f}")
        print(f"Test LogLoss: {test_ll:.6f}")
        print(f"{'='*80}\n")
    
        return test_auc, test_ll
    else:
        print("⏭️ Skipping FT-AFM training (preprocessing only).")
        return None, None

# ================== RUN: AVAZU 40M IMPROVED ==================
AVAZU_40M_PATH = "/home/elicer/ctr_project/data/avazu_full/avazu_full.parquet"
OUT_AVAZU_IMPROVED = "runs_avazu_40m_improved_ft_afm"

print("\n" + "="*80)
print("Starting IMPROVED FT-AFM training on Avazu 40M dataset")
print("IMPROVEMENTS:")
print("  • Larger model: d=192, layers=3, ff=512")
print("  • Time features: hour_of_day, day_of_week, is_weekend")
print("  • More target encoding: C14, C17, site_id, app_id, device_model")
print("  • Enhanced AFM: attn_dim=64 with dropout")
print("  • Learning rate schedule: warmup + cosine annealing")
print("  • Longer training: 20 epochs, patience=5")
print("  • Better regularization: dropout=0.15, weight_decay=5e-5")
print("  • More cross features")
print("="*80)

test_auc, test_ll = run_improved_ft_afm_40m(
    csv_path=AVAZU_40M_PATH,
    outdir=OUT_AVAZU_IMPROVED,
    label_col="click",
    time_info={"time_cols":["hour"], "order_key":"hour"},
    base_num_cols=["hour"],
    drop_cols=["id"],
    single_freq_cats=["C14","C17","C18","C19","C20","C21","site_id","app_id","device_model"],  # Added more C fields
    pair_freq_cats=[
        ("site_id","app_id"),
        ("device_model","hour"),
        ("C14","site_id"),          # NEW
        ("C17","app_id"),           # NEW
        ("site_id","C14"),          # NEW
        ("app_id","device_model"),  # NEW
    ],
    te_targets=["C14","C17","site_id","app_id","device_model"]  # Added C17, app_id, device_model
)

print("\n" + "="*80)
print("FINAL RESULTS - Avazu 40M IMPROVED FT-AFM")
print("="*80)
print(f"Test AUC:     {test_auc:.6f}")
print(f"Test LogLoss: {test_ll:.6f}")
print("="*80)
print(f"\nResults saved to: {OUT_AVAZU_IMPROVED}/")
print("\n✅ IMPROVED Training complete!")
print("\nIMPROVEMENTS SUMMARY:")
print("  ✓ Model capacity increased (128→192 dim, 2→3 layers, 256→512 FF)")
print("  ✓ Time-based features extracted (hour_of_day, day_of_week, weekend)")
print("  ✓ More target encoding (5 features vs 2)")
print("  ✓ More cross features (6 pairs vs 2)")
print("  ✓ Larger AFM attention (64 vs 32)")
print("  ✓ Learning rate scheduling (warmup + cosine decay)")
print("  ✓ Longer training (20 epochs vs 12)")
print("  ✓ Better regularization tuning")
print("="*80)

Available GPUs: 4
Using DataParallel with 4 GPUs

Starting IMPROVED FT-AFM training on Avazu 40M dataset
IMPROVEMENTS:
  • Larger model: d=192, layers=3, ff=512
  • Time features: hour_of_day, day_of_week, is_weekend
  • More target encoding: C14, C17, site_id, app_id, device_model
  • Enhanced AFM: attn_dim=64 with dropout
  • Learning rate schedule: warmup + cosine annealing
  • Longer training: 20 epochs, patience=5
  • Better regularization: dropout=0.15, weight_decay=5e-5
  • More cross features

IMPROVED FT-AFM on 40M Samples
Loading data...
Loaded 40,428,967 rows
Extracting time features...
Added time features: ['hour_of_day', 'day_of_week', 'day_of_month', 'is_weekend']
Splitting data (8:1:1)...
Train: 32,343,173, Val: 4,042,896, Test: 4,042,898
Encoding categoricals with rare category handling...
  C1: 7 frequent categories, 8 total (with UNK)
  banner_pos: 7 frequent categories, 8 total (with UNK)
  site_id: 3005 frequent categories, 3006 total (with UNK)
  site_domain: 3294 



Total parameters: 99,325,761
Trainable parameters: 99,325,761
⏭️ Skipping FT-AFM training (preprocessing only).

FINAL RESULTS - Avazu 40M IMPROVED FT-AFM


TypeError: unsupported format string passed to NoneType.__format__

In [2]:
# ============================================================
# QUICK ABLATION STUDY - Run 3 variants in 3 days
# You already have FT-AFM (0.739), need: FT-only, AFM-only, FT+FM
# ============================================================
import os, json, math, random, numpy as np, pandas as pd, torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score, log_loss
from sklearn.model_selection import KFold
from torch.amp import GradScaler, autocast

# ===================== ADD THESE FLAGS (near the top, after imports) =====================
RUN_FT_ONLY = False
RUN_AFM_ONLY = False
RUN_FT_FM = False
RUN_FT_AFM_GATED = True     # <-- run ONLY this now
# =======================================================================================

# ============== Setup ==============
def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
set_seed(42)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_GPUS = torch.cuda.device_count()

# ============== Models ==============
class AFM(nn.Module):
    def __init__(self, d, attn_dim=32):
        super().__init__()
        self.W = nn.Linear(d, attn_dim, bias=False)
        self.h = nn.Linear(attn_dim, 1, bias=False)
    def forward(self, E):
        B,F,d = E.shape; pairs=[]
        for i in range(F):
            for j in range(i+1, F):
                pairs.append(E[:,i]*E[:,j])
        P = torch.stack(pairs, dim=1)
        A = torch.softmax(self.h(torch.tanh(self.W(P))), dim=1)
        return (A * P).sum(dim=1)

def fm_interaction(E):
    sum_embed = E.sum(dim=1)
    square_sum = sum_embed.pow(2)
    sum_square = (E.pow(2)).sum(dim=1)
    return (0.5 * (square_sum - sum_square)).sum(dim=1, keepdim=True)

class FeatureTokenizer(nn.Module):
    def __init__(self, cat_cardinalities, n_num, d_model):
        super().__init__()
        self.cat_embs = nn.ModuleList([nn.Embedding(card, d_model) for card in cat_cardinalities])
        self.num_proj = nn.ModuleList([nn.Linear(1, d_model) for _ in range(n_num)])
        self.cls = nn.Parameter(torch.zeros(1,1,d_model)); nn.init.trunc_normal_(self.cls, std=0.02)
    def forward(self, x_cat, x_num):
        B = x_cat.size(0)
        cat_tokens = [emb(x_cat[:, i]) for i, emb in enumerate(self.cat_embs)]
        num_tokens = [proj(x_num[:, i:i+1]) for i, proj in enumerate(self.num_proj)]
        field_embs = torch.stack(cat_tokens + num_tokens, dim=1)
        cls = self.cls.expand(B, -1, -1)
        tokens = torch.cat([cls, field_embs], dim=1)
        return tokens, field_embs

class FTTransformer(nn.Module):
    def __init__(self, d_model=128, nhead=8, ff=256, n_layers=2, dropout=0.1):
        super().__init__()
        enc = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=ff,
                                         dropout=dropout, batch_first=True, activation="gelu", norm_first=True)
        self.encoder = nn.TransformerEncoder(enc, num_layers=n_layers)
    def forward(self, tokens): return self.encoder(tokens)
    
# ===================== ADD THIS GATED MODEL (paste with your other model classes) =======
class ImprovedFTAFM_Gated(nn.Module):
    """FT + AFM with gated fusion (adaptive AFM contribution)."""
    def __init__(self, cat_cards, n_num, d_model=192, nhead=8, ff=512, n_layers=3, dropout=0.15, afm_attn_dim=64):
        super().__init__()
        self.tok = FeatureTokenizer(cat_cards, n_num, d_model)
        self.backbone = FTTransformer(d_model, nhead, ff, n_layers, dropout)
        self.afm = AFM(d_model, attn_dim=afm_attn_dim)  # uses your AFM class above

        # gate: conditioned on CLS, outputs per-dimension (0..1)
        self.gate = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.Sigmoid()
        )

        fusion_dim = d_model + d_model
        self.head = nn.Sequential(
            nn.Linear(fusion_dim, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 1)
        )

    def forward(self, x_cat, x_num):
        tokens, field_embs = self.tok(x_cat, x_num)
        H = self.backbone(tokens)
        h_cls = H[:, 0, :]                 # [B, d_model]

        v_afm = self.afm(field_embs)       # [B, d_model]

        g = self.gate(h_cls)               # [B, d_model] in (0,1)
        v_afm = g * v_afm                  # gated AFM signal

        z = torch.cat([h_cls, v_afm], dim=1)
        return self.head(z).squeeze(1)

# ============== Variant 1: FT-only ==============
class FTOnly(nn.Module):
    def __init__(self, cat_cards, n_num, d_model=128, nhead=8, ff=256, n_layers=2, dropout=0.1):
        super().__init__()
        self.tok = FeatureTokenizer(cat_cards, n_num, d_model)
        self.backbone = FTTransformer(d_model, nhead, ff, n_layers, dropout)
        self.head = nn.Sequential(nn.Linear(d_model, 128), nn.ReLU(), nn.Dropout(dropout), nn.Linear(128, 1))
    def forward(self, x_cat, x_num):
        tokens, _ = self.tok(x_cat, x_num)
        H = self.backbone(tokens)
        return self.head(H[:,0,:]).squeeze(1)

# ============== Variant 2: AFM-only ==============
class AFMOnly(nn.Module):
    def __init__(self, cat_cards, n_num, d_model=64):
        super().__init__()
        self.cat_embs = nn.ModuleList([nn.Embedding(card, d_model) for card in cat_cards])
        self.num_proj = nn.ModuleList([nn.Linear(1, d_model) for _ in range(n_num)])
        self.afm = AFM(d_model, attn_dim=32)
        self.head = nn.Linear(d_model, 1)
    def forward(self, x_cat, x_num):
        cat_embs = [emb(x_cat[:, i]) for i, emb in enumerate(self.cat_embs)]
        num_embs = [proj(x_num[:, i:i+1]) for i, proj in enumerate(self.num_proj)]
        field_embs = torch.stack(cat_embs + num_embs, dim=1)
        afm_out = self.afm(field_embs)
        return self.head(afm_out).squeeze(1)

# ============== Variant 3: FT+FM ==============
class FTFM(nn.Module):
    def __init__(self, cat_cards, n_num, d_model=128, nhead=8, ff=256, n_layers=2, dropout=0.1):
        super().__init__()
        self.tok = FeatureTokenizer(cat_cards, n_num, d_model)
        self.backbone = FTTransformer(d_model, nhead, ff, n_layers, dropout)
        self.head = nn.Sequential(nn.Linear(d_model + 1, 128), nn.ReLU(), nn.Dropout(dropout), nn.Linear(128, 1))
    def forward(self, x_cat, x_num):
        tokens, field_embs = self.tok(x_cat, x_num)
        H = self.backbone(tokens)
        fm_out = fm_interaction(field_embs)
        z = torch.cat([H[:,0,:], fm_out], dim=1)
        return self.head(z).squeeze(1)

# ============== Training ==============
class CTRDataset(Dataset):
    def __init__(self, Xc, Xn, y):
        self.Xc = torch.as_tensor(Xc, dtype=torch.long)
        self.Xn = torch.as_tensor(Xn, dtype=torch.float32)
        self.y  = torch.as_tensor(y,  dtype=torch.float32)
    def __len__(self): return len(self.y)
    def __getitem__(self, i): return self.Xc[i], self.Xn[i], self.y[i].unsqueeze(-1)

def evaluate(model, dl):
    model.eval(); ys, ps = [], []
    with torch.no_grad():
        for Xc, Xn, yb in dl:
            # Process in chunks to avoid OOM
            batch_probs = []
            for i in range(0, Xc.size(0), 2048):
                Xc_chunk = Xc[i:i+2048].to(DEVICE)
                Xn_chunk = Xn[i:i+2048].to(DEVICE)
                logits = model(Xc_chunk, Xn_chunk)
                if logits.dim()==1: logits = logits.unsqueeze(1)
                probs = torch.sigmoid(logits).squeeze(-1).cpu()
                batch_probs.append(probs)
                del Xc_chunk, Xn_chunk, logits, probs
                torch.cuda.empty_cache()
            probs = torch.cat(batch_probs)
            ys.append(yb.squeeze(-1).numpy())
            ps.append(probs.numpy())
    y_true = np.concatenate(ys); y_prob = np.clip(np.concatenate(ps), 1e-7, 1-1e-7)
    return roc_auc_score(y_true, y_prob), log_loss(y_true, y_prob)

def train_variant(model, tr_dl, va_dl, name, outdir, max_epochs=12, patience=3):
    os.makedirs(outdir, exist_ok=True)
    if NUM_GPUS > 1:
        model = nn.DataParallel(model)
    
    opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)
    scaler = GradScaler(device='cuda', enabled=torch.cuda.is_available())
    best_ll, best_state, stale = float("inf"), None, 0
    
    for ep in range(1, max_epochs+1):
        model.train(); run = 0.0
        for Xc, Xn, yb in tr_dl:
            Xc, Xn, yb = Xc.to(DEVICE), Xn.to(DEVICE), yb.to(DEVICE)
            opt.zero_grad(set_to_none=True)
            with autocast(device_type='cuda', enabled=torch.cuda.is_available()):
                logits = model(Xc, Xn)
                if logits.dim()==1: logits = logits.unsqueeze(1)
                loss = F.binary_cross_entropy_with_logits(logits, yb)
            scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
            run += loss.item() * yb.size(0)
        
        val_auc, val_ll = evaluate(model, va_dl)
        train_ll = run / len(tr_dl.dataset)
        improved = val_ll < best_ll
        if improved:
            best_ll = val_ll; stale = 0
            model_to_save = model.module if isinstance(model, nn.DataParallel) else model
            best_state = {k: v.cpu().clone() for k,v in model_to_save.state_dict().items()}
        else:
            stale += 1
        
        print(f"[{name}] ep{ep:02d} train_ll={train_ll:.4f} val_ll={val_ll:.4f} val_auc={val_auc:.4f} {'*' if improved else f'stale {stale}/{patience}'}")
        if stale >= patience:
            print(f"[{name}] early-stopped.")
            break
    
    if best_state:
        model_to_load = model.module if isinstance(model, nn.DataParallel) else model
        model_to_load.load_state_dict(best_state)
        torch.save(best_state, os.path.join(outdir, f"{name}_best.pth"))
    
    return model.module if isinstance(model, nn.DataParallel) else model

# ============== Main Runner ==============
# ===================== REPLACE YOUR run_ablations() WITH THIS PATCHED VERSION ===========
def run_ablations(preprocessed_data_dir="runs_avazu_40m_improved_ft_afm", output_dir="ablation_results"):
    """
    Loads saved preprocessed splits and runs selected ablation variants.
    You can skip already-finished variants using the RUN_* flags.
    """

    print("\n" + "="*80)
    print("RUNNING ABLATION STUDIES - SELECTED VARIANTS")
    print("="*80)

    # ---- Load preprocessed data ----
    print("\nLoading preprocessed data...")
    Xc_tr = np.load(os.path.join(preprocessed_data_dir, "Xc_train.npy"))
    Xn_tr = np.load(os.path.join(preprocessed_data_dir, "Xn_train.npy"))
    y_tr  = np.load(os.path.join(preprocessed_data_dir, "y_train.npy"))

    Xc_va = np.load(os.path.join(preprocessed_data_dir, "Xc_val.npy"))
    Xn_va = np.load(os.path.join(preprocessed_data_dir, "Xn_val.npy"))
    y_va  = np.load(os.path.join(preprocessed_data_dir, "y_val.npy"))

    Xc_te = np.load(os.path.join(preprocessed_data_dir, "Xc_test.npy"))
    Xn_te = np.load(os.path.join(preprocessed_data_dir, "Xn_test.npy"))
    y_te  = np.load(os.path.join(preprocessed_data_dir, "y_test.npy"))

    # ---- Schema file name: adjust ONLY if your file is named differently ----
    schema_path = os.path.join(preprocessed_data_dir, "schema.json")
    if not os.path.exists(schema_path):
        alt = os.path.join(preprocessed_data_dir, "preprocessed_schema.json")
        if os.path.exists(alt):
            schema_path = alt

    with open(schema_path) as f:
        schema = json.load(f)

    cat_cards = schema["cat_cards"]
    n_num = len(schema["num_cols"])

    print(f"Loaded: {len(y_tr):,} train, {len(y_va):,} val, {len(y_te):,} test")
    print(f"Features: {len(cat_cards)} categorical, {n_num} numerical")
    print(f"Schema: {schema_path}")

    # ---- Dataloaders ----
    batch_size = 2048 * NUM_GPUS if NUM_GPUS > 1 else 2048
    tr_dl = DataLoader(CTRDataset(Xc_tr, Xn_tr, y_tr), batch_size=batch_size, shuffle=True,  num_workers=4, pin_memory=True)
    va_dl = DataLoader(CTRDataset(Xc_va, Xn_va, y_va), batch_size=batch_size*2, shuffle=False, num_workers=4, pin_memory=True)
    te_dl = DataLoader(CTRDataset(Xc_te, Xn_te, y_te), batch_size=batch_size*2, shuffle=False, num_workers=4, pin_memory=True)

    os.makedirs(output_dir, exist_ok=True)

    # ---- Load existing summary if present so we don't lose old results ----
    summary_path = os.path.join(output_dir, "ablation_summary.csv")
    results = []
    if os.path.exists(summary_path):
        try:
            prev = pd.read_csv(summary_path)
            results = [tuple(x) for x in prev[["Model", "Test AUC", "Test LogLoss"]].values.tolist()]
            print(f"\nFound existing summary with {len(results)} rows: {summary_path}")
        except Exception as e:
            print(f"\n⚠️ Could not read existing summary ({summary_path}): {e}")
            results = []

    # Helper to avoid duplicate rows
    def upsert_result(model_name, auc, ll):
        nonlocal results
        results = [r for r in results if r[0] != model_name]
        results.append((model_name, float(auc), float(ll)))

    # =================== 1) FT-only (skip if already done) ===================
    if RUN_FT_ONLY:
        print("\n" + "="*80)
        print("VARIANT 1: FT-only")
        print("="*80)
        model = FTOnly(cat_cards, n_num, d_model=128, nhead=8, ff=256, n_layers=2, dropout=0.1).to(DEVICE)
        model = train_variant(model, tr_dl, va_dl, "ft_only", output_dir)
        test_auc, test_ll = evaluate(model, te_dl)
        upsert_result("FT-only", test_auc, test_ll)
        print(f"FT-only Test: AUC={test_auc:.4f}, LogLoss={test_ll:.4f}")
    else:
        print("\n⏭️ Skipping FT-only (already done)")

    # =================== 2) AFM-only (skip if already done) ===================
    if RUN_AFM_ONLY:
        print("\n" + "="*80)
        print("VARIANT 2: AFM-only")
        print("="*80)
        model = AFMOnly(cat_cards, n_num, d_model=64).to(DEVICE)
        model = train_variant(model, tr_dl, va_dl, "afm_only", output_dir)
        test_auc, test_ll = evaluate(model, te_dl)
        upsert_result("AFM-only", test_auc, test_ll)
        print(f"AFM-only Test: AUC={test_auc:.4f}, LogLoss={test_ll:.4f}")
    else:
        print("⏭️ Skipping AFM-only (already done)")

    # =================== 3) FT+FM (skip if already done) ===================
    if RUN_FT_FM:
        print("\n" + "="*80)
        print("VARIANT 3: FT+FM")
        print("="*80)
        model = FTFM(cat_cards, n_num, d_model=128, nhead=8, ff=256, n_layers=2, dropout=0.1).to(DEVICE)
        model = train_variant(model, tr_dl, va_dl, "ft_fm", output_dir)
        test_auc, test_ll = evaluate(model, te_dl)
        upsert_result("FT+FM", test_auc, test_ll)
        print(f"FT+FM Test: AUC={test_auc:.4f}, LogLoss={test_ll:.4f}")
    else:
        print("⏭️ Skipping FT+FM (already done)")

    # =================== 4) FT+AFM (GATED) (run this now) ===================
    if RUN_FT_AFM_GATED:
        print("\n" + "="*80)
        print("VARIANT 4: FT+AFM (GATED)")
        print("="*80)

        # Use same capacity as your main model for fairness
        model = ImprovedFTAFM_Gated(
            cat_cards, n_num,
            d_model=192, nhead=8, ff=512, n_layers=3, dropout=0.15, afm_attn_dim=64
        ).to(DEVICE)

        model = train_variant(model, tr_dl, va_dl, "ft_afm_gated", output_dir, max_epochs=20, patience=5)
        test_auc, test_ll = evaluate(model, te_dl)
        upsert_result("FT+AFM (gated)", test_auc, test_ll)
        print(f"FT+AFM (gated) Test: AUC={test_auc:.4f}, LogLoss={test_ll:.4f}")
    else:
        print("⏭️ Skipping FT+AFM (gated)")

    # =================== Summary ===================
    print("\n" + "="*80)
    print("ABLATION RESULTS SUMMARY (UPDATED)")
    print("="*80)

    df = pd.DataFrame(results, columns=["Model", "Test AUC", "Test LogLoss"])
    df = df.sort_values(by="Test AUC", ascending=False)
    df.to_csv(summary_path, index=False)
    print(df.to_string(index=False))

    print(f"\n✅ Saved updated summary to: {summary_path}")
    return df

# ============== USAGE ==============
# First, save your preprocessed data from the FT-AFM run
# Then run this:
run_ablations("runs_avazu_40m_improved_ft_afm", "ablation_results")

print("="*80)
print("ABLATION STUDY SCRIPT READY")
print("="*80)
print("\nBefore running, you need to save preprocessed data as numpy arrays.")
print("Add this to the end of your FT-AFM training script:")
print("""
# After preprocessing, before training:
np.save('Xc_train.npy', Xc_tr)
np.save('Xn_train.npy', Xn_tr)
np.save('y_train.npy', y_tr)
np.save('Xc_val.npy', Xc_va)
np.save('Xn_val.npy', Xn_va)  
np.save('y_val.npy', y_va)
np.save('Xc_test.npy', Xc_te)
np.save('Xn_test.npy', Xn_te)
np.save('y_test.npy', y_te)
with open('schema.json', 'w') as f:
    json.dump({'cat_cards': cat_cards, 'num_cols': num_cols}, f)
""")
print("\nThen run: run_ablations('path/to/saved/data', 'ablation_results')")


RUNNING ABLATION STUDIES - SELECTED VARIANTS

Loading preprocessed data...
Loaded: 32,343,173 train, 4,042,896 val, 4,042,898 test
Features: 21 categorical, 25 numerical
Schema: runs_avazu_40m_improved_ft_afm/schema.json

Found existing summary with 4 rows: ablation_results/ablation_summary.csv

⏭️ Skipping FT-only (already done)
⏭️ Skipping AFM-only (already done)
⏭️ Skipping FT+FM (already done)

VARIANT 4: FT+AFM (GATED)




[ft_afm_gated] ep01 train_ll=0.3936 val_ll=0.3878 val_auc=0.7372 *
[ft_afm_gated] ep02 train_ll=0.3833 val_ll=0.3878 val_auc=0.7384 *
[ft_afm_gated] ep03 train_ll=0.3762 val_ll=0.3892 val_auc=0.7352 stale 1/5
[ft_afm_gated] ep04 train_ll=0.3712 val_ll=0.3906 val_auc=0.7351 stale 2/5
[ft_afm_gated] ep05 train_ll=0.3674 val_ll=0.3941 val_auc=0.7294 stale 3/5
[ft_afm_gated] ep06 train_ll=0.3641 val_ll=0.3924 val_auc=0.7341 stale 4/5
[ft_afm_gated] ep07 train_ll=0.3612 val_ll=0.3986 val_auc=0.7308 stale 5/5
[ft_afm_gated] early-stopped.
FT+AFM (gated) Test: AUC=0.7283, LogLoss=0.4081

ABLATION RESULTS SUMMARY (UPDATED)
         Model  Test AUC  Test LogLoss
       FT-only  0.742573      0.398276
FT+AFM (yours)  0.739000      0.398800
         FT+FM  0.735403      0.410631
FT+AFM (gated)  0.728305      0.408130
      AFM-only  0.720308      0.409382

✅ Saved updated summary to: ablation_results/ablation_summary.csv
ABLATION STUDY SCRIPT READY

Before running, you need to save preprocessed d

In [2]:
import pandas as pd

# Load FT-AFM history
history = pd.read_csv("runs_avazu_40m_improved_ft_afm/improved_ft_afm_history.csv")

print("="*60)
print("FT-AFM Training Summary:")
print("="*60)
print(f"Best val AUC: {history['val_auc'].max():.4f}")
print(f"Best val LL: {history['val_ll'].min():.4f}")
print(f"Final val AUC: {history['val_auc'].iloc[-1]:.4f}")
print(f"Stopped at epoch: {len(history)}")
print(f"\nFull history:")
print(history[['epoch', 'val_auc', 'val_ll']])

FT-AFM Training Summary:
Best val AUC: 0.7477
Best val LL: 0.3823
Final val AUC: 0.7467
Stopped at epoch: 7

Full history:
   epoch   val_auc    val_ll
0      1  0.746883  0.382866
1      2  0.747685  0.382297
2      3  0.746501  0.383027
3      4  0.746505  0.383146
4      5  0.746465  0.382589
5      6  0.746195  0.382757
6      7  0.746650  0.382608


In [3]:
import pandas as pd
import glob

# Find all history files
history_files = glob.glob("ablation_results/*history.csv")

for file in history_files:
    model_name = file.split('/')[-1].replace('_history.csv', '')
    history = pd.read_csv(file)
    best_val = history['val_auc'].max()
    print(f"{model_name:15} | Best Val AUC: {best_val:.4f}")

In [None]:
import pandas as pd
p = "ctr_project/data/criteo/train.csv"   # change to your actual split file
df = pd.read_csv(p, nrows=5)
print(df.columns.tolist())
df.head()
