In [1]:
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader, WeightedRandomSampler

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, f1_score

from f2_preprocessor import Preprocessor

import optuna

from tabm import TabM
from rtdl_num_embeddings import LinearReLUEmbeddings  # simple but good :contentReference[oaicite:1]{index=1}

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
num_cols = [
    "year_of_born",
    "email_or_tel_available",
    "safety_rating",
    "annual_income",
    "high_education_ind",
    "address_change_ind",
    "past_num_of_claims",
    "liab_prct",
    "policy_report_filed_ind",
    "claim_est_payout",
    "vehicle_made_year",
    "vehicle_price",
    "vehicle_weight",
    "age_of_DL",
    "vehicle_mileage",
]

cat_cols = [
    "gender",
    "living_status",
    "zip_code",
    "claim_day_of_week",
    "accident_site",
    "witness_present_ind",
    "channel",
    "vehicle_category",
    "vehicle_color",
    "accident_type",
    "in_network_bodyshop",
]

In [3]:
def compute_best_f1(probs: np.ndarray,
                    targets: np.ndarray,
                    thresholds: np.ndarray | None = None):
    """
    probs:    shape (N,), predicted positive probabilities
    targets:  shape (N,), 0/1 labels
    thresholds: optional array of candidate thresholds (0..1)

    Returns (best_f1, best_threshold)
    """
    if thresholds is None:
        # Coarse but reasonable grid; you can make it denser if you want
        thresholds = np.linspace(0.05, 0.95, 19)

    best_f1 = 0.0
    best_t = 0.5

    for t in thresholds:
        preds = (probs >= t).astype(int)
        f1 = f1_score(targets, preds, zero_division=0)
        if f1 > best_f1:
            best_f1 = f1
            best_t = t

    return best_f1, best_t


In [4]:
df = pd.read_csv("data/Training_TriGuard.csv")
df = df.dropna(subset=['subrogation'])

df = df.drop(columns=["claim_number"], errors="ignore")

target_col = "subrogation"
y_all = df[target_col].to_numpy().astype(np.int64)

In [5]:
# Compute global pos_weight for BCEWithLogitsLoss
pos_frac = y_all.mean()
neg_frac = 1.0 - pos_frac

# Avoid division by zero – but you almost surely have both classes
pos_weight_value = neg_frac / max(pos_frac, 1e-6)
print("Positive class fraction:", pos_frac)
print("pos_weight:", pos_weight_value)

device = "cpu"

Positive class fraction: 0.2286238124340241
pos_weight: 3.373997569866343


In [6]:
def build_model(n_num_features: int,
                cat_cardinalities,
                trial: optuna.Trial) -> nn.Module:
    # EDIT: tighten but beef up architecture search space
    n_blocks = trial.suggest_int("n_blocks", 3, 5)                 # was 2–5
    d_block  = trial.suggest_int("d_block", 384, 1024, log=True)   # was 256–1024
    dropout  = trial.suggest_float("dropout", 0.0, 0.4)
    k        = trial.suggest_int("k", 8, 32, step=8)               # keep 8–32

    num_emb = LinearReLUEmbeddings(n_num_features)

    model = TabM.make(
        n_num_features=n_num_features,
        num_embeddings=num_emb,
        cat_cardinalities=cat_cardinalities,
        d_out=1,
        n_blocks=n_blocks,
        d_block=d_block,
        dropout=dropout,
        k=k,
    )

    return model.to(device)

In [7]:
def train_one_fold(model,
                   train_loader,
                   valid_loader,
                   n_epochs: int,
                   lr: float,
                   weight_decay: float,
                   pos_weight_eff: float) -> float:
    # EDIT: pass in effective pos_weight (after scaling)
    pos_weight = torch.tensor([pos_weight_eff], device=device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay,
    )

    # -------- TRAIN --------
    for epoch in range(n_epochs):
        model.train()
        for X_num_b, X_cat_b, y_b in train_loader:
            X_num_b = X_num_b.to(device)
            X_cat_b = X_cat_b.to(device)
            y_b = y_b.to(device)   # (B, 1)

            logits = model(X_num_b, X_cat_b)          # (B, k, 1)
            B, k, _ = logits.shape

            # EDIT: use reshape instead of view
            logits_flat = logits.reshape(B * k, 1)    # (B*k, 1)
            y_flat = y_b.repeat_interleave(k, dim=0)  # (B*k, 1)

            loss = criterion(logits_flat, y_flat)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # -------- VALID: compute probs and F1 --------
    model.eval()
    all_probs = []
    all_targets = []

    with torch.no_grad():
        for X_num_b, X_cat_b, y_b in valid_loader:
            X_num_b = X_num_b.to(device)
            X_cat_b = X_cat_b.to(device)
            y_b = y_b.to(device)

            logits = model(X_num_b, X_cat_b)              # (B, k, 1)
            probs = torch.sigmoid(logits).mean(dim=1)     # (B, 1)
            probs = probs.squeeze(-1).cpu().numpy()
            targets = y_b.squeeze(-1).cpu().numpy()

            all_probs.append(probs)
            all_targets.append(targets)

    all_probs = np.concatenate(all_probs)
    all_targets = np.concatenate(all_targets)

    # Optional defensive mask
    mask = ~np.isnan(all_probs) & ~np.isnan(all_targets)
    all_probs = all_probs[mask]
    all_targets = all_targets[mask]

    if len(np.unique(all_targets)) < 2:
        return 0.0

    best_f1, best_t = compute_best_f1(all_probs, all_targets)
    return float(best_f1)


