# B0 + B1 — Audio-to-Text Description Baselines

These baselines convert structured labels into natural-language descriptions, establishing a
reference format for later generative models.

### B0 — Ground-truth templating
- Deterministically maps annotation fields into a fixed, human-readable sentence.
- Defines the canonical caption structure and controlled vocabulary.

### B1 — Predicted templating (A2 → text)
- Uses the A2 multi-task model to predict structured slots from audio.
- Renders predictions into the same template, yielding an end-to-end audio→text pipeline.

### Evaluation
- Slot-level performance (macro-F1 / accuracy per label)
- End-to-end correctness (all slots correct)
- Qualitative comparisons between ground-truth and predicted descriptions


In [None]:
# Context ID -> name (from your table)
CONTEXT_MAP = {
    "0": "Unknown",
    "1": "Separation",
    "2": "Biting",
    "3": "Feeding",
    "4": "Fighting",
    "5": "Grooming",
    "6": "Isolation",
    "7": "Kissing",
    "8": "Landing",
    "9": "Mating protest",
    "10": "Threat-like",
    "11": "General",
    "12": "Sleeping",
}

# Pre-vocalization action ID -> name (same mapping for emitter_pre and addressee_pre)
PRE_ACTION_MAP = {
    "0": "Unknown",
    "1": "Fly in",
    "2": "Present",
    "3": "Crawl in",
}

# Post-vocalization action ID -> name (same mapping for emitter_post and addressee_post)
POST_ACTION_MAP = {
    "0": "Unknown",
    "1": "Cower",
    "2": "Fly away",
    "3": "Stay",
    "4": "Crawl away",
}

CONTEXT_MAP_I = {int(k): v for k, v in CONTEXT_MAP.items()}
PRE_ACTION_MAP_I = {int(k): v for k, v in PRE_ACTION_MAP.items()}
POST_ACTION_MAP_I = {int(k): v for k, v in POST_ACTION_MAP.items()}

In [None]:
def _as_str(x):
    # Normalize numpy scalars / ints / strings to a clean string key
    if x is None:
        return ""
    return str(x).strip()

def pretty_slot(task: str, raw_value) -> str:
    """
    Convert raw slot value (often numeric code as str) into a human-readable label.
    Falls back gracefully if a value is unmapped.
    """
    v = _as_str(raw_value)

    if task == "context":
        return CONTEXT_MAP.get(v, f"Unknown({v})" if v != "" else "Unknown")

    if task in ("emitter_pre", "addressee_pre"):
        return PRE_ACTION_MAP.get(v, f"Unknown({v})" if v != "" else "Unknown")

    if task in ("emitter_post", "addressee_post"):
        return POST_ACTION_MAP.get(v, f"Unknown({v})" if v != "" else "Unknown")

    # emitter/addressee are IDs and should remain as-is
    if task in ("emitter", "addressee"):
        return v if v != "" else "?"

    return v if v != "" else "?"

In [None]:
SLOTS = [
    "emitter",
    "addressee",
    "context",
    "emitter_pre",
    "addressee_pre",
    "emitter_post",
    "addressee_post",
]

def slots_to_text(slots: dict) -> str:
    """
    Deterministic template generator with human-readable context/action meanings.
    """
    emitter = pretty_slot("emitter", slots.get("emitter"))
    addressee = pretty_slot("addressee", slots.get("addressee"))

    context = pretty_slot("context", slots.get("context"))

    e_pre = pretty_slot("emitter_pre", slots.get("emitter_pre"))
    e_post = pretty_slot("emitter_post", slots.get("emitter_post"))

    a_pre = pretty_slot("addressee_pre", slots.get("addressee_pre"))
    a_post = pretty_slot("addressee_post", slots.get("addressee_post"))

    return (
        f"Emitter {emitter} vocalizes to Addressee {addressee} in context: {context}. "
        f"Emitter action: pre={e_pre}, post={e_post}. "
        f"Addressee action: pre={a_pre}, post={a_post}."
    )

In [None]:
def build_b0_reference_texts(labels_raw: dict) -> list[str]:
    """
    labels_raw is the dict you built in A2: task -> list[str] aligned with X_all.
    """
    n = len(next(iter(labels_raw.values())))
    refs = []
    for i in range(n):
        slots_i = {k: labels_raw[k][i] for k in SLOTS}
        refs.append(slots_to_text(slots_i))
    return refs

b0_texts = build_b0_reference_texts(labels_raw)
print("Example B0 reference text:\n", b0_texts[0])

Example B0 reference text:
 Emitter 216 vocalizes to Addressee 221 in context: General. Emitter action: pre=Present, post=Stay. Addressee action: pre=Crawl in, post=Stay.


In [None]:
import torch
import numpy as np

IGNORE_INDEX = -100  # keep consistent with your A2 code

@torch.no_grad()
def predict_slots_with_a2(a2_out: dict, X: np.ndarray, batch_size: int = 256) -> dict:
    """
    Returns:
      pred_raw: dict task -> list[str] (decoded predicted class labels)
      pred_int: dict task -> np.ndarray[int] (encoded predicted IDs)
    """
    model = a2_out["model"]
    scaler = a2_out["scaler"]
    encoders = a2_out["encoders"]
    tasks = a2_out["tasks"]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    # Apply the SAME scaler used in training A2
    Xs = scaler.transform(X).astype(np.float32)

    pred_int = {t: [] for t in tasks}

    for start in range(0, len(Xs), batch_size):
        xb = torch.tensor(Xs[start:start+batch_size], dtype=torch.float32, device=device)
        logits = model(xb)
        for t in tasks:
            pi = torch.argmax(logits[t], dim=1).detach().cpu().numpy()
            pred_int[t].append(pi)

    pred_int = {t: np.concatenate(pred_int[t]) for t in tasks}

    # Decode ints back to original string labels
    pred_raw = {}
    for t in tasks:
        le = encoders[t]
        pred_raw[t] = le.inverse_transform(pred_int[t]).astype(str).tolist()

    return pred_raw, pred_int

