# B2 — Interpretable Semantic Token Prediction (Audio to Set of Meaningful Tokens)

We convert each labeled example into a small set of **interpretable semantic tokens**:
- identity tokens: EMITTER_x, ADDRESSEE_y
- context token: CONTEXT_Feeding / CONTEXT_Sleeping / ...
- action tokens: E_PRE_Fly_in, E_POST_Stay, A_PRE_Present, A_POST_Fly_away, ...

We then train a **multi-label classifier** to predict this token set from acoustic features
(AST + token histograms). This creates a "language-like" intermediate representation that is:
- interpretable
- easy to evaluate (token-level F1, per-slot accuracy)
- a clean stepping stone to later caption generation (B3)

In [None]:
import numpy as np
import torch
import torch.nn as nn
from dataclasses import dataclass
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import f1_score, classification_report


In [None]:
def _s(x) -> str:
    return str(x).strip()

def normalize_slot_value(slot: str, raw) -> str:
    """Convert numeric codes to human-readable names where applicable."""
    v = _s(raw)

    if slot == "context":
        return CONTEXT_MAP.get(v, f"Unknown({v})")

    if slot in ("emitter_pre", "addressee_pre"):
        return PRE_ACTION_MAP.get(v, f"Unknown({v})")

    if slot in ("emitter_post", "addressee_post"):
        return POST_ACTION_MAP.get(v, f"Unknown({v})")

    # emitter/addressee: keep IDs
    return v

def slots_to_tokens(slots: dict) -> list[str]:
    """
    Convert one example’s slots into a set of interpretable tokens.
    """
    emitter = normalize_slot_value("emitter", slots["emitter"])
    addressee = normalize_slot_value("addressee", slots["addressee"])
    context = normalize_slot_value("context", slots["context"])

    e_pre  = normalize_slot_value("emitter_pre", slots["emitter_pre"])
    a_pre  = normalize_slot_value("addressee_pre", slots["addressee_pre"])
    e_post = normalize_slot_value("emitter_post", slots["emitter_post"])
    a_post = normalize_slot_value("addressee_post", slots["addressee_post"])

    toks = [
        f"EMITTER_{emitter}",
        f"ADDRESSEE_{addressee}",
        f"CONTEXT_{context}",
        f"E_PRE_{e_pre}",
        f"A_PRE_{a_pre}",
        f"E_POST_{e_post}",
        f"A_POST_{a_post}",
    ]
    return toks

In [None]:
def build_token_dataset(labels_raw: dict):
    """
    Returns:
      token_lists: list[list[str]] length N
      vocab: list[str]
      token_to_id: dict[str,int]
      Y: np.ndarray shape (N, V) multi-hot
      groups: dict[str, list[int]] token indices grouped by slot-category for per-slot decoding
    """
    N = len(next(iter(labels_raw.values())))
    token_lists = []

    for i in range(N):
        slots_i = {k: labels_raw[k][i] for k in SLOTS}
        token_lists.append(slots_to_tokens(slots_i))

    # Build vocab
    vocab_set = set()
    for toks in token_lists:
        vocab_set.update(toks)
    vocab = sorted(vocab_set)
    token_to_id = {t:i for i,t in enumerate(vocab)}

    # Multi-hot matrix
    V = len(vocab)
    Y = np.zeros((N, V), dtype=np.float32)
    for i, toks in enumerate(token_lists):
        for t in toks:
            Y[i, token_to_id[t]] = 1.0

    # Group token indices by category (for per-slot decoding)
    groups = {
        "emitter": [token_to_id[t] for t in vocab if t.startswith("EMITTER_")],
        "addressee": [token_to_id[t] for t in vocab if t.startswith("ADDRESSEE_")],
        "context": [token_to_id[t] for t in vocab if t.startswith("CONTEXT_")],
        "emitter_pre": [token_to_id[t] for t in vocab if t.startswith("E_PRE_")],
        "addressee_pre": [token_to_id[t] for t in vocab if t.startswith("A_PRE_")],
        "emitter_post": [token_to_id[t] for t in vocab if t.startswith("E_POST_")],
        "addressee_post": [token_to_id[t] for t in vocab if t.startswith("A_POST_")],
    }

    return token_lists, vocab, token_to_id, Y, groups

