# Beam Search Decoding

## Problem Statement

Implement beam search, a breadth-first search algorithm that maintains the top-k most promising partial sequences at each decoding step.

## Background

### Why Not Greedy Decoding?

**Greedy decoding** selects the highest-probability token at each step:
- Fast (O(V) per step)
- But can miss globally optimal sequences
- "The cat sat" might be better than "The dog..." even if P(dog|The) > P(cat|The)

**Exhaustive search** considers all possible sequences:
- Guarantees optimal solution
- But exponential: O(V^T) for vocabulary V and length T
- Completely infeasible

### Beam Search: A Middle Ground

**Beam search** maintains B (beam width) best partial sequences:
- At each step, expand all B sequences by all V tokens = B*V candidates
- Keep only top B by cumulative log probability
- Complexity: O(B*V*T) - linear in sequence length!

### Key Concepts

1. **Beam Width (B)**: Number of candidates to keep (typically 4-10)
2. **Log Probabilities**: Use log-sum to avoid numerical underflow
3. **Length Normalization**: Longer sequences have lower log-probs; normalize to compare fairly
4. **Early Stopping**: Stop beams that generate EOS token

## Mathematical Formulation

At step $t$, we have $B$ partial sequences $\{y^{(1)}_{1:t}, ..., y^{(B)}_{1:t}\}$.

For each sequence $y^{(b)}$:
1. Compute next-token distribution: $P(y_{t+1} | y^{(b)}_{1:t})$
2. Score all extensions: $\log P(y^{(b)}_{1:t}) + \log P(y_{t+1} | y^{(b)}_{1:t})$

Select top $B$ from all $B \times V$ candidates based on cumulative score.

**Length Normalization**:
$$\text{score}(y) = \frac{\log P(y)}{|y|^\alpha}$$

where $\alpha \in [0, 1]$ controls normalization strength (typically 0.6-0.7).

## Learning Objectives

1. Understand why beam search outperforms greedy decoding
2. Implement basic beam search with log probabilities
3. Add length normalization for fair comparison
4. Handle EOS tokens and early stopping
5. Know the tradeoffs of different beam widths

## Requirements

1. `beam_search_step()` - One step of beam search expansion
2. `beam_search()` - Full beam search decoding
3. `beam_search_with_length_norm()` - With length normalization

## Hints

1. Use `torch.topk()` to get top-B candidates efficiently
2. Track both token sequences and cumulative log-probs
3. Use `-float('inf')` for terminated beams to exclude them
4. Consider returning top-N final sequences, not just top-1

In [None]:
import torch
import torch.nn.functional as F
from typing import List, Tuple, Optional, Callable
from dataclasses import dataclass

torch.manual_seed(42)

## Implementation

In [None]:
@dataclass
class BeamHypothesis:
    """A single beam hypothesis (partial or complete sequence)."""
    tokens: torch.Tensor  # Token IDs
    log_prob: float       # Cumulative log probability
    is_finished: bool     # Whether EOS was generated
    
    def score(self, length_penalty: float = 1.0) -> float:
        """Compute length-normalized score."""
        length = len(self.tokens)
        if length_penalty == 0:
            return self.log_prob
        return self.log_prob / (length ** length_penalty)

