# Data Preprocessing

In [2]:
# ── Cell 1 ───────────────────────────────────────────────────────────
import pandas as pd, numpy as np, torch, random
from sklearn.model_selection import train_test_split
from pathlib import Path

SEED = 42
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic, torch.backends.cudnn.benchmark = True, False

# ── user dials
DATA_PATH = Path("tedsa_puf_2020.csv")
SUB1_KEEP = [2, 4, 5, 10]
FRACTION  = 1.0

KEEP_COLS = [
    "AGE", "GENDER", "RACE", "ETHNIC", "EDUC", "EMPLOY",
    "LIVARAG", "PRIMINC", "STFIPS", "REGION", "DIVISION",
    "HLTHINS", "SUB1"
]

df = pd.read_csv(DATA_PATH, usecols=KEEP_COLS)

if 0 < FRACTION < 1:
    df = df.sample(frac=FRACTION, random_state=SEED).reset_index(drop=True)

TARGET = "SUB1"
feature_cols = [c for c in df.columns if c != TARGET]

# 1 · drop rows containing sentinel −9 in any predictor column
df = df[~(df[feature_cols] == -9).any(axis=1)].reset_index(drop=True)

# 2 · keep only desired SUB1 classes
df = df[df[TARGET].isin(SUB1_KEEP)].reset_index(drop=True)

# 3 · move SUB1 to the far-right side of the DataFrame
df = df[[c for c in df.columns if c != TARGET] + [TARGET]]

print(f"Rows after cleaning & filtering: {len(df)}")
df.head()


Rows after cleaning & filtering: 436657


Unnamed: 0,STFIPS,EDUC,EMPLOY,GENDER,LIVARAG,AGE,RACE,ETHNIC,PRIMINC,HLTHINS,DIVISION,REGION,SUB1
0,2,4,4,1,2,5,1,4,1,2,9,4,2
1,2,2,4,2,3,2,1,4,4,2,9,4,2
2,2,4,4,2,3,9,2,4,3,4,9,4,2
3,2,3,4,2,3,11,4,4,5,2,9,4,5
4,2,2,4,2,3,5,7,2,5,2,9,4,5


In [3]:
# ── Cell 2 ───────────────────────────────────────────────────────────
TARGET = "SUB1"
feature_cols = [c for c in df.columns if c != TARGET]

# 0 · map each kept SUB1 code to 0…k-1
sub1_to_idx = {code: idx for idx, code in enumerate(SUB1_KEEP)}
df["TARGET_ID"] = df[TARGET].map(sub1_to_idx).astype("int64")

# 1 · label-encode every predictor
for col in feature_cols:
    df[col] = (
        df[col]
          .astype("category")
          .cat.add_categories("UNK")
          .fillna("UNK")          # safeguards genuine NaNs
          .cat.codes
          .astype("int64")
    )

# 2 · schema info
n_classes = len(SUB1_KEEP)
cat_cardinalities = [df[c].nunique() for c in feature_cols]

# 3 · stratified train/test split
train_df, test_df = train_test_split(
    df, stratify=df["TARGET_ID"], test_size=0.20, random_state=SEED
)

# 4 · numpy tensors for the model
X_train_cat = train_df[feature_cols].to_numpy(np.int64)
X_test_cat  =  test_df[feature_cols].to_numpy(np.int64)
y_train     = train_df["TARGET_ID"].to_numpy(np.int64)
y_test      =  test_df["TARGET_ID"].to_numpy(np.int64)

# 5 · report sizes and save CSVs
print(f"Train rows : {len(train_df)}")
print(f"Test  rows : {len(test_df)}\n")

train_df.to_csv("tedsa_train_split.csv", index=False)
test_df.to_csv("tedsa_test_split.csv",  index=False)

print("CSV files written: tedsa_train_split.csv, tedsa_test_split.csv")

Train rows : 349325
Test  rows : 87332

CSV files written: tedsa_train_split.csv, tedsa_test_split.csv


# Tuning

In [None]:
# Cell 3 ──────────────────────────────────────────────────────────────
# Hyper-parameter tuning for FT-Transformer (all-categorical)