In [12]:
def objective(trial: optuna.Trial) -> float:
    # EDIT: more suitable ranges for TabM on imbalanced data
    lr = trial.suggest_float("lr", 5e-4, 1e-2, log=True)                 # was 1e-4–3e-3
    weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-2, log=True)
    batch_size = trial.suggest_categorical("batch_size", [64, 128, 256]) # was 256–1024
    n_epochs = trial.suggest_int("n_epochs", 20, 40)                     # was 5–10

    # EDIT: let Optuna scale pos_weight
    pos_weight_scale = trial.suggest_float("pos_weight_scale", 0.5, 3.0)
    pos_weight_eff = pos_weight_value * pos_weight_scale

    n_splits = 3
    skf = StratifiedKFold(
        n_splits=n_splits,
        shuffle=True,
        random_state=42,
    )

    fold_f1s = []

    for fold_idx, (train_idx, valid_idx) in enumerate(skf.split(df, y_all)):
        train_df = df.iloc[train_idx].reset_index(drop=True)
        valid_df = df.iloc[valid_idx].reset_index(drop=True)

        preproc = Preprocessor(num_cols=num_cols, cat_cols=cat_cols,
                                   target_col=target_col)

        X_num_train, X_cat_train, y_train = preproc.fit_transform(train_df)
        X_num_valid, X_cat_valid, y_valid = preproc.transform(valid_df)

        n_num_features = X_num_train.shape[1]
        cat_cardinalities = [
            preproc.cat_cardinalities_[c] for c in cat_cols
        ]

        model = build_model(n_num_features, cat_cardinalities, trial)

        X_num_tr_t = torch.from_numpy(X_num_train)
        X_cat_tr_t = torch.from_numpy(X_cat_train)
        y_tr_t = torch.from_numpy(y_train)

        X_num_va_t = torch.from_numpy(X_num_valid)
        X_cat_va_t = torch.from_numpy(X_cat_valid)
        y_va_t = torch.from_numpy(y_valid)

        train_dataset = TensorDataset(X_num_tr_t, X_cat_tr_t, y_tr_t)
        valid_dataset = TensorDataset(X_num_va_t, X_cat_va_t, y_va_t)

        # Imbalance-aware sampler
        y_train_flat = y_train.ravel()
        class_counts = np.bincount(y_train_flat.astype(int))
        class_counts = np.maximum(class_counts, 1)
        class_weights = 1.0 / class_counts
        sample_weights = class_weights[y_train_flat.astype(int)]

        sampler = WeightedRandomSampler(
            weights=torch.from_numpy(sample_weights).float(),
            num_samples=len(sample_weights),
            replacement=True,
        )

        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            sampler=sampler,
            drop_last=False,
        )
        valid_loader = DataLoader(
            valid_dataset,
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
        )

        fold_f1 = train_one_fold(
            model=model,
            train_loader=train_loader,
            valid_loader=valid_loader,
            n_epochs=n_epochs,
            lr=lr,
            weight_decay=weight_decay,
            pos_weight_eff=pos_weight_eff,       # EDIT: pass scaled pos_weight
        )
        fold_f1s.append(fold_f1)

        trial.report(fold_f1, step=fold_idx)
        if trial.should_prune():
            raise optuna.TrialPruned()

    mean_f1 = float(np.mean(fold_f1s))
    return mean_f1


In [13]:
pruner = optuna.pruners.MedianPruner(n_warmup_steps=1)

study = optuna.create_study(
    direction="maximize",
    study_name="tabm_subrogation_cv_f1",
    pruner=pruner
)
study.optimize(objective, n_trials=50, show_progress_bar=True)

print("Best mean CV F1:", study.best_value)
print("Best params:", study.best_params)


[I 2025-11-13 06:31:37,618] A new study created in memory with name: tabm_subrogation_cv_f1
Best trial: 0. Best value: 0.447372:   2%|▏         | 1/50 [11:00<8:59:04, 660.10s/it]

[I 2025-11-13 06:42:37,718] Trial 0 finished with value: 0.44737241605974015 and parameters: {'lr': 0.0006894437249902676, 'weight_decay': 0.0075725263336484565, 'batch_size': 256, 'n_epochs': 22, 'pos_weight_scale': 1.2057968350977348, 'n_blocks': 4, 'd_block': 803, 'dropout': 0.3888584189356692, 'k': 32}. Best is trial 0 with value: 0.44737241605974015.


