## 07.2 — Multi-task Structured Classification with a Shared MLP

This notebook introduces **07.2**, a structured multi-task model that jointly predicts speaker
identity, interaction context, and pre/post-vocalization behaviors from a shared acoustic
representation. Rather than training each label independently, the model learns a common trunk
that supports multiple, semantically related prediction heads.

**Key design choices:**
- **Shared MLP trunk with multiple heads** to exploit correlations between identity, context, and behavior.
- **Per-task class-weighted losses** to address severe imbalance, especially for context and action labels.
- **Fixed-length acoustic features (1152-D)** consistent with earlier models to isolate the effect of structure.

**Objective:** improve context and action prediction through shared supervision and motivate
sequence-level and language-style modeling in later stages.

In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, List, Tuple, Any

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import f1_score, classification_report, confusion_matrix


## Construct feature matrix and multi-task labels

This cell builds a fixed-dimensional feature matrix and aligned label arrays for all supervised
tasks. Each example contributes multiple labels simultaneously, and only samples with complete
acoustic features are retained to ensure consistent input dimensionality.


In [None]:
def collect_features_and_labels_A2(
    ann: pd.DataFrame,
    use_ast: bool = True,
    use_kmeans_tokens: bool = True,
    use_vqvae_tokens: bool = True,
    kmeans_clusters: int = 128,
    vqvae_codes: int = 256,
    expected_dim: int = 1152,  # 768 + 128 + 256
) -> Tuple[np.ndarray, Dict[str, List[Any]]]:
    """
    Builds X and labels_raw dict aligned to your annotations_10k.csv columns.
    Ensures fixed feature size by using the A0-style bincount(minlength=...).
    """

    stems = ann["File Name"].apply(lambda s: Path(str(s)).stem)

    X_list: List[np.ndarray] = []
    labels_raw: Dict[str, List[Any]] = {
        "emitter": [],
        "addressee": [],
        "context": [],
        "emitter_pre": [],
        "addressee_pre": [],
        "emitter_post": [],
        "addressee_post": [],
    }

    # Exact column names
    COL_EMITTER = "Emitter"
    COL_ADDRESSEE = "Addressee"
    COL_CONTEXT = "Context"
    COL_E_PRE = "Emitter pre-vocalization action"
    COL_A_PRE = "Addressee pre-vocalization action"
    COL_E_POST = "Emitter post-vocalization action"
    COL_A_POST = "Addressee post-vocalization action"

    skipped_missing = 0
    skipped_dim = 0

    for fn, stem, row in zip(ann["File Name"], stems, ann.to_dict("records")):
        parts: List[np.ndarray] = []

        if use_ast:
            ast_vec = _load_ast_vector(stem)
            if ast_vec is None:
                skipped_missing += 1
                continue
            parts.append(ast_vec)

        if use_kmeans_tokens:
            km_hist = _load_kmeans_hist(stem, n_clusters=kmeans_clusters)
            if km_hist is None:
                skipped_missing += 1
                continue
            parts.append(km_hist)

        if use_vqvae_tokens:
            vq_hist = _load_vqvae_hist(stem, n_codes=vqvae_codes)
            if vq_hist is None:
                skipped_missing += 1
                continue
            parts.append(vq_hist)

        feat_vec = np.concatenate(parts).astype(np.float32)

        # Hard guard
        if expected_dim is not None and feat_vec.shape[0] != expected_dim:
            skipped_dim += 1
            continue

        X_list.append(feat_vec)

        # Labels as strings (LabelEncoder-friendly)
        labels_raw["emitter"].append(str(row.get(COL_EMITTER, "")))
        labels_raw["addressee"].append(str(row.get(COL_ADDRESSEE, "")))
        labels_raw["context"].append(str(row.get(COL_CONTEXT, "")))

        labels_raw["emitter_pre"].append(str(row.get(COL_E_PRE, "")))
        labels_raw["addressee_pre"].append(str(row.get(COL_A_PRE, "")))
        labels_raw["emitter_post"].append(str(row.get(COL_E_POST, "")))
        labels_raw["addressee_post"].append(str(row.get(COL_A_POST, "")))

    if not X_list:
        raise RuntimeError("No feature vectors constructed. Check feature directories and filenames.")

    X = np.vstack(X_list)
    print(
        f"Built A2 features for {X.shape[0]} examples; dim={X.shape[1]}. "
        f"Skipped missing={skipped_missing}, skipped_dim={skipped_dim}."
    )
    return X, labels_raw