token_lists, vocab, token_to_id, Y_all, groups = build_token_dataset(labels_raw)

print("N examples:", len(token_lists))
print("Vocab size:", len(vocab))
print("Example tokens:", token_lists[0])


N examples: 10000
Vocab size: 68
Example tokens: ['EMITTER_216', 'ADDRESSEE_221', 'CONTEXT_General', 'E_PRE_Present', 'A_PRE_Crawl in', 'E_POST_Stay', 'A_POST_Stay']


In [None]:
def compute_pos_weight(Y: np.ndarray, eps: float = 1e-6) -> torch.Tensor:
    """
    pos_weight for BCEWithLogitsLoss:
      pos_weight[j] = (N - pos_count[j]) / pos_count[j]
    """
    N = Y.shape[0]
    pos = Y.sum(axis=0)
    pos = np.clip(pos, eps, None)
    neg = N - pos
    pw = neg / pos
    return torch.tensor(pw, dtype=torch.float32)

pos_weight = compute_pos_weight(Y_all)

In [None]:
class TokenDataset(torch.utils.data.Dataset):
    def __init__(self, X: np.ndarray, Y: np.ndarray):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.Y = torch.tensor(Y, dtype=torch.float32)

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

In [None]:
class B2aMLP(nn.Module):
    def __init__(self, input_dim: int, hidden_dims=(512,), dropout=0.2, vocab_size=100):
        super().__init__()
        layers = []
        d = input_dim
        for h in hidden_dims:
            layers.append(nn.Linear(d, h))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
            d = h
        self.trunk = nn.Sequential(*layers)
        self.out = nn.Linear(d, vocab_size)

    def forward(self, x):
        z = self.trunk(x)
        return self.out(z)  # logits (N, V)

In [None]:
@torch.no_grad()
def decode_per_slot_from_logits(logits: np.ndarray, vocab: list[str], groups: dict) -> dict:
    """
    For each slot group, pick the highest-logit token within that group.
    Returns: dict slot -> list[str] predicted token strings
    """
    pred = {}
    for slot, idxs in groups.items():
        sub = logits[:, idxs]  # (N, |group|)
        best = np.argmax(sub, axis=1)
        pred_tokens = [vocab[idxs[j]] for j in best]
        pred[slot] = pred_tokens
    return pred

def per_slot_accuracy(gt_tokens_by_slot: dict, pred_tokens_by_slot: dict) -> dict:
    out = {}
    for slot in gt_tokens_by_slot:
        gt = np.array(gt_tokens_by_slot[slot], dtype=str)
        pr = np.array(pred_tokens_by_slot[slot], dtype=str)
        out[slot] = float((gt == pr).mean())
    return out

def build_gt_tokens_by_slot(token_lists: list[list[str]]) -> dict:
    """
    Extract the single GT token per slot from the 7-token list.
    """
    gt = {k: [] for k in groups.keys()}
    for toks in token_lists:
        # each example includes exactly one token per slot prefix
        for t in toks:
            if t.startswith("EMITTER_"):   gt["emitter"].append(t)
            elif t.startswith("ADDRESSEE_"): gt["addressee"].append(t)
            elif t.startswith("CONTEXT_"): gt["context"].append(t)
            elif t.startswith("E_PRE_"):   gt["emitter_pre"].append(t)
            elif t.startswith("A_PRE_"):   gt["addressee_pre"].append(t)
            elif t.startswith("E_POST_"):  gt["emitter_post"].append(t)
            elif t.startswith("A_POST_"):  gt["addressee_post"].append(t)
    return gt