Best trial: 0. Best value: 0.447372:   4%|▍         | 2/50 [14:14<5:08:52, 386.10s/it]

[I 2025-11-13 06:45:52,020] Trial 1 finished with value: 0.37216242076746936 and parameters: {'lr': 0.0008640415443663024, 'weight_decay': 0.0007194560755780915, 'batch_size': 128, 'n_epochs': 33, 'pos_weight_scale': 1.6997170578629168, 'n_blocks': 3, 'd_block': 707, 'dropout': 0.38597094914730634, 'k': 8}. Best is trial 0 with value: 0.44737241605974015.


Best trial: 2. Best value: 0.514259:   6%|▌         | 3/50 [17:55<4:03:34, 310.94s/it]

[I 2025-11-13 06:49:33,516] Trial 2 finished with value: 0.5142590882550425 and parameters: {'lr': 0.0027058591983204284, 'weight_decay': 1.2756419998056076e-05, 'batch_size': 64, 'n_epochs': 32, 'pos_weight_scale': 2.2054918833171677, 'n_blocks': 3, 'd_block': 719, 'dropout': 0.199983380294276, 'k': 8}. Best is trial 2 with value: 0.5142590882550425.


Best trial: 3. Best value: 0.574809:   8%|▊         | 4/50 [21:32<3:29:45, 273.60s/it]

[I 2025-11-13 06:53:09,874] Trial 3 finished with value: 0.5748090059157763 and parameters: {'lr': 0.0030627853899652877, 'weight_decay': 0.0007249030562731436, 'batch_size': 256, 'n_epochs': 31, 'pos_weight_scale': 2.990947861811148, 'n_blocks': 4, 'd_block': 403, 'dropout': 0.07578104283428044, 'k': 16}. Best is trial 3 with value: 0.5748090059157763.


Best trial: 3. Best value: 0.574809:  10%|█         | 5/50 [26:13<3:27:21, 276.48s/it]

[I 2025-11-13 06:57:51,455] Trial 4 finished with value: 0.4454387280353422 and parameters: {'lr': 0.0026650710919662447, 'weight_decay': 0.004928211803621021, 'batch_size': 256, 'n_epochs': 36, 'pos_weight_scale': 2.1790653788417216, 'n_blocks': 3, 'd_block': 559, 'dropout': 0.2996527018617813, 'k': 16}. Best is trial 3 with value: 0.5748090059157763.


Best trial: 3. Best value: 0.574809:  12%|█▏        | 6/50 [37:38<5:04:24, 415.10s/it]

[I 2025-11-13 07:09:15,657] Trial 5 pruned. 


Best trial: 3. Best value: 0.574809:  14%|█▍        | 7/50 [38:45<3:36:02, 301.45s/it]

[I 2025-11-13 07:10:23,105] Trial 6 pruned. 


Best trial: 7. Best value: 0.588873:  16%|█▌        | 8/50 [49:35<4:48:35, 412.27s/it]

[I 2025-11-13 07:21:12,657] Trial 7 finished with value: 0.5888730821923267 and parameters: {'lr': 0.003427141435105846, 'weight_decay': 1.0364658836025416e-05, 'batch_size': 128, 'n_epochs': 36, 'pos_weight_scale': 2.8358930891064613, 'n_blocks': 4, 'd_block': 511, 'dropout': 0.026685800430954833, 'k': 32}. Best is trial 7 with value: 0.5888730821923267.


Best trial: 8. Best value: 0.589681:  18%|█▊        | 9/50 [57:49<4:59:18, 438.02s/it]

[I 2025-11-13 07:29:27,289] Trial 8 finished with value: 0.5896811151700324 and parameters: {'lr': 0.0032548628650264796, 'weight_decay': 0.0003218770404710231, 'batch_size': 256, 'n_epochs': 36, 'pos_weight_scale': 2.7024638111485357, 'n_blocks': 5, 'd_block': 443, 'dropout': 0.13818567440588492, 'k': 24}. Best is trial 8 with value: 0.5896811151700324.


Best trial: 8. Best value: 0.589681:  20%|██        | 10/50 [1:08:26<5:33:00, 499.50s/it]

[I 2025-11-13 07:40:04,465] Trial 9 pruned. 


Best trial: 8. Best value: 0.589681:  22%|██▏       | 11/50 [1:25:23<7:07:34, 657.80s/it]

[I 2025-11-13 07:57:01,190] Trial 10 pruned. 


Best trial: 8. Best value: 0.589681:  24%|██▍       | 12/50 [1:32:46<6:15:16, 592.53s/it]

[I 2025-11-13 08:04:24,450] Trial 11 finished with value: 0.5855799675606025 and parameters: {'lr': 0.0013227319286114238, 'weight_decay': 4.8354219891694756e-06, 'batch_size': 128, 'n_epochs': 27, 'pos_weight_scale': 2.8980165632147137, 'n_blocks': 5, 'd_block': 505, 'dropout': 0.01073325533386861, 'k': 24}. Best is trial 8 with value: 0.5896811151700324.