X_all, labels_raw = collect_features_and_labels_A2(
    ann,
    use_ast=True,
    use_kmeans_tokens=True,
    use_vqvae_tokens=True,
    kmeans_clusters=128,
    vqvae_codes=256,
    expected_dim=1152,
)

X_all.shape, {k: len(set(v)) for k, v in labels_raw.items()}


Built A2 features for 10000 examples; dim=1152. Skipped missing=0, skipped_dim=0.


((10000, 1152),
 {'emitter': 10,
  'addressee': 28,
  'context': 12,
  'emitter_pre': 4,
  'addressee_pre': 4,
  'emitter_post': 5,
  'addressee_post': 5})

## Label encoding and class imbalance handling

This cell encodes string labels into integer IDs, safely ignores missing annotations, and computes
balanced class weights per task. These utilities are critical for handling severe class imbalance
without discarding partially labeled samples.


In [None]:
IGNORE_INDEX = -100

def encode_labels(y_raw, missing_values=None):
    """
    Encode labels with LabelEncoder; map missing_values to IGNORE_INDEX.
    """
    y_raw = np.asarray(y_raw, dtype=object)

    missing_mask = np.zeros(len(y_raw), dtype=bool)
    if missing_values is not None:
        missing_values = set(missing_values)
        missing_mask = np.array([v in missing_values for v in y_raw], dtype=bool)

    le = LabelEncoder()
    y_fit = y_raw[~missing_mask]

    if len(y_fit) == 0:
        le.classes_ = np.array([], dtype=object)
        y_enc = np.full(len(y_raw), IGNORE_INDEX, dtype=np.int64)
        return y_enc, le

    le.fit(y_fit)
    y_enc = np.full(len(y_raw), IGNORE_INDEX, dtype=np.int64)
    y_enc[~missing_mask] = le.transform(y_raw[~missing_mask]).astype(np.int64)
    return y_enc, le


def compute_class_weights(y_enc: np.ndarray, num_classes: int, device: torch.device) -> torch.Tensor:
    """
    Balanced weights: n / (k * count_c), ignoring IGNORE_INDEX.
    """
    y = y_enc[y_enc != IGNORE_INDEX]
    counts = np.bincount(y, minlength=num_classes).astype(np.float32)
    counts[counts == 0] = 1.0
    n = float(len(y))
    k = float(num_classes)
    weights = n / (k * counts)
    return torch.tensor(weights, dtype=torch.float32, device=device)


