# B4: Audio to Text Generation (Feature-conditioned GRU decoder)

**Goal:** Train a neural generator that maps a single bat call’s acoustic features (AST + token histograms) to a short, grammatical natural-language description of the interaction.

**Why B4 (after previous model designs):**
- Multi-task Structured Classification learned a strong structured semantic representation (emitter/addressee/context/actions) but outputs discrete labels.
- B3 showed we can convert labels into readable sentences, but it’s still classification + templating.
- B4 learns **direct sentence generation** from audio features, producing an interpretable text output while preserving your controlled, factual structure.

**Model design:**
- Encoder: small MLP maps 1152-dim audio features → a context vector.
- Decoder: GRU generates a sentence token-by-token (teacher forcing training).
- This is intentionally simple (interpretable, fast, easy to debug) and matches your dataset scale.

In [None]:
CONTEXT_MAP = {
    "0": None,
    "1": "separation",
    "2": "biting",
    "3": "feeding",
    "4": "fighting",
    "5": "grooming",
    "6": "isolation",
    "7": "kissing",
    "8": "landing",
    "9": "mating protest",
    "10": "threat-like interaction",
    "11": "general interaction",
    "12": "sleeping",
}

PRE_ACTION_MAP = {
    "0": None,
    "1": "flying in",
    "2": "present",
    "3": "crawling in",
}

POST_ACTION_MAP = {
    "0": None,
    "1": "cowered",
    "2": "flew away",
    "3": "stayed",
    "4": "crawled away",
}

In [None]:
import math
import random
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split

In [None]:
def with_article(phrase: str) -> str:
    """Adds 'a' or 'an' where appropriate."""
    if phrase is None:
        return "an unspecified interaction"
    if phrase[0].lower() in "aeiou":
        return f"an {phrase}"
    return f"a {phrase}"

def normalize_key(x):
    if x is None:
        return None
    # handle pandas/numpy scalars
    s = str(x).strip()

    # convert "11.0" -> "11"
    try:
        f = float(s)
        if f.is_integer():
            return str(int(f))
    except Exception:
        pass

    # also handle "11" already
    return s

def safe_lookup(mapping: dict, key: str) -> str | None:
    k = normalize_key(key)
    return mapping.get(str(k), None)

In [None]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [None]:
from pathlib import Path

def _load_ast_vector(stem: str) -> np.ndarray | None:
    ast_path = AST_DIR / f"ast_{stem}.npy"
    if not ast_path.exists():
        return None
    vec = np.load(ast_path)
    return np.asarray(vec, dtype=np.float32).reshape(-1)

def _load_kmeans_hist(stem: str, n_clusters: int = 128) -> np.ndarray | None:
    tok_path = KMEANS_DIR / f"w2v_kmeans_{stem}.npy"
    if not tok_path.exists():
        return None
    tokens = np.load(tok_path).astype(int)
    hist = np.bincount(tokens, minlength=n_clusters).astype(np.float32)
    total = hist.sum()
    if total > 0:
        hist /= total
    return hist

def _load_vqvae_hist(stem: str, n_codes: int = 256) -> np.ndarray | None:
    tok_path = VQ_TOKENS_DIR / f"vqvae_{stem}.npy"
    if not tok_path.exists():
        return None
    tokens = np.load(tok_path).astype(int)
    hist = np.bincount(tokens, minlength=n_codes).astype(np.float32)
    total = hist.sum()
    if total > 0:
        hist /= total
    return hist