Best trial: 8. Best value: 0.589681:  26%|██▌       | 13/50 [1:39:46<5:33:03, 540.08s/it]

[I 2025-11-13 08:11:23,851] Trial 12 pruned. 


Best trial: 8. Best value: 0.589681:  28%|██▊       | 14/50 [1:47:23<5:08:58, 514.96s/it]

[I 2025-11-13 08:19:00,753] Trial 13 finished with value: 0.441769269946914 and parameters: {'lr': 0.004623948021302192, 'weight_decay': 0.0007361975855827028, 'batch_size': 128, 'n_epochs': 28, 'pos_weight_scale': 2.5393709432518077, 'n_blocks': 4, 'd_block': 469, 'dropout': 0.11068505284382774, 'k': 32}. Best is trial 8 with value: 0.5896811151700324.


Best trial: 8. Best value: 0.589681:  30%|███       | 15/50 [1:59:02<5:32:47, 570.50s/it]

[I 2025-11-13 08:30:39,955] Trial 14 finished with value: 0.5843159211926904 and parameters: {'lr': 0.0015650119399879328, 'weight_decay': 4.087671005482316e-06, 'batch_size': 64, 'n_epochs': 34, 'pos_weight_scale': 2.582653079901241, 'n_blocks': 5, 'd_block': 588, 'dropout': 0.1791149974692297, 'k': 24}. Best is trial 8 with value: 0.5896811151700324.


Best trial: 8. Best value: 0.589681:  32%|███▏      | 16/50 [2:02:02<4:16:47, 453.15s/it]

[I 2025-11-13 08:33:40,599] Trial 15 pruned. 


Best trial: 8. Best value: 0.589681:  34%|███▍      | 17/50 [2:09:12<4:05:23, 446.15s/it]

[I 2025-11-13 08:40:50,483] Trial 16 pruned. 


Best trial: 8. Best value: 0.589681:  36%|███▌      | 18/50 [2:14:35<3:38:12, 409.15s/it]

[I 2025-11-13 08:46:13,506] Trial 17 pruned. 


Best trial: 8. Best value: 0.589681:  38%|███▊      | 19/50 [2:18:46<3:06:44, 361.44s/it]

[I 2025-11-13 08:50:23,798] Trial 18 pruned. 


Best trial: 8. Best value: 0.589681:  40%|████      | 20/50 [2:22:27<2:39:43, 319.44s/it]

[I 2025-11-13 08:54:05,332] Trial 19 finished with value: 0.5866627062254125 and parameters: {'lr': 0.001879814192250539, 'weight_decay': 5.0202170124308736e-05, 'batch_size': 64, 'n_epochs': 20, 'pos_weight_scale': 2.794715867814151, 'n_blocks': 4, 'd_block': 388, 'dropout': 0.09748267428969998, 'k': 24}. Best is trial 8 with value: 0.5896811151700324.


Best trial: 8. Best value: 0.589681:  42%|████▏     | 21/50 [2:27:01<2:27:47, 305.78s/it]

[I 2025-11-13 08:58:39,259] Trial 20 pruned. 


Best trial: 8. Best value: 0.589681:  44%|████▍     | 22/50 [2:30:40<2:10:34, 279.80s/it]

[I 2025-11-13 09:02:18,478] Trial 21 finished with value: 0.5868780993955817 and parameters: {'lr': 0.0018851654941083775, 'weight_decay': 4.598401087819166e-05, 'batch_size': 64, 'n_epochs': 20, 'pos_weight_scale': 2.787153586009761, 'n_blocks': 4, 'd_block': 386, 'dropout': 0.11087530550429407, 'k': 24}. Best is trial 8 with value: 0.5896811151700324.


Best trial: 8. Best value: 0.589681:  46%|████▌     | 23/50 [2:35:50<2:09:58, 288.82s/it]

[I 2025-11-13 09:07:28,353] Trial 22 finished with value: 0.5884526979205732 and parameters: {'lr': 0.0011256567670956763, 'weight_decay': 1.129348505677873e-05, 'batch_size': 64, 'n_epochs': 25, 'pos_weight_scale': 2.99814571619338, 'n_blocks': 4, 'd_block': 429, 'dropout': 0.12451363153780429, 'k': 24}. Best is trial 8 with value: 0.5896811151700324.


Best trial: 8. Best value: 0.589681:  48%|████▊     | 24/50 [2:40:59<2:07:43, 294.74s/it]

[I 2025-11-13 09:12:36,909] Trial 23 finished with value: 0.5891005073400414 and parameters: {'lr': 0.0011669420713321954, 'weight_decay': 5.343022660990404e-06, 'batch_size': 64, 'n_epochs': 25, 'pos_weight_scale': 2.9738420693000664, 'n_blocks': 4, 'd_block': 435, 'dropout': 0.13719724078533624, 'k': 24}. Best is trial 8 with value: 0.5896811151700324.