In [None]:
def beam_search_step(
    model_fn: Callable[[torch.Tensor], torch.Tensor],
    beams: List[BeamHypothesis],
    beam_width: int,
    vocab_size: int,
    eos_token_id: Optional[int] = None,
) -> List[BeamHypothesis]:
    """
    Perform one step of beam search expansion.
    
    Args:
        model_fn: Function that takes token sequences and returns logits
                  Input: (batch_size, seq_len) -> Output: (batch_size, seq_len, vocab_size)
        beams: Current list of beam hypotheses
        beam_width: Number of beams to keep
        vocab_size: Size of vocabulary
        eos_token_id: End-of-sequence token ID (optional)
        
    Returns:
        List of new beam hypotheses after expansion
    """
    # Separate finished and active beams
    finished_beams = [b for b in beams if b.is_finished]
    active_beams = [b for b in beams if not b.is_finished]
    
    if not active_beams:
        return finished_beams
    
    # Stack all active beam tokens for batched inference
    # Pad to same length for batching
    max_len = max(len(b.tokens) for b in active_beams)
    batch_tokens = torch.zeros(len(active_beams), max_len, dtype=torch.long)
    for i, beam in enumerate(active_beams):
        batch_tokens[i, :len(beam.tokens)] = beam.tokens
    
    # Get logits from model
    with torch.no_grad():
        logits = model_fn(batch_tokens)  # (batch, seq_len, vocab)
    
    # Get last token logits and convert to log probs
    # Use actual sequence lengths
    next_token_logits = torch.stack([
        logits[i, len(active_beams[i].tokens) - 1] 
        for i in range(len(active_beams))
    ])  # (num_active, vocab)
    log_probs = F.log_softmax(next_token_logits, dim=-1)
    
    # For each beam, compute scores for all possible next tokens
    all_candidates = []
    
    for beam_idx, beam in enumerate(active_beams):
        for token_id in range(vocab_size):
            new_log_prob = beam.log_prob + log_probs[beam_idx, token_id].item()
            new_tokens = torch.cat([beam.tokens, torch.tensor([token_id])])
            is_finished = (eos_token_id is not None and token_id == eos_token_id)
            
            all_candidates.append(BeamHypothesis(
                tokens=new_tokens,
                log_prob=new_log_prob,
                is_finished=is_finished,
            ))
    
    # Add finished beams to candidates (they don't expand)
    all_candidates.extend(finished_beams)
    
    # Sort by log probability and keep top beam_width
    all_candidates.sort(key=lambda x: x.log_prob, reverse=True)
    
    return all_candidates[:beam_width]

In [None]:
def beam_search(
    model_fn: Callable[[torch.Tensor], torch.Tensor],
    prompt_tokens: torch.Tensor,
    beam_width: int,
    max_length: int,
    vocab_size: int,
    eos_token_id: Optional[int] = None,
    length_penalty: float = 0.0,
) -> List[BeamHypothesis]:
    """
    Perform beam search decoding.
    
    Args:
        model_fn: Model function for next-token prediction
        prompt_tokens: Initial token sequence (1D tensor)
        beam_width: Number of beams to maintain
        max_length: Maximum generation length
        vocab_size: Vocabulary size
        eos_token_id: EOS token ID (optional)
        length_penalty: Length normalization factor (0 = no normalization)
        
    Returns:
        List of final beam hypotheses, sorted by score
    """
    # Initialize with single beam containing prompt
    beams = [BeamHypothesis(
        tokens=prompt_tokens.clone(),
        log_prob=0.0,
        is_finished=False,
    )]
    
    # Generate tokens
    for step in range(max_length):
        beams = beam_search_step(
            model_fn=model_fn,
            beams=beams,
            beam_width=beam_width,
            vocab_size=vocab_size,
            eos_token_id=eos_token_id,
        )
        
        # Early stop if all beams are finished
        if all(b.is_finished for b in beams):
            break
    
    # Sort by score (with length penalty)
    beams.sort(key=lambda x: x.score(length_penalty), reverse=True)
    
    return beams