def train_eval_b2a(
    X_all: np.ndarray,
    Y_all: np.ndarray,
    token_lists: list[list[str]],
    vocab: list[str],
    groups: dict,
    hidden_dims=(512,),
    lr=3e-4,
    dropout=0.2,
    batch_size=256,
    epochs=12,
    random_state=0,
    stratify_on_context=True,
):
    # Split (optionally stratify by context token for stability)
    gt_by_slot = build_gt_tokens_by_slot(token_lists)
    strat = None
    if stratify_on_context:
        strat = np.array(gt_by_slot["context"], dtype=str)

    X_tr, X_te, Y_tr, Y_te, toks_tr, toks_te = train_test_split(
        X_all, Y_all, token_lists,
        test_size=0.2, random_state=random_state, stratify=strat
    )

    # Scale like A-models
    scaler = StandardScaler()
    X_tr = scaler.fit_transform(X_tr).astype(np.float32)
    X_te = scaler.transform(X_te).astype(np.float32)

    # DataLoaders
    tr_ds = TokenDataset(X_tr, Y_tr)
    te_ds = TokenDataset(X_te, Y_te)
    tr_loader = torch.utils.data.DataLoader(tr_ds, batch_size=batch_size, shuffle=True)
    te_loader = torch.utils.data.DataLoader(te_ds, batch_size=batch_size, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = B2aMLP(input_dim=X_all.shape[1], hidden_dims=hidden_dims, dropout=dropout, vocab_size=len(vocab)).to(device)

    # pos_weight for imbalance
    pos_weight = compute_pos_weight(Y_tr)
    loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device))

    opt = torch.optim.AdamW(model.parameters(), lr=lr)

    # Train
    for ep in range(1, epochs + 1):
        model.train()
        total = 0.0
        for xb, yb in tr_loader:
            xb = xb.to(device)
            yb = yb.to(device)

            opt.zero_grad()
            logits = model(xb)
            loss = loss_fn(logits, yb)
            loss.backward()
            opt.step()
            total += loss.item() * xb.size(0)

        avg_loss = total / len(tr_ds)
        if ep == 1 or ep == epochs or ep % 3 == 0:
            print(f"epoch {ep:02d}/{epochs}  train_loss={avg_loss:.4f}")

    # Eval
    model.eval()
    all_logits = []
    all_Y = []
    for xb, yb in te_loader:
        xb = xb.to(device)
        logits = model(xb).detach().cpu().numpy()
        all_logits.append(logits)
        all_Y.append(yb.numpy())

    logits_te = np.vstack(all_logits)
    Y_te = np.vstack(all_Y)

    # Multi-label token F1 (threshold at 0.5 on sigmoid)
    probs = 1 / (1 + np.exp(-logits_te))
    Y_hat = (probs >= 0.5).astype(int)

    micro = f1_score(Y_te, Y_hat, average="micro", zero_division=0)
    macro = f1_score(Y_te, Y_hat, average="macro", zero_division=0)

    # Per-slot accuracy via argmax within each group
    pred_by_slot = decode_per_slot_from_logits(logits_te, vocab, groups)

    # Build GT by slot for *test* subset
    gt_te_by_slot = build_gt_tokens_by_slot(toks_te)
    slot_acc = per_slot_accuracy(gt_te_by_slot, pred_by_slot)

    out = {
        "model": model,
        "scaler": scaler,
        "micro_f1_tokens": float(micro),
        "macro_f1_tokens": float(macro),
        "slot_accuracy": slot_acc,
        "vocab": vocab,
        "groups": groups,
    }
    return out

In [None]:
configs = [
    {"hidden_dims": (512,), "lr": 3e-4},
    {"hidden_dims": (512,), "lr": 1e-3},
    {"hidden_dims": (512, 256), "lr": 3e-4},
    {"hidden_dims": (512, 256), "lr": 1e-3},
]