Best trial: 24. Best value: 0.592961:  50%|█████     | 25/50 [2:48:48<2:24:36, 347.07s/it]

[I 2025-11-13 09:20:26,053] Trial 24 finished with value: 0.5929613710472844 and parameters: {'lr': 0.0010500074002639942, 'weight_decay': 3.116059262324199e-06, 'batch_size': 64, 'n_epochs': 27, 'pos_weight_scale': 2.4122588304716297, 'n_blocks': 4, 'd_block': 470, 'dropout': 0.14738306190566103, 'k': 32}. Best is trial 24 with value: 0.5929613710472844.


Best trial: 24. Best value: 0.592961:  52%|█████▏    | 26/50 [2:52:05<2:00:53, 302.21s/it]

[I 2025-11-13 09:23:43,608] Trial 25 finished with value: 0.5856407817007506 and parameters: {'lr': 0.0010070954012964624, 'weight_decay': 2.8881629403343535e-06, 'batch_size': 64, 'n_epochs': 26, 'pos_weight_scale': 2.3927880708455618, 'n_blocks': 3, 'd_block': 467, 'dropout': 0.23153857508201933, 'k': 16}. Best is trial 24 with value: 0.5929613710472844.


Best trial: 24. Best value: 0.592961:  54%|█████▍    | 27/50 [2:58:06<2:02:29, 319.56s/it]

[I 2025-11-13 09:29:43,658] Trial 26 finished with value: 0.5919070646690239 and parameters: {'lr': 0.0007793022417980855, 'weight_decay': 2.3467616459061777e-06, 'batch_size': 64, 'n_epochs': 30, 'pos_weight_scale': 2.001279626733373, 'n_blocks': 4, 'd_block': 423, 'dropout': 0.15697942458882605, 'k': 24}. Best is trial 24 with value: 0.5929613710472844.


Best trial: 24. Best value: 0.592961:  56%|█████▌    | 28/50 [3:05:13<2:08:59, 351.81s/it]

[I 2025-11-13 09:36:50,712] Trial 27 pruned. 


Best trial: 24. Best value: 0.592961:  58%|█████▊    | 29/50 [3:12:10<2:09:58, 371.36s/it]

[I 2025-11-13 09:43:47,672] Trial 28 finished with value: 0.5900443778209673 and parameters: {'lr': 0.0005694748172241117, 'weight_decay': 2.4078932638035155e-05, 'batch_size': 64, 'n_epochs': 29, 'pos_weight_scale': 1.6256495964670281, 'n_blocks': 5, 'd_block': 416, 'dropout': 0.23152102706877298, 'k': 24}. Best is trial 24 with value: 0.5929613710472844.


Best trial: 24. Best value: 0.592961:  60%|██████    | 30/50 [3:21:49<2:24:36, 433.85s/it]

[I 2025-11-13 09:53:27,319] Trial 29 pruned. 


Best trial: 24. Best value: 0.592961:  62%|██████▏   | 31/50 [3:25:43<1:58:23, 373.85s/it]

[I 2025-11-13 09:57:21,195] Trial 30 finished with value: 0.5893235291368971 and parameters: {'lr': 0.0006376852053473165, 'weight_decay': 6.827590413206087e-06, 'batch_size': 64, 'n_epochs': 27, 'pos_weight_scale': 1.670743960986234, 'n_blocks': 4, 'd_block': 416, 'dropout': 0.26458224281345855, 'k': 16}. Best is trial 24 with value: 0.5929613710472844.


Best trial: 24. Best value: 0.592961:  64%|██████▍   | 32/50 [3:34:10<2:04:07, 413.74s/it]

[I 2025-11-13 10:05:48,007] Trial 31 finished with value: 0.5916380993151306 and parameters: {'lr': 0.0007918407885989136, 'weight_decay': 0.00030952955709413597, 'batch_size': 64, 'n_epochs': 31, 'pos_weight_scale': 1.8996995160943286, 'n_blocks': 5, 'd_block': 472, 'dropout': 0.18798197244231782, 'k': 24}. Best is trial 24 with value: 0.5929613710472844.


Best trial: 24. Best value: 0.592961:  66%|██████▌   | 33/50 [3:42:38<2:05:16, 442.14s/it]

[I 2025-11-13 10:14:16,417] Trial 32 finished with value: 0.5922824879183252 and parameters: {'lr': 0.0008045880791527643, 'weight_decay': 9.342866591198316e-05, 'batch_size': 64, 'n_epochs': 31, 'pos_weight_scale': 1.9806094047461842, 'n_blocks': 5, 'd_block': 469, 'dropout': 0.19805151055504655, 'k': 24}. Best is trial 24 with value: 0.5929613710472844.


Best trial: 24. Best value: 0.592961:  68%|██████▊   | 34/50 [3:51:25<2:04:38, 467.42s/it]