In [None]:
def beam_search_efficient(
    model_fn: Callable[[torch.Tensor], torch.Tensor],
    prompt_tokens: torch.Tensor,
    beam_width: int,
    max_length: int,
    vocab_size: int,
    eos_token_id: Optional[int] = None,
    length_penalty: float = 0.0,
) -> List[Tuple[torch.Tensor, float]]:
    """
    Efficient beam search using tensor operations.
    Maintains beams as tensors rather than lists.
    
    Returns:
        List of (tokens, score) tuples
    """
    device = prompt_tokens.device
    batch_size = 1  # Single sequence beam search
    
    # Initialize beams: (beam_width, seq_len)
    # Start with single beam containing prompt, replicate for beam_width
    beam_tokens = prompt_tokens.unsqueeze(0).expand(beam_width, -1).clone()
    beam_log_probs = torch.zeros(beam_width, device=device)
    beam_log_probs[1:] = float('-inf')  # Only first beam is valid initially
    
    finished_beams = []  # Store (tokens, score) for finished sequences
    
    for step in range(max_length):
        # Get logits for all beams
        with torch.no_grad():
            logits = model_fn(beam_tokens)  # (beam_width, seq_len, vocab)
        
        # Get next-token log probs
        next_log_probs = F.log_softmax(logits[:, -1, :], dim=-1)  # (beam_width, vocab)
        
        # Compute scores for all beam x vocab combinations
        # (beam_width, 1) + (beam_width, vocab) = (beam_width, vocab)
        candidate_scores = beam_log_probs.unsqueeze(-1) + next_log_probs
        
        # Flatten and get top-k
        flat_scores = candidate_scores.view(-1)  # (beam_width * vocab)
        top_scores, top_indices = torch.topk(flat_scores, beam_width)
        
        # Decode indices back to (beam_idx, token_id)
        beam_indices = top_indices // vocab_size
        token_indices = top_indices % vocab_size
        
        # Create new beams
        new_beam_tokens = torch.cat([
            beam_tokens[beam_indices],
            token_indices.unsqueeze(-1)
        ], dim=-1)
        
        # Check for EOS
        if eos_token_id is not None:
            eos_mask = token_indices == eos_token_id
            for i in range(beam_width):
                if eos_mask[i]:
                    length = new_beam_tokens[i].shape[0]
                    score = top_scores[i].item() / (length ** length_penalty) if length_penalty > 0 else top_scores[i].item()
                    finished_beams.append((new_beam_tokens[i].clone(), score))
                    top_scores[i] = float('-inf')  # Mark as done
        
        beam_tokens = new_beam_tokens
        beam_log_probs = top_scores
        
        # Early stop if all beams finished
        if (beam_log_probs == float('-inf')).all():
            break
    
    # Add remaining active beams to finished
    for i in range(beam_width):
        if beam_log_probs[i] > float('-inf'):
            length = beam_tokens[i].shape[0]
            score = beam_log_probs[i].item() / (length ** length_penalty) if length_penalty > 0 else beam_log_probs[i].item()
            finished_beams.append((beam_tokens[i].clone(), score))
    
    # Sort by score
    finished_beams.sort(key=lambda x: x[1], reverse=True)
    
    return finished_beams

## Testing