import torch, optuna, numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from sklearn.model_selection import StratifiedKFold
from sklearn.utils.class_weight import compute_class_weight
from rtdl_revisiting_models import FTTransformer

# ── dataset wrapper ─────────────────────────────────────────────────
class CatOnlyDataset(Dataset):
    def __init__(self, X_cat, y):
        self.X_cat = torch.as_tensor(X_cat, dtype=torch.long)
        self.y     = torch.as_tensor(y,      dtype=torch.long)
    def __len__(self):          return self.y.size(0)
    def __getitem__(self, idx): return self.X_cat[idx], self.y[idx]

# ── focal loss implementation ───────────────────────────────────────
def focal_loss(logits, targets, gamma=2.0, weight=None):
    log_probs = F.log_softmax(logits, dim=1)
    probs     = torch.exp(log_probs)
    tgt_log_p = log_probs[range(len(targets)), targets]
    tgt_p     = probs[range(len(targets)), targets]
    loss      = -((1.0 - tgt_p) ** gamma) * tgt_log_p
    if weight is not None:
        loss = loss * weight[targets]
    return loss.mean()

# ── epoch loop ──────────────────────────────────────────────────────
def run_epoch(model, loader, loss_fn, optimizer=None):
    train = optimizer is not None
    model.train() if train else model.eval()
    loss_sum, correct, total = 0.0, 0, 0
    for x_cat, y in loader:
        x_cat, y = x_cat.to(device), y.to(device)
        logits   = model(None, x_cat)
        loss     = loss_fn(logits, y)
        if train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        loss_sum += loss.item() * y.size(0)
        correct  += (logits.argmax(1) == y).sum().item()
        total    += y.size(0)
    return loss_sum / total, correct / total

# ── Optuna objective ────────────────────────────────────────────────
device    = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_wts  = torch.tensor(
    compute_class_weight("balanced", classes=np.unique(y_train), y=y_train),
    dtype=torch.float32, device=device
)

def objective(trial):
    # ── hyper-parameters to search ──────────────────────────────────
    d_block          = trial.suggest_categorical("d_block",  [128, 192, 256, 320])
    n_blocks         = trial.suggest_int        ("n_blocks", 2, 6)
    n_heads          = trial.suggest_categorical("attention_n_heads", [4, 8])
    attn_dropout     = trial.suggest_float      ("attention_dropout", 0.0, 0.4)
    ffn_dropout      = trial.suggest_float      ("ffn_dropout",       0.0, 0.4)
    residual_dropout = trial.suggest_float      ("residual_dropout",  0.0, 0.3)
    ffn_mult         = trial.suggest_float      ("ffn_d_hidden_multiplier", 1.0, 6.0)
    lr               = trial.suggest_float      ("lr", 1e-4, 1e-3, log=True)
    batch_size       = trial.suggest_categorical("batch_size", [128, 256, 512])
    loss_type        = trial.suggest_categorical("loss_type", ["ce", "focal"])

    # pick loss function object once per trial
    if loss_type == "ce":
        loss_fn = lambda logits, y: F.cross_entropy(
            logits, y, weight=base_wts, label_smoothing=0.1
        )
    else:  # focal
        loss_fn = lambda logits, y: focal_loss(logits, y, gamma=2.0, weight=base_wts)

    cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
    fold_accs = []

    for tr_idx, va_idx in cv.split(X_train_cat, y_train):
        # ── Datasets ────────────────────────────────────────────────
        X_tr, y_tr = X_train_cat[tr_idx], y_train[tr_idx]
        X_va, y_va = X_train_cat[va_idx], y_train[va_idx]
        ds_tr = CatOnlyDataset(X_tr, y_tr)
        ds_va = CatOnlyDataset(X_va, y_va)

        # balanced sampler for the training fold
        class_counts = np.bincount(y_tr)
        samp_wts     = 1.0 / class_counts[y_tr]
        sampler      = WeightedRandomSampler(
            weights=samp_wts, num_samples=len(samp_wts), replacement=True
        )

        dl_tr = DataLoader(ds_tr, batch_size=batch_size, sampler=sampler)
        dl_va = DataLoader(ds_va, batch_size=batch_size, shuffle=False)

        # ── Model ───────────────────────────────────────────────────
        model = FTTransformer(
            n_cont_features=0,
            cat_cardinalities=cat_cardinalities,
            d_out=n_classes,
            n_blocks=n_blocks,
            d_block=d_block,
            attention_n_heads=n_heads,
            attention_dropout=attn_dropout,
            ffn_d_hidden_multiplier=ffn_mult,
            ffn_dropout=ffn_dropout,
            residual_dropout=residual_dropout,
        ).to(device)

        # cosine annealing schedule with warm-up
        opt = torch.optim.AdamW(model.make_parameter_groups(), lr=lr, weight_decay=1e-2)
        sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=20)

        best_acc, patience, wait = 0.0, 10, 0
        for epoch in range(100):
            run_epoch(model, dl_tr, loss_fn, opt)
            sched.step()
            _, val_acc = run_epoch(model, dl_va, loss_fn)
            if val_acc > best_acc:
                best_acc, wait = val_acc, 0
            else:
                wait += 1
            if wait >= patience:
                break

        fold_accs.append(best_acc)

    return float(np.mean(fold_accs))

