# Tokenization Strategies for Bat Vocalizations

This notebook sets up three representation / tokenization strategies on top of the 10k subset (using `data/annotations.csv` and derived features):

1. **Self-supervised speech encoders (wav2vec 2.0 / HuBERT) + k-means** to produce discrete "bio-tokens".
2. **VQ-VAE on mel-spectrograms** to learn a discrete codebook of acoustic units.
3. **Continuous-feature encoders (AST)** that operate on spectrograms without discretization.

Classifier or sequence models can be trained later on top of the saved representations.

In [None]:
from __future__ import annotations

from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
import torch
from sklearn.cluster import KMeans

ROOT = Path(__file__).resolve().parent
DATA_DIR = ROOT / 'data'
DERIVED_DIR = ROOT / 'derived'
AUDIO_DIR = DATA_DIR / 'audio'
MELS_48K_DIR = DERIVED_DIR / 'mels_48k'
TOKENS_DIR = DERIVED_DIR / 'tokens'
AST_DIR = DERIVED_DIR / 'ast_features'

TOKENS_DIR.mkdir(exist_ok=True, parents=True)
AST_DIR.mkdir(exist_ok=True, parents=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device


## 1. wav2vec 2.0 / HuBERT embeddings + k-means clustering

We use a pretrained self-supervised speech model to get frame-level embeddings, then learn a k-means codebook to derive discrete token sequences per call.

In [None]:
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
import librosa

W2V_MODEL_NAME = 'facebook/wav2vec2-base'  # or a HuBERT variant

feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(W2V_MODEL_NAME)
w2v_model = Wav2Vec2Model.from_pretrained(W2V_MODEL_NAME).to(device).eval()

ann_small = pd.read_csv(DATA_DIR / 'annotations.csv')

def load_audio_for_w2v(path: Path, target_sr: int = 16_000) -> np.ndarray:
    y, sr = librosa.load(path, sr=None)
    if sr != target_sr:
        y = librosa.resample(y, orig_sr=sr, target_sr=target_sr)
    return y.astype(np.float32)

def extract_w2v_embeddings(wav: np.ndarray, sr: int = 16_000) -> np.ndarray:
    inputs = feature_extractor(wav, sampling_rate=sr, return_tensors='pt')
    with torch.no_grad():
        out = w2v_model(inputs.input_values.to(device))
    # shape: (1, T, hidden_size)
    return out.last_hidden_state.squeeze(0).cpu().numpy()

# Example: extract embeddings for a small subset of files
subset = ann_small['File Name'].iloc[:128]
all_frames: List[np.ndarray] = []
file2frame_indices: Dict[str, slice] = {}
start = 0
for fn in subset:
    path = AUDIO_DIR / fn
    if not path.exists():
        continue
    wav = load_audio_for_w2v(path)
    emb = extract_w2v_embeddings(wav)  # (T, D)
    end = start + emb.shape[0]
    all_frames.append(emb)
    file2frame_indices[fn] = slice(start, end)
    start = end

frame_matrix = np.concatenate(all_frames, axis=0)  # (total_T, D)
frame_matrix.shape


In [None]:
# Fit k-means on frame-level embeddings to create a codebook
N_CLUSTERS = 128
kmeans = KMeans(n_clusters=N_CLUSTERS, random_state=0)
kmeans.fit(frame_matrix)

# Convert each file's frames into a sequence of discrete token IDs
file2tokens: Dict[str, np.ndarray] = {}
for fn, sl in file2frame_indices.items():
    frame_embs = frame_matrix[sl]
    tokens = kmeans.predict(frame_embs)
    file2tokens[fn] = tokens.astype(np.int16)
    np.save(TOKENS_DIR / f'w2v_kmeans_{Path(fn).stem}.npy', tokens)

len(file2tokens)


## 2. VQ-VAE over mel-spectrograms

We now sketch a simple VQ-VAE model that operates on log-mel spectrogram patches loaded from `derived/mels_48k`.
This code is a starting point; you can tune architecture and training hyperparameters based on resources.

In [None]:
import torch.nn as nn

class VectorQuantizer(nn.Module):
    def __init__(self, num_codes: int, code_dim: int, beta: float = 0.25):
        super().__init__()
        self.code_dim = code_dim
        self.embeddings = nn.Embedding(num_codes, code_dim)
        self.embeddings.weight.data.uniform_(-1.0 / num_codes, 1.0 / num_codes)
        self.beta = beta

    def forward(self, z: torch.Tensor):
        # z: (B, C, T, F) -> flatten to (B*T*F, C)
        z_perm = z.permute(0, 2, 3, 1).contiguous()
        flat_z = z_perm.view(-1, self.code_dim)
        # Compute distances to codebook
        distances = (
            flat_z.pow(2).sum(dim=1, keepdim=True)
            - 2 * flat_z @ self.embeddings.weight.t()
            + self.embeddings.weight.pow(2).sum(dim=1)
        )
        codes = distances.argmin(dim=1)
        z_q = self.embeddings(codes).view(*z_perm.shape)
        z_q = z_q.permute(0, 3, 1, 2).contiguous()
        # VQ-VAE losses
        commitment_loss = self.beta * (z_q.detach() - z).pow(2).mean()
        codebook_loss = (z_q - z.detach()).pow(2).mean()
        z_q = z + (z_q - z).detach()  # straight-through
        return z_q, codes.view(z.size(0), -1), commitment_loss + codebook_loss

class SimpleVQVAE(nn.Module):
    def __init__(self, hidden_dim: int = 128, num_codes: int = 256):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, hidden_dim, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )
        self.vq = VectorQuantizer(num_codes=num_codes, code_dim=hidden_dim)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(hidden_dim, hidden_dim, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_dim, 1, kernel_size=4, stride=2, padding=1),
        )

    def forward(self, x: torch.Tensor):  # x: (B, 1, T, F)
        z = self.encoder(x)
        z_q, codes, vq_loss = self.vq(z)
        recon = self.decoder(z_q)
        recon_loss = (x - recon).pow(2).mean()
        return recon, codes, recon_loss + vq_loss