In [None]:
def collect_features_and_labels_B6(
    ann: pd.DataFrame,
    use_ast: bool = True,
    use_kmeans_tokens: bool = True,
    use_vqvae_tokens: bool = True,
) -> Tuple[np.ndarray, pd.DataFrame]:
    """
    Builds X (N, 1152) and returns a *filtered* ann_aligned dataframe with the same row order as X.
    """
    stems = ann["File Name"].apply(lambda s: Path(str(s)).stem)

    X_list: List[np.ndarray] = []
    rows: List[int] = []

    missing_any = 0

    for i, (stem,) in enumerate(zip(stems)):
        parts: List[np.ndarray] = []

        if use_ast:
            ast_vec = _load_ast_vector(stem)
            if ast_vec is None:
                missing_any += 1
                continue
            parts.append(ast_vec)

        if use_kmeans_tokens:
            km_hist = _load_kmeans_hist(stem, n_clusters=128)
            if km_hist is None:
                missing_any += 1
                continue
            parts.append(km_hist)

        if use_vqvae_tokens:
            vq_hist = _load_vqvae_hist(stem, n_codes=256)  # IMPORTANT: fixed 256
            if vq_hist is None:
                missing_any += 1
                continue
            parts.append(vq_hist)

        feat_vec = np.concatenate(parts).astype(np.float32)
        X_list.append(feat_vec)
        rows.append(i)

    if not X_list:
        raise RuntimeError("No feature vectors constructed.")

    X = np.vstack(X_list)
    ann_aligned = ann.iloc[rows].reset_index(drop=True)

    print(f"Built X for {X.shape[0]} examples; dim={X.shape[1]}. Skipped {missing_any}.")
    return X, ann_aligned

In [None]:
def map_or_unknown(m: Dict[str, str], key: str) -> str:
    key = str(key)
    return m[key] if key in m else f"unknown({key})"


In [None]:
def build_caption_from_row(row: pd.Series) -> str:
    emitter = str(row["Emitter"])
    addressee = str(row["Addressee"])

    context = safe_lookup(CONTEXT_MAP, row["Context"])
    e_pre = safe_lookup(PRE_ACTION_MAP, row["Emitter pre-vocalization action"])
    a_pre = safe_lookup(PRE_ACTION_MAP, row["Addressee pre-vocalization action"])
    e_post = safe_lookup(POST_ACTION_MAP, row["Emitter post-vocalization action"])
    a_post = safe_lookup(POST_ACTION_MAP, row["Addressee post-vocalization action"])

    # Sentence 1: core event
    s1 = (
        f"Bat {emitter} vocalized toward bat {addressee} "
        f"during {with_article(context)}."
    )

    # Sentence 2: pre-vocalization
    if e_pre or a_pre:
        pre_parts = []
        if e_pre:
            pre_parts.append(f"the emitter was {e_pre}")
        if a_pre:
            pre_parts.append(f"the addressee was {a_pre}")
        s2 = "Before vocalizing, " + " and ".join(pre_parts) + "."
    else:
        s2 = "Before vocalizing, no clear movement was observed."

    # Sentence 3: post-vocalization
    if e_post or a_post:
        post_parts = []
        if e_post:
            post_parts.append(f"the emitter {e_post}")
        if a_post:
            post_parts.append(f"the addressee {a_post}")
        s3 = "Afterward, " + " while ".join(post_parts) + "."
    else:
        s3 = "Afterward, no clear movement was observed."

    return " ".join([s1, s2, s3])

def build_all_captions(ann_aligned: pd.DataFrame) -> List[str]:
    return [build_caption_from_row(ann_aligned.iloc[i]) for i in range(len(ann_aligned))]


In [None]:
SPECIAL = ["<pad>", "<bos>", "<eos>", "<unk>"]
PAD, BOS, EOS, UNK = SPECIAL

def simple_tokenize(text: str) -> List[str]:
    # conservative tokenizer
    text = text.lower().strip()
    for ch in [".", ",", ":", ";"]:
        text = text.replace(ch, f" {ch} ")
    return text.split()

@dataclass
class Vocab:
    stoi: Dict[str, int]
    itos: List[str]

    @property
    def pad_id(self): return self.stoi[PAD]
    @property
    def bos_id(self): return self.stoi[BOS]
    @property
    def eos_id(self): return self.stoi[EOS]
    @property
    def unk_id(self): return self.stoi[UNK]

def build_vocab(texts: List[str], min_freq: int = 2) -> Vocab:
    freq: Dict[str, int] = {}
    for t in texts:
        for tok in simple_tokenize(t):
            freq[tok] = freq.get(tok, 0) + 1

    itos = list(SPECIAL)
    for tok, c in sorted(freq.items(), key=lambda x: (-x[1], x[0])):
        if c >= min_freq and tok not in SPECIAL:
            itos.append(tok)

    stoi = {t:i for i,t in enumerate(itos)}
    return Vocab(stoi=stoi, itos=itos)

