# 08 â€“ Next-Token Language Model over Acoustic Tokens

Trains a **next-token prediction model** (a small Transformer language model)
over discrete acoustic token sequences derived in `05_Tokenization_Strategies.ipynb`.

treat each file's token sequence (ex. wav2vec2+k-means or VQ-VAE codes) as a "sentence",
train the model to predict token \(t_{i+1}\) given the history \(t_{\le i}\)

Reports:
- **Validation cross-entropy loss**
- **Perplexity** \(= e^{\text{loss}}\)
- **Token-level accuracy** on the validation set.

Can also generate sample continuations from the trained model to qualitatively inspect behavior.


In [1]:
from __future__ import annotations

from pathlib import Path
from typing import List, Tuple

import math
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split

ROOT = Path.cwd().resolve()
DATA_DIR = ROOT / "data"
DERIVED_DIR = ROOT / "derived"
TOKENS_DIR = DERIVED_DIR / "tokens"
KMEANS_DIR = TOKENS_DIR / "k_means"
VQ_TOKENS_DIR = TOKENS_DIR / "vqvae"

# choose which tokenization to use for the LM: "kmeans" or "vqvae"
TOKEN_TYPE = "kmeans"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
print("ROOT:", ROOT)
print("Using tokenization:", TOKEN_TYPE)


Device: cpu
ROOT: /Users/mahikacalyanakoti/Downloads/College/Year4/Year4Sem1/ESE5460/bat-llm/starter_code
Using tokenization: kmeans


## Load token sequences

We load per-file token sequences from `derived/tokens/`.

- For `TOKEN_TYPE == "kmeans"`, we use `w2v_kmeans_<stem>.npy` (K=128 clusters).
- For `TOKEN_TYPE == "vqvae"`, we use `vqvae_<stem>.npy` (default 256 codes).

Each file contributes a 1D integer sequence of token IDs.


In [2]:
def load_token_sequences(token_type: str = TOKEN_TYPE) -> Tuple[List[np.ndarray], int]:
    """Load all available token sequences and infer vocabulary size.

    Returns
    -------
    sequences : list of np.ndarray
        Each entry is a 1D int array of token IDs for one file.
    vocab_size : int
        Maximum token ID + 1 across all sequences (upper bound on vocabulary).
    """

    sequences: List[np.ndarray] = []
    max_token_id = -1

    if token_type == "kmeans":
        pattern = "w2v_kmeans_"
        base_dir = KMEANS_DIR
    elif token_type == "vqvae":
        pattern = "vqvae_"
        base_dir = VQ_TOKENS_DIR
    else:
        raise ValueError(f"Unsupported TOKEN_TYPE: {token_type}")

    if not base_dir.exists():
        raise FileNotFoundError(
            f"Token directory {base_dir} not found. Run 05_Tokenization_Strategies.ipynb first."
        )

    paths = sorted(base_dir.glob("*.npy"))
    if not paths:
        raise RuntimeError(f"No token files found in {base_dir}.")

    for p in paths:
        if pattern not in p.name:
            continue
        arr = np.load(p).astype(np.int64)
        if arr.ndim != 1 or arr.size < 2:
            continue
        max_token_id = max(max_token_id, int(arr.max()))
        sequences.append(arr)

    if not sequences:
        raise RuntimeError(f"No valid token sequences loaded from {base_dir}.")

    vocab_size = max_token_id + 1
    print(f"Loaded {len(sequences)} sequences from {base_dir} with vocab_size={vocab_size}.")
    return sequences, vocab_size


sequences, vocab_size = load_token_sequences(TOKEN_TYPE)
len(sequences), vocab_size


Loaded 10000 sequences from /Users/mahikacalyanakoti/Downloads/College/Year4/Year4Sem1/ESE5460/bat-llm/starter_code/derived/tokens/k_means with vocab_size=128.


(10000, 128)

## Dataset for next-token prediction

We turn each token sequence into many training examples of length `seq_len`:

- Input: tokens `[t_i, ..., t_{i+L-1}]`
- Target: tokens `[t_{i+1}, ..., t_{i+L}]`

We step with a configurable `stride` to control how many overlapping windows we create.


In [3]:
class NextTokenDataset(Dataset):
    def __init__(
        self,
        sequences: List[np.ndarray],
        seq_len: int = 64,
        stride: int = 4,
    ) -> None:
        self.sequences = sequences
        self.seq_len = seq_len
        self.stride = max(1, stride)

        # precompute (sequence_idx, start) pairs
        indices: List[Tuple[int, int]] = []
        for si, seq in enumerate(sequences):
            n = len(seq)
            if n <= seq_len:
                continue
            # last start index that allows target to go up to t_{i+L}
            max_start = n - (seq_len + 1)
            if max_start < 0:
                continue
            for start in range(0, max_start + 1, self.stride):
                indices.append((si, start))

        if not indices:
            raise RuntimeError("No training windows could be formed. Try shorter seq_len or stride.")

        self.indices = indices
        print(f"Dataset: {len(self.indices)} windows, seq_len={self.seq_len}, stride={self.stride}.")

    def __len__(self) -> int:
        return len(self.indices)

    def __getitem__(self, idx: int):
        si, start = self.indices[idx]
        seq = self.sequences[si]
        x = torch.from_numpy(seq[start : start + self.seq_len]).long()
        y = torch.from_numpy(seq[start + 1 : start + 1 + self.seq_len]).long()
        return x, y


SEQ_LEN = 64
STRIDE = 4

# train/val split at the sequence level to avoid leakage
num_sequences = len(sequences)
val_frac = 0.1
num_val = max(1, int(num_sequences * val_frac))
num_train = num_sequences - num_val

all_indices = list(range(num_sequences))
rng = np.random.default_rng(42)
rng.shuffle(all_indices)
train_indices = all_indices[:num_train]
val_indices = all_indices[num_train:]

train_seqs = [sequences[i] for i in train_indices]
val_seqs = [sequences[i] for i in val_indices]

train_dataset = NextTokenDataset(train_seqs, seq_len=SEQ_LEN, stride=STRIDE)
val_dataset = NextTokenDataset(val_seqs, seq_len=SEQ_LEN, stride=STRIDE)

len(train_dataset), len(val_dataset)


Dataset: 71935 windows, seq_len=64, stride=4.
Dataset: 7637 windows, seq_len=64, stride=4.


(71935, 7637)

## Transformer language model

Define a small Transformer-based language model that operates on token sequences.
It uses learned token and positional embeddings plus a stack of Transformer encoder layers
with a causal mask so each position can only attend to previous positions.