[I 2025-11-13 10:23:02,826] Trial 33 finished with value: 0.5913774526615051 and parameters: {'lr': 0.0008274927963727642, 'weight_decay': 2.107127965089021e-06, 'batch_size': 64, 'n_epochs': 32, 'pos_weight_scale': 2.0169853040227643, 'n_blocks': 5, 'd_block': 469, 'dropout': 0.1942082096728207, 'k': 24}. Best is trial 24 with value: 0.5929613710472844.


Best trial: 24. Best value: 0.592961:  70%|███████   | 35/50 [4:01:42<2:08:03, 512.24s/it]

[I 2025-11-13 10:33:19,637] Trial 34 finished with value: 0.588932016158647 and parameters: {'lr': 0.0009637264643223816, 'weight_decay': 0.00010068950175294807, 'batch_size': 64, 'n_epochs': 31, 'pos_weight_scale': 1.908532322490604, 'n_blocks': 4, 'd_block': 534, 'dropout': 0.20277570804745146, 'k': 32}. Best is trial 24 with value: 0.5929613710472844.


Best trial: 24. Best value: 0.592961:  72%|███████▏  | 36/50 [4:10:05<1:58:55, 509.65s/it]

[I 2025-11-13 10:41:43,247] Trial 35 finished with value: 0.4475628225372607 and parameters: {'lr': 0.000688434263639781, 'weight_decay': 0.009016854374956244, 'batch_size': 64, 'n_epochs': 32, 'pos_weight_scale': 2.1153135559715595, 'n_blocks': 5, 'd_block': 632, 'dropout': 0.3978080120934157, 'k': 16}. Best is trial 24 with value: 0.5929613710472844.


Best trial: 24. Best value: 0.592961:  74%|███████▍  | 37/50 [4:15:02<1:36:36, 445.85s/it]

[I 2025-11-13 10:46:40,229] Trial 36 pruned. 


Best trial: 24. Best value: 0.592961:  76%|███████▌  | 38/50 [4:24:54<1:37:54, 489.52s/it]

[I 2025-11-13 10:56:31,658] Trial 37 finished with value: 0.5906832466075683 and parameters: {'lr': 0.0013697804523425373, 'weight_decay': 0.0012251732932800697, 'batch_size': 64, 'n_epochs': 33, 'pos_weight_scale': 1.8042416384349975, 'n_blocks': 3, 'd_block': 576, 'dropout': 0.15969747058882125, 'k': 32}. Best is trial 24 with value: 0.5929613710472844.


Best trial: 24. Best value: 0.592961:  78%|███████▊  | 39/50 [4:28:37<1:15:06, 409.71s/it]

[I 2025-11-13 11:00:15,134] Trial 38 pruned. 


Best trial: 24. Best value: 0.592961:  80%|████████  | 40/50 [4:31:44<57:09, 342.97s/it]  

[I 2025-11-13 11:03:22,388] Trial 39 pruned. 


Best trial: 40. Best value: 0.59381:  82%|████████▏ | 41/50 [4:39:01<55:38, 370.96s/it] 

[I 2025-11-13 11:10:38,668] Trial 40 finished with value: 0.5938096041958789 and parameters: {'lr': 0.0011183425690718294, 'weight_decay': 0.004187499814600279, 'batch_size': 64, 'n_epochs': 33, 'pos_weight_scale': 2.159985493642329, 'n_blocks': 4, 'd_block': 455, 'dropout': 0.18355459979095332, 'k': 24}. Best is trial 40 with value: 0.5938096041958789.


Best trial: 40. Best value: 0.59381:  84%|████████▍ | 42/50 [4:43:45<46:00, 345.10s/it]

[I 2025-11-13 11:15:23,422] Trial 41 pruned. 


Best trial: 40. Best value: 0.59381:  86%|████████▌ | 43/50 [4:50:59<43:22, 371.82s/it]

[I 2025-11-13 11:22:37,590] Trial 42 finished with value: 0.5912987843062151 and parameters: {'lr': 0.0010969475613838737, 'weight_decay': 0.0020572609386457397, 'batch_size': 64, 'n_epochs': 31, 'pos_weight_scale': 2.2407443355934378, 'n_blocks': 4, 'd_block': 496, 'dropout': 0.15208053612741457, 'k': 24}. Best is trial 40 with value: 0.5938096041958789.


Best trial: 40. Best value: 0.59381:  88%|████████▊ | 44/50 [4:59:42<41:41, 416.89s/it]

[I 2025-11-13 11:31:19,643] Trial 43 finished with value: 0.5901855441195368 and parameters: {'lr': 0.00086359610127574, 'weight_decay': 0.0005348783936036363, 'batch_size': 64, 'n_epochs': 34, 'pos_weight_scale': 1.8334239355824478, 'n_blocks': 4, 'd_block': 533, 'dropout': 0.17908432247068556, 'k': 24}. Best is trial 40 with value: 0.5938096041958789.