def encode_text(vocab: Vocab, text: str, max_len: int = 64) -> List[int]:
    toks = [BOS] + simple_tokenize(text) + [EOS]
    ids = [vocab.stoi.get(tok, vocab.unk_id) for tok in toks]
    if len(ids) < max_len:
        ids += [vocab.pad_id] * (max_len - len(ids))
    else:
        ids = ids[:max_len]
        ids[-1] = vocab.eos_id
    return ids

def decode_ids(vocab: Vocab, ids: List[int]) -> str:
    toks = []
    for i in ids:
        tok = vocab.itos[i] if 0 <= i < len(vocab.itos) else UNK
        if tok in (BOS, PAD):
            continue
        if tok == EOS:
            break
        toks.append(tok)
    # detokenize lightly
    out = " ".join(toks)
    out = out.replace(" .", ".").replace(" ,", ",")
    return out

In [None]:
class B6Dataset(Dataset):
    def __init__(self, X: np.ndarray, captions: List[str], vocab: Vocab, max_len: int = 64):
        self.X = torch.from_numpy(X).float()
        self.y = torch.tensor([encode_text(vocab, c, max_len=max_len) for c in captions], dtype=torch.long)
        self.max_len = max_len

    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

def make_loaders(X: np.ndarray, captions: List[str], test_size=0.2, seed=42, batch_size=128, max_len=64):
    idx = np.arange(len(X))
    tr_idx, te_idx = train_test_split(idx, test_size=test_size, random_state=seed, shuffle=True)

    train_texts = [captions[i] for i in tr_idx]
    vocab = build_vocab(train_texts, min_freq=2)

    ds_tr = B6Dataset(X[tr_idx], [captions[i] for i in tr_idx], vocab, max_len=max_len)
    ds_te = B6Dataset(X[te_idx], [captions[i] for i in te_idx], vocab, max_len=max_len)

    dl_tr = DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=False)
    dl_te = DataLoader(ds_te, batch_size=batch_size, shuffle=False, drop_last=False)

    return dl_tr, dl_te, vocab, tr_idx, te_idx