@torch.no_grad()
def macro_f1_ignore_missing(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    mask = (y_true != IGNORE_INDEX)
    if mask.sum() == 0:
        return float("nan")
    return f1_score(y_true[mask], y_pred[mask], average="macro")

## PyTorch dataset for multi-task learning

This dataset wraps the shared feature vectors and returns a dictionary of task-specific labels for
each example. The structure supports a single forward pass with independent loss computation per
task head.


In [None]:
class MultiTaskDataset(Dataset):
    def __init__(self, X: np.ndarray, y_dict: dict):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y_dict = {k: torch.tensor(v, dtype=torch.long) for k, v in y_dict.items()}

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        out = {"X": self.X[idx]}
        for k in self.y_dict:
            out[k] = self.y_dict[k][idx]
        return out

## Multi-head MLP architecture

This model uses a shared MLP trunk to learn a common acoustic representation, followed by
task-specific linear heads. The design encourages shared structure while allowing specialization
across identity, context, and action labels.


In [None]:
class MultiHeadMLP(nn.Module):
    def __init__(self, input_dim: int, hidden_dims=(512,), dropout=0.3, head_dims=None):
        super().__init__()
        if head_dims is None:
            head_dims = {}

        layers = []
        prev = input_dim
        for h in hidden_dims:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
            prev = h

        self.trunk = nn.Sequential(*layers)
        self.heads = nn.ModuleDict({k: nn.Linear(prev, c) for k, c in head_dims.items()})

    def forward(self, x):
        z = self.trunk(x)
        return {k: head(z) for k, head in self.heads.items()}

## Training and evaluation loops

This cell defines the multi-task training loop with weighted, task-specific losses and an evaluation
routine that reports macro-F1 while ignoring missing labels. Metrics are computed per task to
diagnose where learning is most effective.


In [None]:
def train_one_epoch(model, loader, optim, criterions, loss_weights, device):
    model.train()
    total = 0.0
    n = 0

    for batch in loader:
        X = batch["X"].to(device)
        optim.zero_grad()

        logits = model(X)
        loss = 0.0

        for task, logit in logits.items():
            y = batch[task].to(device)
            loss_task = criterions[task](logit, y)
            loss = loss + loss_weights.get(task, 1.0) * loss_task

        loss.backward()
        optim.step()

        bs = X.size(0)
        total += loss.item() * bs
        n += bs

    return total / max(n, 1)


@torch.no_grad()
def evaluate(model, loader, tasks, device):
    model.eval()

    all_true = {t: [] for t in tasks}
    all_pred = {t: [] for t in tasks}

    for batch in loader:
        X = batch["X"].to(device)
        logits = model(X)

        for t in tasks:
            y = batch[t].cpu().numpy()
            pred = torch.argmax(logits[t], dim=1).cpu().numpy()
            all_true[t].append(y)
            all_pred[t].append(pred)

    out = {}
    for t in tasks:
        y_true = np.concatenate(all_true[t])
        y_pred = np.concatenate(all_pred[t])
        out[t] = {
            "macro_f1": macro_f1_ignore_missing(y_true, y_pred),
            "y_true": y_true,
            "y_pred": y_pred,
        }
    return out


## Run multi-task experiments

This function performs train/test splitting, feature scaling, and a small hyperparameter search
over model size and learning rate. Models are selected based on context macro-F1, the most
challenging and informative task.


In [None]:
def run_a2_multitask_mlp(
    X_all: np.ndarray,
    labels_raw: dict,
    hidden_grid=((512,), (512, 256)),
    lr_grid=(1e-3, 3e-4),
    dropout=0.3,
    batch_size=128,
    epochs=30,
    random_state=42,
    select_by="context",  # "context" or "avg"
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    missing_vals = {"", "NA", "NaN", "None", None, np.nan}

    # Encode labels per head
    y_enc = {}
    encoders = {}
    for k, y in labels_raw.items():
        yk, le = encode_labels(y, missing_values=missing_vals)
        y_enc[k] = yk
        encoders[k] = le

    # Stratify on context (primary target for improvement)
    y_strat = y_enc["context"]
    X_train, X_test, idx_train, idx_test = train_test_split(
        X_all, np.arange(len(X_all)),
        test_size=0.2,
        random_state=random_state,
        stratify=y_strat,
    )

    y_train = {k: v[idx_train] for k, v in y_enc.items()}
    y_test  = {k: v[idx_test]  for k, v in y_enc.items()}

    # Scale features for MLP stability
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test  = scaler.transform(X_test)

    train_ds = MultiTaskDataset(X_train, y_train)
    test_ds  = MultiTaskDataset(X_test, y_test)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    test_loader  = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

    input_dim = X_train.shape[1]

    head_dims = {k: len(encoders[k].classes_) for k in y_enc.keys()}
    tasks = list(head_dims.keys())

    # Loss weights: emphasize context + action heads
    loss_weights = {t: 1.0 for t in tasks}
    loss_weights["context"] = 2.0
    for t in tasks:
        if "pre" in t or "post" in t:
            loss_weights[t] = 1.5

    print("\n" + "=" * 80)
    print("A2: Multi-task structured classifier (shared MLP trunk + 7 heads)")
    print("Tasks:", tasks)
    print("Loss weights:", loss_weights)
    print("Selecting best config by:", select_by)

    best_score = -1.0
    best_state = None
    best_cfg = None

    for hidden_dims in hidden_grid:
        for lr in lr_grid:
            model = MultiHeadMLP(
                input_dim=input_dim,
                hidden_dims=hidden_dims,
                dropout=dropout,
                head_dims=head_dims
            ).to(device)

            # Class-weighted CE per head
            criterions = {}
            for t in tasks:
                w = compute_class_weights(y_train[t], head_dims[t], device=device)
                criterions[t] = nn.CrossEntropyLoss(weight=w, ignore_index=IGNORE_INDEX)

            optim = torch.optim.AdamW(model.parameters(), lr=lr)

            for _ in range(epochs):
                train_one_epoch(model, train_loader, optim, criterions, loss_weights, device)

            metrics = evaluate(model, test_loader, tasks, device)

            if select_by == "context":
                score = metrics["context"]["macro_f1"]
            else:
                vals = [metrics[t]["macro_f1"] for t in tasks if not np.isnan(metrics[t]["macro_f1"])]
                score = float(np.mean(vals)) if vals else float("nan")

            print(f"hidden={hidden_dims}, lr={lr:.1e} → score={score:.4f}")

            if score > best_score:
                best_score = score
                best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
                best_cfg = (hidden_dims, lr)

    # Reload best
    hidden_dims, lr = best_cfg
    best_model = MultiHeadMLP(
        input_dim=input_dim,
        hidden_dims=hidden_dims,
        dropout=dropout,
        head_dims=head_dims
    ).to(device)
    best_model.load_state_dict(best_state)

    metrics = evaluate(best_model, test_loader, tasks, device)

    print("\nBest config:", {"hidden_dims": hidden_dims, "lr": lr})
    print("Best selection score:", best_score)

    # Reports per head
    for t in tasks:
        le = encoders[t]
        y_true = metrics[t]["y_true"]
        y_pred = metrics[t]["y_pred"]
        mask = (y_true != IGNORE_INDEX)

        print("\n" + "-" * 80)
        print(f"Task: {t}")
        print(f"Macro-F1: {metrics[t]['macro_f1']:.4f}")

        true_lbl = le.inverse_transform(y_true[mask])
        pred_lbl = le.inverse_transform(y_pred[mask])

        print("\nClassification report (all classes):")
        K = len(le.classes_)
        labels_idx = list(range(K))  # enforce consistent class set

        # Use encoded ints for report/CM, so labels line up perfectly
        y_true_i = y_true[mask]
        y_pred_i = y_pred[mask]

        print(classification_report(
            y_true_i,
            y_pred_i,
            labels=labels_idx,
            target_names=[str(c) for c in le.classes_],
            zero_division=0
        ))

        print("Confusion matrix (all classes):")
        print(confusion_matrix(
            y_true_i,
            y_pred_i,
            labels=labels_idx
        ))


    return {
        "model": best_model,
        "scaler": scaler,
        "encoders": encoders,
        "tasks": tasks,
        "best_cfg": {"hidden_dims": hidden_dims, "lr": lr, "score": best_score},
        "metrics": metrics,
    }


## Execute training and evaluation

This cell runs the full 07.2 experiment, selects the best model, and reports detailed metrics and
confusion matrices for all tasks. The results guide interpretation of which communicative variables
are acoustically grounded and motivate later modeling stages.


In [None]:
out = run_a2_multitask_mlp(
    X_all=X_all,
    labels_raw=labels_raw,
    hidden_grid=((512,), (512, 256)),  # small grid (A1 suggested bigger isn't always better)
    lr_grid=(1e-3, 3e-4),
    dropout=0.3,
    epochs=30,
    batch_size=128,
    select_by="context",
)


A2: Multi-task structured classifier (shared MLP trunk + 7 heads)
Tasks: ['emitter', 'addressee', 'context', 'emitter_pre', 'addressee_pre', 'emitter_post', 'addressee_post']
Loss weights: {'emitter': 1.0, 'addressee': 1.0, 'context': 2.0, 'emitter_pre': 1.5, 'addressee_pre': 1.5, 'emitter_post': 1.5, 'addressee_post': 1.5}
Selecting best config by: context
hidden=(512,), lr=1.0e-03 → score=0.3617
hidden=(512,), lr=3.0e-04 → score=0.3782
hidden=(512, 256), lr=1.0e-03 → score=0.3776
hidden=(512, 256), lr=3.0e-04 → score=0.3585

Best config: {'hidden_dims': (512,), 'lr': 0.0003}
Best selection score: 0.3782362041893505

--------------------------------------------------------------------------------
Task: emitter
Macro-F1: 0.6662

Classification report (all classes):
              precision    recall  f1-score   support

         111       0.69      0.82      0.75       202
         210       0.59      0.49      0.54       218
         211       0.59      0.58      0.58       205
      