
# 04 — Train MIL Head (Presence/Absence)

**Goal:** Train a lightweight MIL head on top of cached DINOv2 features.
- Pool tile logits per image (max-pool by default).
- Handle class imbalance with `pos_weight`.
- Log basic metrics and save `mil_head.pt`.


In [None]:

%pip -q install --extra-index-url https://download.pytorch.org/whl/cu121   torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0
%pip -q install numpy pandas pyarrow scikit-learn tqdm mlflow==2.14.3


In [None]:

import os, json, math, numpy as np, pandas as pd, torch
from torch import nn, optim
from pathlib import Path
from tqdm import tqdm
from sklearn.metrics import average_precision_score, precision_recall_curve, roc_auc_score
import mlflow

BASE = Path('/content')  # change if needed
CACHE_DIR   = BASE/'cache/embeddings'
MODEL_DIR   = BASE/'models'
MODEL_DIR.mkdir(parents=True, exist_ok=True)

# If using MLflow with file storage, set once per session:
os.environ.setdefault("MLFLOW_TRACKING_URI", f"file:{BASE/'runs'}")
mlflow.set_tracking_uri(os.environ["MLFLOW_TRACKING_URI"])

FEAT_DIM = 384   # DINOv2 ViT-S/14 feature dim
POOLING = "max"  # "max" | "lse" | "mean"
POS_WEIGHT = 3.0 # tune based on imbalance
LR = 1e-3
WD = 1e-4
EPOCHS = 10
device = "cuda" if torch.cuda.is_available() else "cpu"


In [None]:

import glob
import pyarrow.parquet as pq

def load_split(split):
    paths = sorted(glob.glob(str(CACHE_DIR/f"emb_{split}_*.parquet")))
    dfs = [pd.read_parquet(p) for p in paths]
    df = pd.concat(dfs, axis=0, ignore_index=True) if dfs else pd.DataFrame()
    return df

train_df = load_split('train')
val_df   = load_split('val')
test_df  = load_split('test')
for name, d in [('train',train_df),('val',val_df),('test',test_df)]:
    print(name, d.shape, d['label'].value_counts().to_dict() if 'label' in d.columns else {})
assert len(train_df), "No training tiles found. Did you run previous notebooks?"


In [None]:

# Group tiles by image_id, stack embeddings per image, keep image-level label
def to_groups(df):
    emb_cols = [c for c in df.columns if c.startswith('emb_')]
    groups = {}
    for img_id, g in df.groupby('image_id'):
        feats = torch.tensor(g[emb_cols].values, dtype=torch.float32)
        label = int(g['label'].iloc[0])
        groups[img_id] = {'feats': feats, 'label': label}
    return groups

train_groups = to_groups(train_df)
val_groups   = to_groups(val_df)
test_groups  = to_groups(test_df)
len(train_groups), len(val_groups), len(test_groups)


In [None]:

class MILHead(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.fc = nn.Linear(d, 1)
    def forward(self, tile_feats, pooling="max"):
        # tile_feats: [T,D]
        logits = self.fc(tile_feats).squeeze(-1)  # [T]
        if pooling=="max":
            img_logit = logits.max()
        elif pooling=="lse":
            s=10.0; img_logit = torch.logsumexp(logits*s, dim=0)/s
        else:
            img_logit = logits.mean()
        return img_logit, logits

def eval_groups(head, groups):
    head.eval()
    y_true, y_prob = [], []
    with torch.no_grad():
        for g in groups.values():
            feats = g['feats'].to(device)
            logit, _ = head(feats, pooling=POOLING)
            prob = torch.sigmoid(logit).item()
            y_prob.append(prob); y_true.append(g['label'])
    ap = average_precision_score(y_true, y_prob)
    try:
        auc = roc_auc_score(y_true, y_prob)
    except Exception:
        auc = float('nan')
    return np.array(y_true), np.array(y_prob), ap, auc


In [None]:

mlflow.set_experiment("tortoise_mil_presence")

with mlflow.start_run() as run:
    head = MILHead(FEAT_DIM).to(device)
    crit = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([POS_WEIGHT], device=device))
    opt = optim.AdamW(head.parameters(), lr=LR, weight_decay=WD)

    for epoch in range(EPOCHS):
        head.train()
        total = 0.0
        for g in tqdm(train_groups.values(), desc=f"epoch {epoch}"):
            feats = g['feats'].to(device)
            y = torch.tensor([g['label']], dtype=torch.float32, device=device)
            logit, _ = head(feats, pooling=POOLING)
            loss = crit(logit.unsqueeze(0), y.unsqueeze(0))
            opt.zero_grad(); loss.backward(); opt.step()
            total += loss.item()
        yv, pv, ap, auc = eval_groups(head, val_groups)
        print(f"Epoch {epoch} loss={total/len(train_groups):.4f}  val AP={ap:.4f} AUC={auc:.4f}")
        mlflow.log_metric("train_loss", total/len(train_groups), step=epoch)
        mlflow.log_metric("val_ap", ap, step=epoch)
        if not math.isnan(auc): mlflow.log_metric("val_auc", auc, step=epoch)

    # Final eval on test
    yt, pt, ap_t, auc_t = eval_groups(head, test_groups)
    print(f"TEST:  AP={ap_t:.4f}  AUC={auc_t:.4f}")
    mlflow.log_metric("test_ap", ap_t)
    if not math.isnan(auc_t): mlflow.log_metric("test_auc", auc_t)

    # Save model
    out_path = MODEL_DIR/'mil_head.pt'
    torch.save(head.state_dict(), out_path)
    mlflow.log_artifact(str(out_path))
    with open(MODEL_DIR/'model_card.json','w') as f:
        json.dump({
            "feat_dim": FEAT_DIM,
            "pooling": POOLING,
            "pos_weight": POS_WEIGHT,
            "epochs": EPOCHS,
            "lr": LR,
            "weight_decay": WD
        }, f, indent=2)
    mlflow.log_artifact(str(MODEL_DIR/'model_card.json'))