# ── run the Optuna study ────────────────────────────────────────────
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=16, timeout=60*60)   # ↑ trials; ↑ timeout

print("Best mean-CV accuracy:", study.best_value)
print("Best hyper-parameters:", study.best_params)

[I 2025-06-19 17:29:31,466] A new study created in memory with name: no-name-3703176d-1e8c-4a82-9225-169c521d72c8
[I 2025-06-19 18:01:32,868] Trial 0 finished with value: 0.47447323907136835 and parameters: {'d_block': 192, 'n_blocks': 2, 'attention_n_heads': 8, 'attention_dropout': 0.049086378176820225, 'ffn_dropout': 0.3251215549423043, 'residual_dropout': 0.10062536302844134, 'ffn_d_hidden_multiplier': 3.271520661705991, 'lr': 0.00016224821579518352, 'batch_size': 256, 'loss_type': 'focal'}. Best is trial 0 with value: 0.47447323907136835.
[I 2025-06-19 20:31:35,646] Trial 1 finished with value: 0.4768343067973525 and parameters: {'d_block': 256, 'n_blocks': 6, 'attention_n_heads': 8, 'attention_dropout': 0.12443808572882027, 'ffn_dropout': 0.19462667261713232, 'residual_dropout': 0.24851363274917035, 'ffn_d_hidden_multiplier': 5.069675871882938, 'lr': 0.0001650455464576478, 'batch_size': 128, 'loss_type': 'focal'}. Best is trial 1 with value: 0.4768343067973525.


Best mean-CV accuracy: 0.4768343067973525
Best hyper-parameters: {'d_block': 256, 'n_blocks': 6, 'attention_n_heads': 8, 'attention_dropout': 0.12443808572882027, 'ffn_dropout': 0.19462667261713232, 'residual_dropout': 0.24851363274917035, 'ffn_d_hidden_multiplier': 5.069675871882938, 'lr': 0.0001650455464576478, 'batch_size': 128, 'loss_type': 'focal'}


# Evaluation

In [None]:
# Cell 4 ──────────────────────────────────────────────────────────────
best = study.best_params
loss_fn = (
    (lambda l, y: F.cross_entropy(l, y, weight=base_wts, label_smoothing=0.1))
    if best["loss_type"] == "ce"
    else (lambda l, y: focal_loss(l, y, weight=base_wts))
)

model = FTTransformer(
    n_cont_features=0,
    cat_cardinalities=cat_cardinalities,
    d_out=n_classes,
    n_blocks=best["n_blocks"],
    d_block=best["d_block"],
    attention_n_heads=best["attention_n_heads"],
    attention_dropout=best["attention_dropout"],
    ffn_d_hidden_multiplier=best["ffn_d_hidden_multiplier"],
    ffn_dropout=best["ffn_dropout"],
    residual_dropout=best["residual_dropout"],
).to(device)

opt   = torch.optim.AdamW(model.make_parameter_groups(), lr=best["lr"], weight_decay=1e-2)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=20)