In [None]:
class B6Captioner(nn.Module):
    def __init__(self, x_dim: int, vocab_size: int, enc_hidden: int = 512,
                 dec_hidden: int = 512, emb_dim: int = 256, dropout: float = 0.2):
        super().__init__()

        # encoder maps X -> context vector
        self.encoder = nn.Sequential(
            nn.Linear(x_dim, enc_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(enc_hidden, dec_hidden),
            nn.ReLU(),
        )

        self.emb = nn.Embedding(vocab_size, emb_dim)
        self.gru = nn.GRU(input_size=emb_dim, hidden_size=dec_hidden, batch_first=True)
        self.out = nn.Linear(dec_hidden, vocab_size)

    def forward(self, X: torch.Tensor, y_inp: torch.Tensor):
        """
        Teacher forcing:
          - y_inp: (B, T) token ids, includes BOS... (we predict next token)
        """
        h0 = self.encoder(X).unsqueeze(0)         # (1, B, H)
        emb = self.emb(y_inp)                     # (B, T, E)
        out, _ = self.gru(emb, h0)                # (B, T, H)
        logits = self.out(out)                    # (B, T, V)
        return logits

    @torch.no_grad()
    def generate(self, X: torch.Tensor, bos_id: int, eos_id: int, max_len: int = 64):
        B = X.size(0)
        h = self.encoder(X).unsqueeze(0)

        ys = torch.full((B, 1), bos_id, dtype=torch.long, device=X.device)

        for _ in range(max_len - 1):
            emb = self.emb(ys[:, -1:])            # (B, 1, E)
            out, h = self.gru(emb, h)             # (B, 1, H)
            logits = self.out(out[:, -1, :])      # (B, V)
            nxt = torch.argmax(logits, dim=-1, keepdim=True)  # greedy
            ys = torch.cat([ys, nxt], dim=1)
            if (nxt.squeeze(1) == eos_id).all():
                break
        return ys

In [None]:
def seq_ce_loss(logits: torch.Tensor, targets: torch.Tensor, pad_id: int) -> torch.Tensor:
    """
    logits: (B, T, V)
    targets: (B, T)  (the desired next tokens aligned with logits)
    """
    B, T, V = logits.shape
    logits = logits.reshape(B*T, V)
    targets = targets.reshape(B*T)
    return F.cross_entropy(logits, targets, ignore_index=pad_id)

@torch.no_grad()
def token_accuracy(pred: torch.Tensor, gold: torch.Tensor, pad_id: int) -> float:
    mask = (gold != pad_id)
    if mask.sum().item() == 0:
        return 0.0
    return (pred[mask] == gold[mask]).float().mean().item()

@torch.no_grad()
def exact_match(pred: torch.Tensor, gold: torch.Tensor, pad_id: int) -> float:
    # exact match over non-pad positions
    B = pred.size(0)
    matches = 0
    for i in range(B):
        g = gold[i]
        p = pred[i]
        g_len = (g != pad_id).sum().item()
        if g_len == 0:
            continue
        if torch.equal(p[:g_len], g[:g_len]):
            matches += 1
    return matches / B

def train_one_config(
    dl_tr, dl_te, vocab: Vocab,
    x_dim: int = 1152,
    enc_hidden: int = 512,
    dec_hidden: int = 512,
    lr: float = 1e-3,
    dropout: float = 0.2,
    epochs: int = 12,
):
    model = B6Captioner(x_dim=x_dim, vocab_size=len(vocab.itos),
                        enc_hidden=enc_hidden, dec_hidden=dec_hidden,
                        emb_dim=256, dropout=dropout).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr)

    for ep in range(1, epochs+1):
        model.train()
        total = 0.0
        for Xb, yb in dl_tr:
            Xb = Xb.to(device)
            yb = yb.to(device)

            # teacher forcing: input tokens are y[:, :-1], predict y[:, 1:]
            y_inp = yb[:, :-1]
            y_tgt = yb[:, 1:]

            logits = model(Xb, y_inp)
            loss = seq_ce_loss(logits, y_tgt, pad_id=vocab.pad_id)

            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

            total += loss.item() * Xb.size(0)

        if ep in (1, 3, 6, 9, epochs):
            model.eval()
            accs = []
            ems = []
            for Xb, yb in dl_te:
                Xb = Xb.to(device)
                yb = yb.to(device)

                gen = model.generate(Xb, bos_id=vocab.bos_id, eos_id=vocab.eos_id, max_len=yb.size(1))
                # compare generated tokens to full target sequence
                # align lengths: gen includes BOS; yb includes BOS. keep same shape by pad/truncate
                if gen.size(1) < yb.size(1):
                    pad = torch.full((gen.size(0), yb.size(1)-gen.size(1)), vocab.pad_id, device=device, dtype=torch.long)
                    gen = torch.cat([gen, pad], dim=1)
                else:
                    gen = gen[:, :yb.size(1)]

                accs.append(token_accuracy(gen, yb, pad_id=vocab.pad_id))
                ems.append(exact_match(gen, yb, pad_id=vocab.pad_id))

            print(f"epoch {ep:02d}/{epochs} train_loss={total/len(dl_tr.dataset):.4f}  "
                  f"val_token_acc={sum(accs)/len(accs):.4f}  val_exact_match={sum(ems)/len(ems):.4f}")

    return model

In [None]:
# Build features + aligned annotations
X_all, ann_aligned = collect_features_and_labels_B6(
    ann,
    use_ast=True,
    use_kmeans_tokens=True,
    use_vqvae_tokens=True,
)

# Build caption targets
captions = build_all_captions(ann_aligned)

# DataLoaders + vocab
dl_tr, dl_te, vocab, tr_idx, te_idx = make_loaders(
    X_all, captions,
    test_size=0.2,
    seed=42,
    batch_size=128,
    max_len=64
)

print("Vocab size:", len(vocab.itos))
print("Example target:", captions[0])


Built X for 10000 examples; dim=1152. Skipped 0.
Vocab size: 76
Example target: Bat 216 vocalized toward bat 221 during a general interaction. Before vocalizing, the emitter was present and the addressee was crawling in. Afterward, the emitter stayed while the addressee stayed.


In [None]:
configs = [
    {"enc_hidden": 512, "dec_hidden": 512, "lr": 1e-3, "dropout": 0.2},
    {"enc_hidden": 512, "dec_hidden": 512, "lr": 3e-4, "dropout": 0.2},
]

best_model = None
best_score = -1.0