In [None]:
def build_b1_predicted_texts(pred_labels_raw: dict) -> list[str]:
    n = len(next(iter(pred_labels_raw.values())))
    preds = []
    for i in range(n):
        slots_i = {k: pred_labels_raw[k][i] for k in SLOTS}
        preds.append(slots_to_text(slots_i))
    return preds

# Run B1 predictions
pred_labels_raw, pred_labels_int = predict_slots_with_a2(out, X_all, batch_size=256)
b1_texts = build_b1_predicted_texts(pred_labels_raw)

print("Example B1 predicted text:\n", b1_texts[0])

Example B1 predicted text:
 Emitter 216 vocalizes to Addressee 208 in context: Feeding. Emitter action: pre=Present, post=Crawl away. Addressee action: pre=Crawl in, post=Stay.


In [None]:
from sklearn.metrics import accuracy_score, f1_score

def encode_ground_truth_for_eval(a2_out: dict, labels_raw: dict) -> dict:
    """
    Encode GT labels using the SAME LabelEncoders from A2 so int IDs align.
    Any unseen/invalid labels would throw; that usually indicates misalignment.
    """
    encoders = a2_out["encoders"]
    y_true_int = {}
    for t in a2_out["tasks"]:
        le = encoders[t]
        y_true_int[t] = le.transform(np.array(labels_raw[t], dtype=str))
    return y_true_int

def evaluate_b1_slots(a2_out: dict, labels_raw: dict, pred_int: dict) -> dict:
    y_true_int = encode_ground_truth_for_eval(a2_out, labels_raw)
    results = {}
    for t in a2_out["tasks"]:
        yt = y_true_int[t]
        yp = pred_int[t]

        results[t] = {
            "accuracy": float(accuracy_score(yt, yp)),
            "macro_f1": float(f1_score(yt, yp, average="macro")),
        }

    # all-slots-correct (strict)
    n = len(next(iter(y_true_int.values())))
    all_correct = np.ones(n, dtype=bool)
    for t in a2_out["tasks"]:
        all_correct &= (y_true_int[t] == pred_int[t])
    results["all_slots_correct_rate"] = float(all_correct.mean())

    return results

slot_metrics = evaluate_b1_slots(out, labels_raw, pred_labels_int)

print("B1 slot metrics (accuracy, macro-F1):")
for t in out["tasks"]:
    print(f"  {t:14s}  acc={slot_metrics[t]['accuracy']:.3f}  macroF1={slot_metrics[t]['macro_f1']:.3f}")
print(f"\nAll-slots-correct rate: {slot_metrics['all_slots_correct_rate']:.3f}")

B1 slot metrics (accuracy, macro-F1):
  emitter         acc=0.764  macroF1=0.763
  addressee       acc=0.552  macroF1=0.522
  context         acc=0.641  macroF1=0.600
  emitter_pre     acc=0.869  macroF1=0.655
  addressee_pre   acc=0.838  macroF1=0.710
  emitter_post    acc=0.835  macroF1=0.565
  addressee_post  acc=0.813  macroF1=0.704

All-slots-correct rate: 0.185


In [None]:
import random

def show_examples(b0_texts, b1_texts, labels_raw, pred_labels_raw, k=8, seed=0):
    random.seed(seed)
    idxs = random.sample(range(len(b0_texts)), k)
    for i in idxs:
        print("="*90)
        print(f"Index: {i}")
        print("GT slots:", {s: labels_raw[s][i] for s in SLOTS})
        print("PR slots:", {s: pred_labels_raw[s][i] for s in SLOTS})
        print("\nB0 (GT text):")
        print(b0_texts[i])
        print("\nB1 (Pred text):")
        print(b1_texts[i])
        print()

show_examples(b0_texts, b1_texts, labels_raw, pred_labels_raw, k=8, seed=42)

Index: 1824
GT slots: {'emitter': '211', 'addressee': '208', 'context': '12', 'emitter_pre': '2', 'addressee_pre': '2', 'emitter_post': '3', 'addressee_post': '3'}
PR slots: {'emitter': '231', 'addressee': '208', 'context': '12', 'emitter_pre': '2', 'addressee_pre': '2', 'emitter_post': '3', 'addressee_post': '3'}

B0 (GT text):
Emitter 211 vocalizes to Addressee 208 in context: Sleeping. Emitter action: pre=Present, post=Stay. Addressee action: pre=Present, post=Stay.

B1 (Pred text):
Emitter 231 vocalizes to Addressee 208 in context: Sleeping. Emitter action: pre=Present, post=Stay. Addressee action: pre=Present, post=Stay.

Index: 409
GT slots: {'emitter': '231', 'addressee': '221', 'context': '12', 'emitter_pre': '2', 'addressee_pre': '2', 'emitter_post': '3', 'addressee_post': '3'}
PR slots: {'emitter': '231', 'addressee': '221', 'context': '12', 'emitter_pre': '2', 'addressee_pre': '2', 'emitter_post': '3', 'addressee_post': '3'}

B0 (GT text):
Emitter 231 vocalizes to Addressee 