In [4]:
class TransformerLM(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 256,
        n_heads: int = 4,
        num_layers: int = 4,
        dim_feedforward: int = 512,
        max_seq_len: int = 256,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.max_seq_len = max_seq_len

        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

    def _generate_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
        # (seq_len, seq_len) with True in upper-right triangle (masked positions)
        mask = torch.triu(torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), diagonal=1)
        return mask

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass.

        Parameters
        ----------
        x : LongTensor of shape (batch, seq_len)

        Returns
        -------
        logits : FloatTensor of shape (batch, seq_len, vocab_size)
        """

        bsz, seq_len = x.shape
        if seq_len > self.max_seq_len:
            raise ValueError(f"seq_len={seq_len} exceeds max_seq_len={self.max_seq_len}")

        positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(bsz, seq_len)
        h = self.token_emb(x) + self.pos_emb(positions)

        src_mask = self._generate_causal_mask(seq_len, device=x.device)
        h = self.encoder(h, mask=src_mask)
        h = self.ln_f(h)
        logits = self.head(h)
        return logits


model = TransformerLM(vocab_size=vocab_size, d_model=256, n_heads=4, num_layers=4, dim_feedforward=512, max_seq_len=SEQ_LEN).to(device)
model


TransformerLM(
  (token_emb): Embedding(128, 256)
  (pos_emb): Embedding(64, 256)
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=512, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (head): Linear(in_features=256, out_features=128, bias=False)
)

## Training and evaluation

We train the model to minimize cross-entropy loss on the next-token prediction task.
After each epoch we compute validation loss, perplexity (exp of loss), and
token-level accuracy on the validation set.


In [None]:
BATCH_SIZE = 64
N_EPOCHS = 5
LR = 3e-4

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)


def evaluate(model: nn.Module, loader: DataLoader) -> Tuple[float, float]:
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    correct = 0

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            logits = model(x)  # (B, L, V)
            B, L, V = logits.shape
            logits_flat = logits.view(B * L, V)
            y_flat = y.view(B * L)

            loss = criterion(logits_flat, y_flat)
            total_loss += loss.item() * (B * L)
            total_tokens += B * L

            preds = logits_flat.argmax(dim=-1)
            correct += (preds == y_flat).sum().item()

    avg_loss = total_loss / max(1, total_tokens)
    acc = correct / max(1, total_tokens)
    return avg_loss, acc


best_val_loss = float("inf")

for epoch in range(1, N_EPOCHS + 1):
    model.train()
    running_loss = 0.0
    running_tokens = 0

    for x, y in train_loader:
        x = x.to(device)
        y = y.to(device)

        logits = model(x)
        B, L, V = logits.shape
        loss = criterion(logits.view(B * L, V), y.view(B * L))

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        running_loss += loss.item() * (B * L)
        running_tokens += B * L

    train_loss = running_loss / max(1, running_tokens)
    val_loss, val_acc = evaluate(model, val_loader)
    val_ppl = math.exp(val_loss)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_state = {k: v.cpu() for k, v in model.state_dict().items()}

    print(f"Epoch {epoch}/{N_EPOCHS} - train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | val_ppl={val_ppl:.2f} | val_acc={val_acc:.4f}")

# can load best model state
if 'best_state' in globals():
    model.load_state_dict(best_state)
    model.to(device)
    print("Loaded best model state (by val loss).")


## Sampling / generation

We can generate new token sequences by autoregressively sampling from the model's
predicted distribution, starting from an initial context window.


In [None]:
@torch.no_grad()
def generate_tokens(
    model: nn.Module,
    start_tokens: torch.Tensor,
    max_new_tokens: int = 50,
    temperature: float = 1.0,
    top_k: int | None = 50,
    max_seq_len: int = 256,
) -> torch.Tensor:
    """Generate a continuation of tokens from a starting context.

    Parameters
    ----------
    start_tokens : LongTensor of shape (seq_len,)
        Initial context tokens.
    max_seq_len : int
        Maximum context length to feed into the model.
    """

    model.eval()
    generated = start_tokens.clone().to(device).unsqueeze(0)  # (1, T)

    for _ in range(max_new_tokens):
        # Ensure we respect the model's maximum sequence length
        if generated.size(1) > max_seq_len:
            context = generated[:, -max_seq_len :]
        else:
            context = generated

        logits = model(context)  # (1, T, V)
        logits = logits[:, -1, :] / max(1e-6, temperature)  # (1, V)

        k = top_k
        if k is not None and k > 0:
            k = min(k, int(logits.size(-1)))
            values, indices = torch.topk(logits, k)
            probs = torch.zeros_like(logits).scatter_(-1, indices, values)
            probs = torch.softmax(probs, dim=-1)
        else:
            probs = torch.softmax(logits, dim=-1)

        next_token = torch.multinomial(probs, num_samples=1)  # (1, 1)
        generated = torch.cat([generated, next_token], dim=1)

    return generated.squeeze(0).cpu()


# EX: take a random validation sequence and generate a continuation
example_seq = val_seqs[0]
start = example_seq[:SEQ_LEN]

generated = generate_tokens(
    model,
    torch.from_numpy(start).long(),
    max_new_tokens=50,
    temperature=1.0,
    top_k=50,
    max_seq_len=SEQ_LEN,
)

generated[:SEQ_LEN], generated[SEQ_LEN:]