for cfg in configs:
    print("\n================================================================================")
    print("B6 config:", cfg)

    model = train_one_config(
        dl_tr, dl_te, vocab,
        x_dim=X_all.shape[1],
        enc_hidden=cfg["enc_hidden"],
        dec_hidden=cfg["dec_hidden"],
        lr=cfg["lr"],
        dropout=cfg["dropout"],
        epochs=12,
    )

    # Evaluate token accuracy on test set (quick score)
    model.eval()
    accs = []
    for Xb, yb in dl_te:
        Xb = Xb.to(device)
        yb = yb.to(device)
        gen = model.generate(Xb, vocab.bos_id, vocab.eos_id, max_len=yb.size(1))
        if gen.size(1) < yb.size(1):
            pad = torch.full((gen.size(0), yb.size(1)-gen.size(1)), vocab.pad_id, device=device, dtype=torch.long)
            gen = torch.cat([gen, pad], dim=1)
        else:
            gen = gen[:, :yb.size(1)]
        accs.append(token_accuracy(gen, yb, vocab.pad_id))
    score = float(sum(accs)/len(accs))
    print(f"B6 token-accuracy (test): {score:.4f}")

    if score > best_score:
        best_score = score
        best_model = model

print("\n################################################################################")
print("BEST B6 token-accuracy:", best_score)



B6 config: {'enc_hidden': 512, 'dec_hidden': 512, 'lr': 0.001, 'dropout': 0.2}
epoch 01/12 train_loss=0.6330  val_token_acc=0.6955  val_exact_match=0.1474
epoch 03/12 train_loss=0.1471  val_token_acc=0.7420  val_exact_match=0.1975
epoch 06/12 train_loss=0.1252  val_token_acc=0.7526  val_exact_match=0.2103
epoch 09/12 train_loss=0.1163  val_token_acc=0.7632  val_exact_match=0.2336
epoch 12/12 train_loss=0.1087  val_token_acc=0.7678  val_exact_match=0.2447
B6 token-accuracy (test): 0.7678

B6 config: {'enc_hidden': 512, 'dec_hidden': 512, 'lr': 0.0003, 'dropout': 0.2}
epoch 01/12 train_loss=1.3406  val_token_acc=0.6752  val_exact_match=0.0735
epoch 03/12 train_loss=0.1780  val_token_acc=0.6873  val_exact_match=0.1552
epoch 06/12 train_loss=0.1413  val_token_acc=0.7370  val_exact_match=0.1902
epoch 09/12 train_loss=0.1297  val_token_acc=0.7520  val_exact_match=0.1923
epoch 12/12 train_loss=0.1233  val_token_acc=0.7565  val_exact_match=0.2355
B6 token-accuracy (test): 0.7565

############

In [None]:
# Show a few qualitative examples
best_model.eval()

# pick a few indices from test set
sample_ids = list(te_idx[:5])
for idx in sample_ids:
    X = torch.from_numpy(X_all[idx:idx+1]).float().to(device)
    gen = best_model.generate(X, vocab.bos_id, vocab.eos_id, max_len=64)[0].tolist()

    gt = captions[idx]
    pred = decode_ids(vocab, gen)

    print("\n================================================================================")
    print("Index:", idx)
    print("GT  :", gt)
    print("PRED:", pred)


Index: 6252
GT  : Bat 230 vocalized toward bat 207 during a sleeping. Before vocalizing, the emitter was present and the addressee was present. Afterward, the emitter stayed while the addressee stayed.
PRED: bat 230 vocalized toward bat 207 during a sleeping. before vocalizing, the emitter was present and the addressee was present. afterward, the emitter stayed while the addressee stayed.

Index: 4684
GT  : Bat 230 vocalized toward bat 207 during a sleeping. Before vocalizing, the emitter was present and the addressee was present. Afterward, the emitter stayed while the addressee stayed.
PRED: bat 230 vocalized toward bat 207 during a sleeping. before vocalizing, the emitter was present and the addressee was present. afterward, the emitter stayed while the addressee stayed.

Index: 1731
GT  : Bat 216 vocalized toward bat 233 during a general interaction. Before vocalizing, no clear movement was observed. Afterward, the emitter flew away while the addressee stayed.
PRED: bat 111 vocali