results = []
for cfg in configs:
    print("\n" + "="*80)
    print("B2a config:", cfg)
    out_b2a = train_eval_b2a(
        X_all=X_all,
        Y_all=Y_all,
        token_lists=token_lists,
        vocab=vocab,
        groups=groups,
        hidden_dims=cfg["hidden_dims"],
        lr=cfg["lr"],
        dropout=0.2,
        epochs=12,
        batch_size=256,
        random_state=0,
        stratify_on_context=True,
    )
    print(f"Token micro-F1: {out_b2a['micro_f1_tokens']:.4f} | macro-F1: {out_b2a['macro_f1_tokens']:.4f}")
    print("Per-slot accuracy:")
    for k,v in out_b2a["slot_accuracy"].items():
        print(f"  {k:14s}: {v:.4f}")
    results.append((cfg, out_b2a))

# Pick best by context slot accuracy (mirrors your A2 selection logic)
best_cfg, best_out = max(results, key=lambda x: x[1]["slot_accuracy"]["context"])
print("\n" + "#"*80)
print("BEST (by context slot accuracy):", best_cfg)
print("Best per-slot accuracy:", best_out["slot_accuracy"])
print(f"Best token micro-F1={best_out['micro_f1_tokens']:.4f}, macro-F1={best_out['macro_f1_tokens']:.4f}")


B2a config: {'hidden_dims': (512,), 'lr': 0.0003}
epoch 01/12  train_loss=1.1372
epoch 03/12  train_loss=0.7576
epoch 06/12  train_loss=0.6111
epoch 09/12  train_loss=0.5390
epoch 12/12  train_loss=0.4912
Token micro-F1: 0.5400 | macro-F1: 0.3218
Per-slot accuracy:
  emitter       : 0.6275
  addressee     : 0.4875
  context       : 0.5715
  emitter_pre   : 0.7030
  addressee_pre : 0.7265
  emitter_post  : 0.7130
  addressee_post: 0.7165

B2a config: {'hidden_dims': (512,), 'lr': 0.001}
epoch 01/12  train_loss=1.1050
epoch 03/12  train_loss=0.6078
epoch 06/12  train_loss=0.4789
epoch 09/12  train_loss=0.4095
epoch 12/12  train_loss=0.3588
Token micro-F1: 0.5870 | macro-F1: 0.3439
Per-slot accuracy:
  emitter       : 0.6625
  addressee     : 0.4605
  context       : 0.5705
  emitter_pre   : 0.7330
  addressee_pre : 0.7230
  emitter_post  : 0.7405
  addressee_post: 0.7195

B2a config: {'hidden_dims': (512, 256), 'lr': 0.0003}
epoch 01/12  train_loss=1.1908
epoch 03/12  train_loss=0.8441


In [None]:
@torch.no_grad()
def predict_b2a_tokens(best_out: dict, X: np.ndarray, topk_per_group: int = 1):
    model = best_out["model"]
    scaler = best_out["scaler"]
    vocab = best_out["vocab"]
    groups = best_out["groups"]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    Xs = scaler.transform(X).astype(np.float32)
    xb = torch.tensor(Xs, dtype=torch.float32, device=device)
    logits = model(xb).detach().cpu().numpy()

    pred_by_slot = decode_per_slot_from_logits(logits, vocab, groups)
    # combine into a "token caption"
    captions = []
    for i in range(len(X)):
        caps = [
            pred_by_slot["emitter"][i],
            pred_by_slot["addressee"][i],
            pred_by_slot["context"][i],
            pred_by_slot["emitter_pre"][i],
            pred_by_slot["addressee_pre"][i],
            pred_by_slot["emitter_post"][i],
            pred_by_slot["addressee_post"][i],
        ]
        captions.append(" ".join(caps))
    return captions

b2a_token_captions = predict_b2a_tokens(best_out, X_all)
print("Example B2a token caption:\n", b2a_token_captions[0])


Example B2a token caption:
 EMITTER_111 ADDRESSEE_221 CONTEXT_Fighting E_PRE_Present A_PRE_Present E_POST_Fly away A_POST_Stay