# balanced sampler on entire training set
counts  = np.bincount(y_train)
weights = 1.0 / counts[y_train]
sampler = WeightedRandomSampler(weights, len(weights), replacement=True)
loader_tr = DataLoader(CatOnlyDataset(X_train_cat, y_train),
                       batch_size=best["batch_size"], sampler=sampler)

for epoch in range(50):
    run_epoch(model, loader_tr, loss_fn, opt); sched.step()
print("Final training done.")


Final training done.


In [None]:
# Cell 5 ──────────────────────────────────────────────────────────────
from sklearn.metrics import (
    accuracy_score, roc_auc_score, confusion_matrix,
    classification_report, f1_score
)
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

model.eval()
test_loader = DataLoader(
    CatOnlyDataset(X_test_cat, y_test),
    batch_size=best["batch_size"], shuffle=False
)

all_logits, all_labels = [], []
with torch.no_grad():
    for x_cat, y in test_loader:
        logits = model(None, x_cat.to(device))
        all_logits.append(logits.cpu())
        all_labels.append(y)
logits = torch.cat(all_logits).numpy()
y_true = torch.cat(all_labels).numpy()
y_pred = logits.argmax(1)

# ── Metrics ---------------------------------------------------------
acc      = accuracy_score(y_true, y_pred)
macro_f1 = f1_score(y_true, y_pred, average="macro")
micro_f1 = f1_score(y_true, y_pred, average="micro")

# ROC-AUC (ovo)
probs = torch.softmax(torch.tensor(logits), 1).numpy()
try:
    auc_macro = roc_auc_score(y_true, probs, multi_class="ovo", average="macro")
except ValueError:
    auc_macro = float("nan")  # happens if a class absent in y_true

print(f"Accuracy      : {acc:.4f}")
print(f"Macro F1      : {macro_f1:.4f}")
print(f"Micro F1      : {micro_f1:.4f}")
print(f"Macro ROC-AUC : {auc_macro:.4f}\n")
print(classification_report(y_true, y_pred))

# ── Confusion matrix ------------------------------------------------
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(6,6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=SUB1_KEEP, yticklabels=SUB1_KEEP)
plt.xlabel("Predicted"); plt.ylabel("True"); plt.title("Confusion Matrix")
plt.show()
# ── Cell 5 ───────────────────────────────────────────────────────────
from sklearn.metrics import (
    accuracy_score, roc_auc_score, confusion_matrix,
    classification_report, f1_score
)
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# ── inference ───────────────────────────────────────────────────────
model.eval()
test_loader = DataLoader(
    CatOnlyDataset(X_test_cat, y_test),
    batch_size=best["batch_size"],
    shuffle=False
)

all_logits, all_labels = [], []
with torch.no_grad():
    for x_cat, y in test_loader:
        logits = model(None, x_cat.to(device))
        all_logits.append(logits.cpu())
        all_labels.append(y)
logits = torch.cat(all_logits).numpy()
y_true = torch.cat(all_labels).numpy()
y_pred = logits.argmax(1)

# ── metrics ---------------------------------------------------------
acc      = accuracy_score(y_true, y_pred)
macro_f1 = f1_score(y_true, y_pred, average="macro")
micro_f1 = f1_score(y_true, y_pred, average="micro")

probs = torch.softmax(torch.tensor(logits), 1).numpy()
try:
    auc_macro = roc_auc_score(y_true, probs, multi_class="ovo", average="macro")
except ValueError:
    auc_macro = float("nan")          # occurs if a class absent in y_true

print(f"Accuracy      : {acc:.4f}")
print(f"Macro F1      : {macro_f1:.4f}")
print(f"Micro F1      : {micro_f1:.4f}")
print(f"Macro ROC-AUC : {auc_macro:.4f}\n")
print(classification_report(y_true, y_pred))

# ── confusion matrix -----------------------------------------------
cm     = confusion_matrix(y_true, y_pred)
cm_df  = pd.DataFrame(cm, index=SUB1_KEEP, columns=SUB1_KEEP)

