In [1]:
!pip install xgboost




In [4]:
# minimal, end-to-end training on weak multi-label data with masks
# deps: pip install torch scikit-learn joblib numpy

import os, json, math, random
from pathlib import Path
import numpy as np
import pandas as pd
import joblib
import torch
import torch.nn as nn
import xgboost as xgb
from torch.utils.data import Dataset, DataLoader
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import average_precision_score, roc_auc_score
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# ------------------------------ data loading ------------------------------

def _load_bucket_flat(bucket_dir, polarity):
    """load a flat bucket (positives/negatives) -> list of (emb, fname, label, polarity)"""
    bucket_dir = Path(bucket_dir)
    embs = joblib.load(bucket_dir / "embeddings.joblib")      # list/array of vectors
    labels = json.loads((bucket_dir / "labels.json").read_text())     # list[str]
    fnames = json.loads((bucket_dir / "filenames.json").read_text())  # list[str]
    assert len(embs) == len(labels) == len(fnames), "mismatched lengths in bucket"
    out = []
    for emb, lab, fn in zip(embs, labels, fnames):
        emb = np.asarray(emb, dtype=np.float32)
        out.append((emb, fn, lab, polarity))
    return out

def scan_root(root, pos_dir_name='tp', neg_dir_name='fp'):
    """scan flat positives/negatives structure and return combined records and class list"""
    root = Path(root)
    pos = _load_bucket_flat(root / pos_dir_name, polarity=1.0) if (root / pos_dir_name).exists() else []
    neg = _load_bucket_flat(root / neg_dir_name, polarity=0.0) if (root / neg_dir_name).exists() else []
    records = pos + neg

    # collect unique class names
    class_names = sorted({lab for (_, _, lab, _) in records})
    class_to_idx = {c: i for i, c in enumerate(class_names)}
    return records, class_names, class_to_idx

def build_multilabel_table(root):
    """
    dedupe by filename and build X, Y with weak multi-label targets
    y in {1, 0, -1}; -1 means unknown (masked)
    positive wins on conflicts; negative only sets if currently unknown
    """
    records, class_names, class_to_idx = scan_root(root)
    C = len(class_names)
    by_file = {}   # fname -> {'embedding': np.array, 'y': np.full(C, -1)}

    emb_dim = None
    for emb, fname, label, pol in records:
        if emb_dim is None: emb_dim = int(emb.shape[-1])
        ci = class_to_idx[label]
        slot = by_file.get(fname)
        if slot is None:
            y = np.full(C, -1.0, dtype=np.float32)
            y[ci] = float(pol)
            by_file[fname] = {"embedding": emb.astype(np.float32), "y": y}
        else:
            cur = slot["y"][ci]
            if pol == 1.0:
                slot["y"][ci] = 1.0
            elif cur == -1.0:
                slot["y"][ci] = 0.0

    X, Y, F = [], [], []
    for fname, rec in by_file.items():
        X.append(rec["embedding"])
        Y.append(rec["y"])
        F.append(fname)
    X = np.stack(X).astype(np.float32)
    Y = np.stack(Y).astype(np.float32)
    return X, Y, F, class_names, emb_dim


# ------------------------------ dataset ------------------------------

class WeakMultiLabelDataset(Dataset):
    def __init__(self, X, Y):
        self.X = torch.from_numpy(X)
        self.Y = torch.from_numpy(Y)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, i): return self.X[i], self.Y[i]

def split_indices(n, y, val_frac=0.2, test_frac=0.0, seed=0):
    """partition dataset based on class and label (pos or neg)"""
    # create class+label label
    tmp = np.array([str(i) + '_' + str(int(j)) for i, j in zip(y.argmax(axis=1), y.max(axis=1))])
    val_size = int(val_frac * n)
    test_size = int(test_frac * n)
    if val_frac > 0:
        idx_train, idx_val = train_test_split(
            np.arange(n), 
            test_size=val_size, 
            stratify=tmp
        )
    else:
        idx_val = []
    if test_frac > 0:
        idx_train, idx_test = train_test_split(
            idx_train, 
            test_size=test_size, 
            stratify=tmp[idx_train]
        )
    else:
        idx_test = []
    return idx_train, idx_val, idx_test

# ------------------------------ model ------------------------------

class MLP(nn.Module):
    def __init__(self, in_dim, out_dim, hidden=256, dropout=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(in_dim),
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, out_dim),
        )
    def forward(self, x): return self.net(x)  # logits