In [None]:
# Minimal data loader for mel patches
class MelDataset(torch.utils.data.Dataset):
    def __init__(self, mels_dir: Path, file_names: List[str]):
        self.mel_paths = [mels_dir / (Path(fn).stem + '.npy') for fn in file_names]
        self.mel_paths = [p for p in self.mel_paths if p.exists()]

    def __len__(self) -> int:
        return len(self.mel_paths)

    def __getitem__(self, idx: int) -> torch.Tensor:
        mel = np.load(self.mel_paths[idx])  # (n_mels, T)
        x = torch.from_numpy(mel.T).unsqueeze(0)  # (1, T, F)
        return x

subset_fns = ann_small['File Name'].iloc[:256].tolist()
mel_ds = MelDataset(MELS_48K_DIR, subset_fns)
mel_dl = torch.utils.data.DataLoader(mel_ds, batch_size=8, shuffle=True)

vqvae = SimpleVQVAE().to(device)
opt = torch.optim.Adam(vqvae.parameters(), lr=1e-3)

# Very small warm-up training loop (extend as needed)
for step, x in enumerate(mel_dl):
    x = x.to(device)
    recon, codes, loss = vqvae(x)
    opt.zero_grad()
    loss.backward()
    opt.step()
    if step % 10 == 0:
        print(f'step {step}: loss={loss.item():.4f}')
    if step >= 50:  # keep short by default
        break

# Save code indices for a few examples
vq_tokens_dir = TOKENS_DIR / 'vqvae'
vq_tokens_dir.mkdir(exist_ok=True, parents=True)

vqvae.eval()
with torch.no_grad():
    for fn in subset_fns[:64]:
        mel_path = MELS_48K_DIR / (Path(fn).stem + '.npy')
        if not mel_path.exists():
            continue
        mel = np.load(mel_path).T  # (T, F)
        x = torch.from_numpy(mel).unsqueeze(0).unsqueeze(0).to(device)
        _, codes, _ = vqvae(x)
        codes_np = codes.squeeze(0).cpu().numpy().astype(np.int16)
        np.save(vq_tokens_dir / f'vqvae_{Path(fn).stem}.npy', codes_np)


## 3. Continuous representations with AST (Audio Spectrogram Transformer)

We obtain continuous embeddings from an Audio Spectrogram Transformer (AST), e.g., pretrained on AudioSet.
These can be used directly for downstream classification or captioning without discretization.

In [None]:
from transformers import AutoFeatureExtractor, ASTModel

AST_MODEL_NAME = 'MIT/ast-finetuned-audioset-10-10-0.4593'
ast_extractor = AutoFeatureExtractor.from_pretrained(AST_MODEL_NAME)
ast_model = ASTModel.from_pretrained(AST_MODEL_NAME).to(device).eval()

def load_mel_as_ast_input(mel_path: Path) -> Dict[str, torch.Tensor]:
    mel = np.load(mel_path)  # (n_mels, T)
    inputs = ast_extractor(
        mel,
        sampling_rate=48_000,  # nominal value when passing features directly
        return_tensors='pt',
    )
    return {k: v.to(device) for k, v in inputs.items()}

for fn in ann_small['File Name'].iloc[:128]:
    mel_path = MELS_48K_DIR / (Path(fn).stem + '.npy')
    if not mel_path.exists():
        continue
    inputs = load_mel_as_ast_input(mel_path)
    with torch.no_grad():
        out = ast_model(**inputs)
    pooled = out.pooler_output.squeeze(0).cpu().numpy().astype(np.float32)
    np.save(AST_DIR / f'ast_{Path(fn).stem}.npy', pooled)

len(list(AST_DIR.glob('*.npy')))