Best trial: 40. Best value: 0.59381:  90%|█████████ | 45/50 [5:06:31<34:33, 414.71s/it]

[I 2025-11-13 11:38:09,252] Trial 44 finished with value: 0.5887840541995292 and parameters: {'lr': 0.0012570942179154824, 'weight_decay': 1.7376012339025193e-06, 'batch_size': 64, 'n_epochs': 32, 'pos_weight_scale': 2.4515094141345335, 'n_blocks': 4, 'd_block': 456, 'dropout': 0.2093181464000144, 'k': 24}. Best is trial 40 with value: 0.5938096041958789.


Best trial: 40. Best value: 0.59381:  92%|█████████▏| 46/50 [5:12:02<25:57, 389.47s/it]

[I 2025-11-13 11:43:39,849] Trial 45 finished with value: 0.5868489859725768 and parameters: {'lr': 0.0006126385079250058, 'weight_decay': 0.0011444297741283209, 'batch_size': 64, 'n_epochs': 28, 'pos_weight_scale': 2.082389889466775, 'n_blocks': 4, 'd_block': 419, 'dropout': 0.09194114684477378, 'k': 24}. Best is trial 40 with value: 0.5938096041958789.


Best trial: 40. Best value: 0.59381:  94%|█████████▍| 47/50 [5:19:19<20:11, 403.88s/it]

[I 2025-11-13 11:50:57,355] Trial 46 finished with value: 0.5912904720531627 and parameters: {'lr': 0.0007697454082563673, 'weight_decay': 8.092220400676234e-06, 'batch_size': 64, 'n_epochs': 30, 'pos_weight_scale': 2.284561945503801, 'n_blocks': 4, 'd_block': 402, 'dropout': 0.28232331249160114, 'k': 32}. Best is trial 40 with value: 0.5938096041958789.


Best trial: 40. Best value: 0.59381:  96%|█████████▌| 48/50 [5:24:25<12:28, 374.41s/it]

[I 2025-11-13 11:56:02,985] Trial 47 pruned. 


Best trial: 40. Best value: 0.59381:  98%|█████████▊| 49/50 [5:29:11<05:48, 348.05s/it]

[I 2025-11-13 12:00:49,548] Trial 48 finished with value: 0.5912304099386172 and parameters: {'lr': 0.0014896832680515405, 'weight_decay': 3.318081098209326e-06, 'batch_size': 64, 'n_epochs': 32, 'pos_weight_scale': 1.3741210222490037, 'n_blocks': 5, 'd_block': 611, 'dropout': 0.1439113240821146, 'k': 8}. Best is trial 40 with value: 0.5938096041958789.


Best trial: 40. Best value: 0.59381: 100%|██████████| 50/50 [5:34:31<00:00, 401.43s/it]

[I 2025-11-13 12:06:08,994] Trial 49 pruned. 
Best mean CV F1: 0.5938096041958789
Best params: {'lr': 0.0011183425690718294, 'weight_decay': 0.004187499814600279, 'batch_size': 64, 'n_epochs': 33, 'pos_weight_scale': 2.159985493642329, 'n_blocks': 4, 'd_block': 455, 'dropout': 0.18355459979095332, 'k': 24}





In [25]:
best_params = study.best_params
print(best_params)

{'lr': 0.0011183425690718294, 'weight_decay': 0.004187499814600279, 'batch_size': 64, 'n_epochs': 33, 'pos_weight_scale': 2.159985493642329, 'n_blocks': 4, 'd_block': 455, 'dropout': 0.18355459979095332, 'k': 24}


In [26]:
# Refit preprocessor on the full training data
preproc_final = Preprocessor(num_cols=num_cols, cat_cols=cat_cols, target_col=target_col)
X_num_full, X_cat_full, y_full = preproc_final.fit_transform(df)

# Prepare tensors
X_num_full_t = torch.from_numpy(X_num_full)
X_cat_full_t = torch.from_numpy(X_cat_full)
y_full_t     = torch.from_numpy(y_full)

full_dataset = TensorDataset(X_num_full_t, X_cat_full_t, y_full_t)

full_loader = DataLoader(
    full_dataset,
    batch_size=best_params["batch_size"],
    shuffle=True,
    drop_last=False
)

In [27]:
# Build final TabM model from best params
model_final = build_model(
    n_num_features=X_num_full.shape[1],
    cat_cardinalities=[preproc_final.cat_cardinalities_[c] for c in cat_cols],
    trial=optuna.trial.FixedTrial(best_params)
)

In [28]:
# No weight, handle imbalance through threshold optimizing
criterion = nn.BCEWithLogitsLoss()

optimizer = torch.optim.AdamW(
    model_final.parameters(),
    lr=best_params["lr"],
    weight_decay=best_params["weight_decay"],
)

n_epochs = best_params["n_epochs"]