# ------------------------------ masked loss ------------------------------

class MaskedBCEWithLogits(nn.Module):
    def __init__(self, pos_weight=None, reduction='mean', mask_val=-1.0):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss(weight=None, pos_weight=pos_weight, reduction='none')
        self.reduction = reduction
        self.mask_val = mask_val

    def forward(self, logits, targets):
        # targets in {1,0,-1}; mask unknowns
        mask = (targets != self.mask_val).float()
        # fill unknowns with 0 to avoid nan in BCE; they will be masked out
        safe_targets = torch.where(mask.bool(), targets, torch.zeros_like(targets))
        per_elem = self.bce(logits, safe_targets)
        per_elem = per_elem * mask
        # reduce over classes then over batch
        # small epsilon to avoid div-by-zero if a batch has all-masked for a class
        denom = mask.sum(dim=1).clamp_min(1.0)
        per_sample = per_elem.sum(dim=1) / denom
        if self.reduction == 'mean':
            return per_sample.mean()
        elif self.reduction == 'sum':
            return per_sample.sum()
        else:
            return per_sample

# ------------------------------ metrics ------------------------------

@torch.no_grad()
def evaluate(model, loader, device, class_names):
    model.eval()
    all_logits, all_targets = [], []
    for xb, yb in loader:
        xb = xb.to(device)
        all_logits.append(model(xb).cpu())
        all_targets.append(yb.cpu())
    logits = torch.cat(all_logits, 0).numpy()
    targets = torch.cat(all_targets, 0).numpy()  # [N, C], values in {1,0,-1}

    per_class = {}
    ap_list, auc_list = [], []
    for ci, cname in enumerate(class_names):
        mask = targets[:, ci] != -1.0
        n_total = int(mask.sum())
        n_pos = int((targets[mask, ci] == 1.0).sum())
        n_neg = int((targets[mask, ci] == 0.0).sum())

        if n_total < 2 or len(np.unique(targets[mask, ci])) < 2:
            per_class[cname] = {
                "AP": None,
                "AUROC": None,
                "n_total": n_total,
                "n_pos": n_pos,
                "n_neg": n_neg,
            }
            continue

        y_true = targets[mask, ci]
        y_score = logits[mask, ci]
        try:
            ap = float(average_precision_score(y_true, y_score))
        except Exception:
            ap = None
        try:
            auc = float(roc_auc_score(y_true, y_score))
        except Exception:
            auc = None

        per_class[cname] = {
            "AP": ap,
            "AUROC": auc,
            "n_total": n_total,
            "n_pos": n_pos,
            "n_neg": n_neg,
        }
        if ap is not None: ap_list.append(ap)
        if auc is not None: auc_list.append(auc)

    macro_ap = float(np.mean(ap_list)) if ap_list else None
    macro_auc = float(np.mean(auc_list)) if auc_list else None
    return {"macro_AP": macro_ap, "macro_AUROC": macro_auc, "per_class": per_class}

# ------------------------------ training loop ------------------------------

def compute_pos_weight(Y, mask_val=-1.0):
    # pos_weight = (num_neg / num_pos) per class using only known labels
    C = Y.shape[1]
    pos_weight = np.zeros(C, dtype=np.float32)
    for c in range(C):
        known = Y[:, c] != mask_val
        if known.sum() == 0:
            pos_weight[c] = 1.0
        else:
            pos = (Y[known, c] == 1.0).sum()
            neg = (Y[known, c] == 0.0).sum()
            pos_weight[c] = float(neg / max(pos, 1))
    return torch.tensor(pos_weight, dtype=torch.float32)