In [None]:
# Create a simple mock language model for testing
class MockLanguageModel:
    """
    A simple mock LM that has deterministic preferences.
    Vocabulary: 0='<pad>', 1='<eos>', 2='the', 3='cat', 4='dog', 5='sat', 6='ran'
    
    Preferences:
    - After 'the': prefers 'cat' > 'dog'
    - After 'cat': prefers 'sat' > 'ran'
    - After 'dog': prefers 'ran' > 'sat'
    - After 'sat'/'ran': prefers '<eos>'
    """
    def __init__(self):
        self.vocab_size = 7
        self.eos_token_id = 1
    
    def __call__(self, tokens: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len = tokens.shape
        logits = torch.zeros(batch_size, seq_len, self.vocab_size)
        
        for b in range(batch_size):
            for t in range(seq_len):
                last_token = tokens[b, t].item()
                
                # Set logits based on previous token
                if last_token == 2:  # 'the'
                    logits[b, t, 3] = 2.0  # 'cat'
                    logits[b, t, 4] = 1.5  # 'dog'
                elif last_token == 3:  # 'cat'
                    logits[b, t, 5] = 2.0  # 'sat'
                    logits[b, t, 6] = 1.0  # 'ran'
                elif last_token == 4:  # 'dog'
                    logits[b, t, 6] = 2.0  # 'ran'
                    logits[b, t, 5] = 1.0  # 'sat'
                elif last_token in [5, 6]:  # 'sat' or 'ran'
                    logits[b, t, 1] = 3.0  # '<eos>'
                else:
                    logits[b, t, 2] = 1.0  # 'the'
        
        return logits

model = MockLanguageModel()
print(f"Vocabulary size: {model.vocab_size}")
print(f"Token mapping: 0='<pad>', 1='<eos>', 2='the', 3='cat', 4='dog', 5='sat', 6='ran'")

In [None]:
# Test basic beam search
print("=== Testing Basic Beam Search ===")

prompt = torch.tensor([2])  # Start with 'the'

# Run beam search with beam_width=2
results = beam_search(
    model_fn=model,
    prompt_tokens=prompt,
    beam_width=2,
    max_length=3,
    vocab_size=model.vocab_size,
    eos_token_id=model.eos_token_id,
)

token_names = {0: '<pad>', 1: '<eos>', 2: 'the', 3: 'cat', 4: 'dog', 5: 'sat', 6: 'ran'}

print(f"\nTop {len(results)} beams:")
for i, beam in enumerate(results):
    tokens = [token_names[t.item()] for t in beam.tokens]
    print(f"  {i+1}. {' '.join(tokens)} (log_prob={beam.log_prob:.3f}, finished={beam.is_finished})")

# The best sequence should be "the cat sat <eos>" 
# because cat is preferred after 'the', and sat is preferred after 'cat'
best_tokens = [token_names[t.item()] for t in results[0].tokens]
print(f"\nBest sequence: {' '.join(best_tokens)}")
assert 'cat' in best_tokens, "Best sequence should contain 'cat'"
print("Basic beam search test passed!")

In [None]:
# Test greedy vs beam search
print("\n=== Greedy vs Beam Search ===")

# Greedy is equivalent to beam_width=1
greedy_results = beam_search(
    model_fn=model,
    prompt_tokens=prompt,
    beam_width=1,
    max_length=3,
    vocab_size=model.vocab_size,
    eos_token_id=model.eos_token_id,
)

beam_results = beam_search(
    model_fn=model,
    prompt_tokens=prompt,
    beam_width=4,
    max_length=3,
    vocab_size=model.vocab_size,
    eos_token_id=model.eos_token_id,
)

greedy_tokens = [token_names[t.item()] for t in greedy_results[0].tokens]
beam_tokens = [token_names[t.item()] for t in beam_results[0].tokens]

print(f"Greedy (beam=1): {' '.join(greedy_tokens)} (score={greedy_results[0].log_prob:.3f})")
print(f"Beam (beam=4):   {' '.join(beam_tokens)} (score={beam_results[0].log_prob:.3f})")

# In this case they should be the same since 'cat' > 'dog' at every step
print("Greedy vs Beam comparison complete!")

In [None]:
# Test efficient beam search
print("\n=== Testing Efficient Beam Search ===")

efficient_results = beam_search_efficient(
    model_fn=model,
    prompt_tokens=prompt,
    beam_width=4,
    max_length=3,
    vocab_size=model.vocab_size,
    eos_token_id=model.eos_token_id,
    length_penalty=0.0,
)

print(f"\nTop {min(4, len(efficient_results))} sequences:")
for i, (tokens, score) in enumerate(efficient_results[:4]):
    token_str = [token_names[t.item()] for t in tokens]
    print(f"  {i+1}. {' '.join(token_str)} (score={score:.3f})")

print("Efficient beam search test passed!")

In [None]:
# Test length penalty effect
print("\n=== Testing Length Penalty ===")

# Without length penalty, longer sequences are penalized
results_no_penalty = beam_search(
    model_fn=model,
    prompt_tokens=prompt,
    beam_width=4,
    max_length=5,
    vocab_size=model.vocab_size,
    eos_token_id=None,  # No early stopping
    length_penalty=0.0,
)

# With length penalty, we normalize by length
results_with_penalty = beam_search(
    model_fn=model,
    prompt_tokens=prompt,
    beam_width=4,
    max_length=5,
    vocab_size=model.vocab_size,
    eos_token_id=None,
    length_penalty=0.6,
)

print("Without length penalty (prefers shorter):")
for beam in results_no_penalty[:2]:
    tokens = [token_names[t.item()] for t in beam.tokens]
    print(f"  {' '.join(tokens)} (raw={beam.log_prob:.3f})")

print("\nWith length penalty (fairer comparison):")
for beam in results_with_penalty[:2]:
    tokens = [token_names[t.item()] for t in beam.tokens]
    print(f"  {' '.join(tokens)} (normalized={beam.score(0.6):.3f})")

print("Length penalty test complete!")

In [None]:
# Demonstrate beam width effect
print("\n=== Effect of Beam Width ===")

for beam_width in [1, 2, 4, 8]:
    results = beam_search(
        model_fn=model,
        prompt_tokens=prompt,
        beam_width=beam_width,
        max_length=3,
        vocab_size=model.vocab_size,
        eos_token_id=model.eos_token_id,
    )
    
    best_tokens = [token_names[t.item()] for t in results[0].tokens]
    unique_results = len(set(tuple(b.tokens.tolist()) for b in results))
    print(f"Beam width={beam_width}: Best='{' '.join(best_tokens)}', Unique sequences={unique_results}")

print("\nLarger beam width explores more alternatives but takes more compute.")

In [None]:
print("\n" + "=" * 50)
print("All Beam Search tests passed!")
print("=" * 50)

## Summary

### Key Concepts

1. **Beam Search** maintains top-B candidates at each step
   - Trade-off between quality and computation
   - Complexity: O(B * V * T) vs O(V^T) for exhaustive

2. **Log Probabilities** prevent numerical underflow
   - Sum log-probs instead of multiplying probs
   - P(seq) = P(t1) * P(t2|t1) * ... becomes log P(t1) + log P(t2|t1) + ...

3. **Length Normalization** enables fair comparison
   - Without it, shorter sequences always score higher
   - Typical alpha: 0.6-0.7

4. **Early Stopping** on EOS improves efficiency
   - Finished beams don't expand further
   - Stop when all beams reach EOS

### Beam Width Selection

| Beam Width | Use Case | Notes |
|------------|----------|-------|
| 1 (greedy) | Fast inference, low diversity | May miss better sequences |
| 4-5 | Standard translation/generation | Good quality/speed balance |
| 10-20 | High-quality translation | Diminishing returns beyond 10 |

### Common Variations

- **Diverse Beam Search**: Penalize similar beams
- **Constrained Beam Search**: Force certain tokens
- **Top-k/Top-p + Beam**: Combine sampling with search

## Interview Tips

1. **Why log probabilities?** Numerical stability - probabilities become tiny for long sequences

2. **Greedy vs Beam?** Greedy is O(V*T), beam is O(B*V*T) but finds better sequences

3. **Optimal beam width?** Diminishing returns past 5-10; depends on task

4. **Length normalization?** Without it, beam search prefers short sequences

5. **When NOT to use beam search?** Creative tasks where diversity matters (use sampling instead)

6. **Time complexity?** O(B * V * T) where B=beam width, V=vocab, T=length

7. **Space complexity?** O(B * T) for storing beam tokens

## References

1. [Sequence to Sequence Learning with Neural Networks](https://arxiv.org/abs/1409.3215) - Sutskever et al., 2014
2. [Google's Neural Machine Translation System](https://arxiv.org/abs/1609.08144) - Wu et al., 2016 (length penalty)
3. [A Simple, Fast Diverse Decoding Algorithm for Neural Generation](https://arxiv.org/abs/1611.08562) - Diverse beam search
4. [HuggingFace - Generation Strategies](https://huggingface.co/docs/transformers/generation_strategies)