for epoch in range(n_epochs):
    model_final.train()
    running_loss = 0.0
    for Xn, Xc, yb in full_loader:
        Xn = Xn.to(device)
        Xc = Xc.to(device)
        yb = yb.to(device)

        logits = model_final(Xn, Xc)        # (B, k, 1)
        B, k, _ = logits.shape
        logits_flat = logits.reshape(B * k, 1)
        y_flat = yb.repeat_interleave(k, dim=0)

        loss = criterion(logits_flat, y_flat)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    if (epoch+1) % 5 == 0:
        print(f"Epoch {epoch+1}/{n_epochs}, loss={running_loss:.4f}")


Epoch 5/33, loss=140.5011
Epoch 10/33, loss=121.5093
Epoch 15/33, loss=113.7620
Epoch 20/33, loss=112.7051
Epoch 25/33, loss=111.7766
Epoch 30/33, loss=111.8216


In [37]:
# Diagnostic Code

from sklearn.metrics import f1_score, confusion_matrix

model_final.eval()
with torch.no_grad():
    Xn_tr = X_num_full_t.to(device)
    Xc_tr = X_cat_full_t.to(device)
    logits_tr = model_final(Xn_tr, Xc_tr)                 # (N, k, 1)
    probs_tr = torch.sigmoid(logits_tr).mean(dim=1)       # (N, 1)
    probs_tr = probs_tr.squeeze(-1).cpu().numpy()

print("Train probs min/mean/max:", probs_tr.min(), probs_tr.mean(), probs_tr.max())

y_train_true = y_full.ravel().astype(int)
y_train_pred_03 = (probs_tr >= 0.3).astype(int)

print("Train F1 @0.3:", f1_score(y_train_true, y_train_pred_03))
print("Train confusion @0.5:\n", confusion_matrix(y_train_true, y_train_pred_05))


Train probs min/mean/max: 0.0005119745 0.25603554 0.842113
Train F1 @0.3: 0.5988591317799478
Train confusion @0.5:
 [[12912   972]
 [ 2268  1847]]


In [31]:
best_f1_train, best_t_train = compute_best_f1(probs_tr, y_full.ravel().astype(int))
print("Best F1 on train:", best_f1_train)
print("Best threshold:", best_t_train)

Best F1 on train: 0.5988591317799478
Best threshold: 0.3


In [32]:
real_test = pd.read_csv("data/Testing_TriGuard.csv")

# Keep columns needed for domain features
claim_numbers = real_test["claim_number"].copy()

# Use the final preprocessor (fit on full training data!)
X_num_test, X_cat_test, _ = preproc_final.transform(real_test)

X_num_test_t = torch.from_numpy(X_num_test)
X_cat_test_t = torch.from_numpy(X_cat_test)


In [33]:
model_final.eval()
all_probs = []

with torch.no_grad():
    Xn = X_num_test_t.to(device)
    Xc = X_cat_test_t.to(device)

    logits = model_final(Xn, Xc)                 # (N, k, 1)
    probs = torch.sigmoid(logits).mean(dim=1)    # (N, 1)
    probs = probs.squeeze(-1).cpu().numpy()

    all_probs = probs


In [34]:
real_pred_proba = all_probs
real_pred_label = (real_pred_proba >= best_t_train).astype(int)

In [35]:
# Diagnostic Code

print("Min prob:", real_pred_proba.min())
print("Max prob:", real_pred_proba.max())
print("Mean prob:", real_pred_proba.mean())
print("First 20 probs:", real_pred_proba[:20])
print("First 20 labels:", real_pred_label[:20])
print("Unique labels:", np.unique(real_pred_label))


Min prob: 0.0007758692
Max prob: 0.82720715
Mean prob: 0.25367633
First 20 probs: [0.21576667 0.3347236  0.06319809 0.609448   0.36384344 0.03800942
 0.45161465 0.06597612 0.23697914 0.66205806 0.09418901 0.47217855
 0.2660694  0.2500148  0.22387505 0.58648896 0.054098   0.04736322
 0.15361674 0.4636856 ]
First 20 labels: [0 1 0 1 1 0 1 0 0 1 0 1 0 0 0 1 0 0 0 1]
Unique labels: [0 1]


In [36]:
prediction = pd.DataFrame({
    "claim_number": claim_numbers,
    "subrogation": real_pred_label
})

prediction.to_csv("results/tabm_5938_prediction.csv", index=False)

print("Saved:", "results/tabm_5938_prediction.csv")


Saved: results/tabm_5938_prediction.csv


Model Saving Pipeline

In [38]:
from TabM_save_load import TabM_save_load

In [40]:
# 1. Create the pipeline wrapper
pipeline = TabM_save_load(
    model=model_final,
    preprocessor=preproc_final,
    threshold=best_t_train,
    best_params=best_params,
    num_cols=num_cols,
    cat_cols=cat_cols,
    device=device
)

# 2. Save everything to a folder
save_dir = "models/tabm_full_pipeline_5938"
pipeline.save(save_dir)

TabM pipeline saved successfully to models/tabm_full_pipeline_5938