def train_mlp(
    X, Y, class_names, idx_train, idx_val=None, idx_test=None,
    batch_size=128,
    hidden=256,
    dropout=0.1,
    lr=3e-4,
    epochs=30,
    seed=0,
    device="cuda" if torch.cuda.is_available() else "cpu",
    verbose=0,
):
    """
    trains an mlp with masked bce and optional early stopping on val.
    unknowns (-1) are ignored by the loss. returns dict(model=..., metrics={...}).
    """

    rng = np.random.RandomState(seed)

    # split matrices
    Xtr, Ytr = X[idx_train], Y[idx_train]
    Xva = Yva = None
    Xte = Yte = None
    if idx_val is not None and len(idx_val) > 0:
        Xva, Yva = X[idx_val], Y[idx_val]
    if idx_test is not None and len(idx_test) > 0:
        Xte, Yte = X[idx_test], Y[idx_test]

    emb_dim = X.shape[1]
    C = Y.shape[1]

    # build model + optimizer
    pos_weight = compute_pos_weight(Ytr)  # tensor [C]
    model = MLP(emb_dim, C, hidden=hidden, dropout=dropout).to(device)
    criterion = MaskedBCEWithLogits(pos_weight=pos_weight.to(device))
    optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=max(epochs, 1))

    # loaders
    ds_tr = WeakMultiLabelDataset(Xtr, Ytr)
    dl_tr = DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=False)

    if Xva is not None:
        ds_va = WeakMultiLabelDataset(Xva, Yva)
        dl_va = DataLoader(ds_va, batch_size=batch_size, shuffle=False)
    else:
        ds_va = dl_va = None

    if Xte is not None:
        ds_te = WeakMultiLabelDataset(Xte, Yte)
        dl_te = DataLoader(ds_te, batch_size=batch_size, shuffle=False)
    else:
        ds_te = dl_te = None

    # training loop with early stopping on val loss when available
    best_va = math.inf
    best_state = None

    for ep in tqdm(range(1, epochs + 1), desc='MLP training'):
        model.train()
        running = 0.0
        n_seen = 0
        for xb, yb in dl_tr:
            xb = xb.to(device)
            yb = yb.to(device)
            logits = model(xb)
            loss = criterion(logits, yb)
            optim.zero_grad(set_to_none=True)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # clip helps with small/imbalanced data
            optim.step()
            running += loss.item() * xb.size(0)
            n_seen += xb.size(0)
        sched.step()
        tr_loss = running / max(n_seen, 1)

        if dl_va is not None:
            with torch.no_grad():
                model.eval()
                va_loss = 0.0
                n_va = 0
                for xb, yb in dl_va:
                    xb = xb.to(device)
                    yb = yb.to(device)
                    va_loss += criterion(model(xb), yb).item() * xb.size(0)
                    n_va += xb.size(0)
                va_loss /= max(n_va, 1)

            if va_loss < best_va:
                best_va = va_loss
                best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

            if verbose:
                print(f"epoch {ep:02d} | train_loss {tr_loss:.4f} | val_loss {va_loss:.4f}")
        elif verbose:
            print(f"epoch {ep:02d} | train_loss {tr_loss:.4f}")

    # restore best val checkpoint if we had a val split
    if best_state is not None:
        model.load_state_dict(best_state)

    # metrics: prefer test if provided; also report val if provided
    metrics = {}
    if dl_va is not None:
        metrics["val"] = evaluate(model, dl_va, device, class_names)
    if dl_te is not None:
        metrics["test"] = evaluate(model, dl_te, device, class_names)
    if not metrics:
        # if neither val nor test given, evaluate on train just so caller gets something
        dl_tmp = DataLoader(ds_tr, batch_size=batch_size, shuffle=False)
        metrics["train"] = evaluate(model, dl_tmp, device, class_names)

    return {"model": model, "metrics": metrics}


# ------------------------- helpers -------------------------

def _eval_split(scores, y_true):
    # compute AP/AUROC per class using only known labels
    C = y_true.shape[1]
    per_class = {}
    ap_list, auc_list = [], []
    for ci in range(C):
        mask = y_true[:, ci] != -1.0
        if mask.sum() < 2 or len(np.unique(y_true[mask, ci])) < 2:
            per_class[ci] = {
                "AP": None,
                "AUROC": None,
                "n_pos": int((y_true[mask, ci] == 1.0).sum()),
                "n_neg": int((y_true[mask, ci] == 0.0).sum()),
            }
            continue
        y = y_true[mask, ci]
        s = scores[mask, ci]
        ap = average_precision_score(y, s)
        auc = roc_auc_score(y, s)
        per_class[ci] = {
            "AP": float(ap),
            "AUROC": float(auc),
            "n_pos": int((y == 1.0).sum()),
            "n_neg": int((y == 0.0).sum()),
        }
        ap_list.append(ap); auc_list.append(auc)
    macro_ap = float(np.mean(ap_list)) if ap_list else None
    macro_auc = float(np.mean(auc_list)) if auc_list else None
    return {"macro_AP": macro_ap, "macro_AUROC": macro_auc, "per_class": per_class}