plt.figure(figsize=(8, 6))                    # enlarge canvas
sns.heatmap(
    cm_df,
    annot=True,
    fmt="d",
    cmap="Blues",
    cbar=False,
    square=True,
    linewidths=0.5,
    linecolor="white",
    annot_kws={"size": 14}                    # smaller font if needed
)
plt.xlabel("Predicted label")
plt.ylabel("True label")
plt.title("Confusion Matrix")
plt.yticks(rotation=0)                        # keep y-axis labels horizontal
plt.tight_layout()
plt.show()


In [None]:
# Cell 6 ──────────────────────────────────────────────────────────────
# Feature importance via permutation on the test set
from copy import deepcopy
from sklearn.metrics import f1_score
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# ── baseline performance (already computed) ─────────────────────────
baseline_f1 = f1_score(y_true, y_pred, average="macro")

def perm_importance(model, X_base, y, metric, batch_size):
    """Return drop in metric for every feature after permutation."""
    drops = []
    X_perm = deepcopy(X_base)

    for col_idx in range(X_base.shape[1]):
        # shuffle a single column
        np.random.shuffle(X_perm[:, col_idx])

        # re-evaluate
        ds = CatOnlyDataset(X_perm, y)
        dl = DataLoader(ds, batch_size=batch_size, shuffle=False)

        all_logits = []
        with torch.no_grad():
            for xb, _ in dl:
                logits = model(None, xb.to(device))
                all_logits.append(logits.cpu())
        logits = torch.cat(all_logits).numpy()
        y_hat = logits.argmax(1)
        score = metric(y, y_hat)

        drops.append(baseline_f1 - score)

        # restore the column for the next round
        X_perm[:, col_idx] = X_base[:, col_idx]

    return np.array(drops)

drops = perm_importance(model, X_test_cat.copy(), y_test,
                        lambda y, yhat: f1_score(y, yhat, average="macro"),
                        best["batch_size"])

# ── tidy & plot ─────────────────────────────────────────────────────
imp_df = pd.DataFrame({
    "feature": feature_cols,
    "drop_in_macro_f1": drops
}).sort_values("drop_in_macro_f1", ascending=False)

plt.figure(figsize=(10, 6))
plt.barh(imp_df["feature"], imp_df["drop_in_macro_f1"])
plt.gca().invert_yaxis()
plt.xlabel("Macro-F1 drop when permuted (higher = more important)")
plt.title("Permutation feature importance — FT-Transformer")
plt.tight_layout()
plt.show()

imp_df.head()


# Save model

In [None]:
# ── Cell 7 ───────────────────────────────────────────────────────────
# Save the trained FT-Transformer and its metadata

import torch
from datetime import datetime
import json
from pathlib import Path

# directory + filename
save_dir  = Path("checkpoints")
save_dir.mkdir(exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
ckpt_path = save_dir / f"ft_transformer_{timestamp}.pt"

# model weights + essential metadata
checkpoint = {
    "model_state": model.state_dict(),
    "cat_cardinalities": cat_cardinalities,
    "sub1_to_idx": sub1_to_idx,          # label mapping
    "best_params": best,                 # hyper-parameters from Optuna
    "seed": SEED,
}

torch.save(checkpoint, ckpt_path)
print(f"Model saved to: {ckpt_path.resolve()}")

# (optional) persist the label mapping as a JSON side-car
json_path = ckpt_path.with_suffix(".json")
with open(json_path, "w") as fp:
    json.dump({"sub1_to_idx": sub1_to_idx}, fp, indent=2)
print(f"Label map saved to: {json_path.resolve()}")


In [None]:
# load model from file code

# import torch
# from rtdl_revisiting_models import FTTransformer

# ckpt = torch.load("checkpoints/ft_transformer_YYYYMMDD_HHMMSS.pt", map_location="cpu")

# model = FTTransformer(
#     n_cont_features=0,
#     cat_cardinalities=ckpt["cat_cardinalities"],
#     d_out=len(ckpt["sub1_to_idx"]),
#     **{k: v for k, v in ckpt["best_params"].items()
#        if k in ("n_blocks", "d_block", "attention_n_heads",
#                 "attention_dropout", "ffn_d_hidden_multiplier",
#                 "ffn_dropout", "residual_dropout")}
# )
# model.load_state_dict(ckpt["model_state"])
# model.eval()