# ------------------------- logistic regression -------------------------

def train_linear(X, Y, class_names, idx_train, idx_val=None, idx_test=None, Cs=(0.1, 0.3, 1.0, 3.0, 10.0)):
    """
    trains per-class logistic regression with simple C selection on val.
    unknowns (-1) are ignored per class per split.
    returns dict(models=..., metrics={'val':..., 'test':...})
    """
    C_list = list(Cs)
    C_default = C_list[len(C_list)//2]

    models = [None] * len(class_names)

    # fit per class with val-based C if provided
    for ci, _ in tqdm(enumerate(class_names), total=len(class_names), desc='Logisting training'):
        # build train set for this class
        m_tr = Y[idx_train, ci] != -1.0
        Xtr = X[idx_train][m_tr]; ytr = Y[idx_train, ci][m_tr]
        if len(np.unique(ytr)) < 2:
            # not enough label diversity
            models[ci] = None
            continue

        # choose C using val AP (fallback to default if no val or degenerate)
        best_C = C_default
        if idx_val is not None and len(idx_val) > 0:
            best_ap, best_C = -1.0, C_default
            m_va_all = Y[idx_val, ci] != -1.0
            if m_va_all.any() and len(np.unique(Y[idx_val, ci][m_va_all])) > 1:
                for C in C_list:
                    clf = LogisticRegression(
                        penalty="l2", C=C, max_iter=2000, solver="saga", class_weight="balanced"
                    )
                    clf.fit(Xtr, ytr)
                    m_va = m_va_all
                    s = clf.predict_proba(X[idx_val][m_va])[:, 1]
                    ap = average_precision_score(Y[idx_val, ci][m_va], s)
                    if ap > best_ap:
                        best_ap, best_C = ap, C

        # fit final model with chosen C on train
        clf = LogisticRegression(
            penalty="l2", C=best_C, max_iter=2000, solver="saga", class_weight="balanced"
        )
        clf.fit(Xtr, ytr)
        models[ci] = clf

    # scores on val/test
    metrics = {}
    if idx_val is not None and len(idx_val) > 0:
        S_val = np.zeros((len(idx_val), len(class_names)), dtype=np.float32)
        for ci, clf in enumerate(models):
            if clf is None:
                S_val[:, ci] = np.nan
                continue
            m_va = Y[idx_val, ci] != -1.0
            S_val[m_va, ci] = clf.predict_proba(X[idx_val][m_va])[:, 1]
        metrics["val"] = _eval_split(S_val, Y[idx_val])

    if idx_test is not None and len(idx_test) > 0:
        S_te = np.zeros((len(idx_test), len(class_names)), dtype=np.float32)
        for ci, clf in enumerate(models):
            if clf is None:
                S_te[:, ci] = np.nan
                continue
            m_te = Y[idx_test, ci] != -1.0
            S_te[m_te, ci] = clf.predict_proba(X[idx_test][m_te])[:, 1]
        metrics["test"] = _eval_split(S_te, Y[idx_test])

    return {"models": models, "metrics": metrics}

# ------------------------- xgboost -------------------------

def train_xgb(
    X, Y, class_names, idx_train, idx_val=None, idx_test=None,
    param_grid=(
        {"max_depth": 3, "min_child_weight": 1},
        {"max_depth": 4, "min_child_weight": 2},
        {"max_depth": 5, "min_child_weight": 3},
    ),
    base_params=None,
    num_boost_round=200,
    early_stopping_rounds=30
):
    """
    trains per-class xgboost with light val tuning + early stopping.
    unknowns (-1) are ignored per class per split.
    returns dict(models=..., metrics={'val':..., 'test':...})
    """
    if base_params is None:
        base_params = {
            "objective": "binary:logistic",
            "eval_metric": "auc",
            "tree_method": "hist",
            "subsample": 0.8,
            "colsample_bytree": 0.8,
            "lambda": 1.0,
        }

    models = [None] * len(class_names)
    best_iters = [None] * len(class_names)

    for ci, _ in tqdm(enumerate(class_names), total=len(class_names), desc='XGB training'):
        m_tr = Y[idx_train, ci] != -1.0
        Xtr = X[idx_train][m_tr]; ytr = Y[idx_train, ci][m_tr]
        if len(np.unique(ytr)) < 2:
            models[ci] = None
            continue

        dtr = xgb.DMatrix(Xtr, label=ytr)

        # if val exists for this class, pick params by val AP with early stopping
        chosen_params = dict(base_params)
        chosen_n = num_boost_round
        if idx_val is not None and len(idx_val) > 0:
            m_va = Y[idx_val, ci] != -1.0
            has_val = m_va.any() and len(np.unique(Y[idx_val, ci][m_va])) > 1
            if has_val:
                Xva = X[idx_val][m_va]; yva = Y[idx_val, ci][m_va]
                dva = xgb.DMatrix(Xva, label=yva)
                best_ap, best_combo = -1.0, None
                for combo in param_grid:
                    params = dict(base_params, **combo)
                    bst = xgb.train(
                        params, dtr,
                        num_boost_round=num_boost_round,
                        evals=[(dtr, "train"), (dva, "val")],
                        early_stopping_rounds=early_stopping_rounds,
                        verbose_eval=False,
                    )
                    # use val AP (more aligned with retrieval) instead of built-in auc
                    y_score = bst.predict(dva, iteration_range=(0, bst.best_iteration+1))
                    ap = average_precision_score(yva, y_score)
                    if ap > best_ap:
                        best_ap, best_combo = ap, (params, bst.best_iteration+1)
                if best_combo is not None:
                    chosen_params, chosen_n = best_combo

        # train final model on train with chosen settings
        bst_final = xgb.train(
            chosen_params, dtr, num_boost_round=chosen_n, verbose_eval=False
        )
        models[ci] = bst_final
        best_iters[ci] = chosen_n

    # scores on val/test
    metrics = {}
    if idx_val is not None and len(idx_val) > 0:
        S_val = np.zeros((len(idx_val), len(class_names)), dtype=np.float32)
        for ci, bst in enumerate(models):
            if bst is None:
                S_val[:, ci] = np.nan
                continue
            m_va = Y[idx_val, ci] != -1.0
            if m_va.any():
                dva = xgb.DMatrix(X[idx_val][m_va])
                S_val[m_va, ci] = bst.predict(dva)
        metrics["val"] = _eval_split(S_val, Y[idx_val])

    if idx_test is not None and len(idx_test) > 0:
        S_te = np.zeros((len(idx_test), len(class_names)), dtype=np.float32)
        for ci, bst in enumerate(models):
            if bst is None:
                S_te[:, ci] = np.nan
                continue
            m_te = Y[idx_test, ci] != -1.0
            if m_te.any():
                dte = xgb.DMatrix(X[idx_test][m_te])
                S_te[m_te, ci] = bst.predict(dte)
        metrics["test"] = _eval_split(S_te, Y[idx_test])

    return {"models": models, "metrics": metrics}


if __name__ == "__main__":
    # dataset root path
    ROOT = "/home/ec2-user/SageMaker/projects/experimental/bio-ae-eval/artifacts/embeddings/20250811-075752/perch_bird/"

    # prepare data
    X, Y, F, class_names, emb_dim = build_multilabel_table(ROOT)
    n = X.shape[0]
    tr_idx, va_idx, te_idx = split_indices(n, Y, val_frac=0.25, test_frac=0.25, seed=13)
    
    # train and compare models
    results = train_mlp(X, Y, class_names, tr_idx, va_idx, te_idx)
    print(f"mlp mAP: {results['metrics']['test']['macro_AP']}")
    results = train_xgb(X, Y, class_names, tr_idx, va_idx, te_idx)
    print(f"xgb mAP: {results['metrics']['test']['macro_AP']}")
    results = train_linear(X, Y, class_names, tr_idx, va_idx, te_idx)
    print(f"lin mAP: {results['metrics']['test']['macro_AP']}")


MLP training: 100%|██████████| 30/30 [00:05<00:00,  5.99it/s]


mlp mAP: 0.8987727452817657


XGB training: 100%|██████████| 26/26 [00:10<00:00,  2.44it/s]


xgb mAP: 0.8459730442030611


Logisting training: 100%|██████████| 26/26 [01:36<00:00,  3.72s/it]


lin mAP: 0.9089192378258985


In [3]:
# def count_pos_neg(Y: np.ndarray, class_names):
#     # count positives (1) per column
#     pos = np.sum(Y == 1, axis=0)
#     # count negatives (0) per column
#     neg = np.sum(Y == 0, axis=0)
#     return pd.DataFrame({'class': class_names, 'npos': pos, 'nneg': neg})

