# Tutorial 3: Inference and Beam Search

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/commanderjcc/torchlingo/blob/main/docs/docs/tutorials/03-inference-and-beamsearch.ipynb)

Generate translations using greedy and beam search decoding strategies.

**⚠️ Prerequisites:** This tutorial requires the model checkpoint from Tutorial 2. Run Tutorial 2 first!

**⚡ Running in Google Colab?** Make sure to:
1. Go to **Runtime → Change runtime type → GPU** (optional)
2. Uncomment and run the `%pip install torchlingo` cell below
3. Run Tutorial 2 first to create the model checkpoint

In [None]:
# Install TorchLingo (uncomment in Google Colab)
# %pip install torchlingo

# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Set device (will use GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else 
                      "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Import TorchLingo modules
from pathlib import Path
import torch.nn.functional as F
from torchlingo.models import SimpleTransformer

print("✓ Imports successful!")

In [None]:
# Load the model from Tutorial 2
ckpt_path = Path("checkpoints/tiny_model.pt")

if not ckpt_path.exists():
    print("⚠️ No checkpoint found. Please run Tutorial 2 first!")
else:
    checkpoint = torch.load(ckpt_path, map_location=device)
    
    model = SimpleTransformer(
        src_vocab_size=len(checkpoint['src_vocab']),
        tgt_vocab_size=len(checkpoint['tgt_vocab']),
        **checkpoint['config'],
    ).to(device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    src_vocab = checkpoint['src_vocab']
    tgt_vocab = checkpoint['tgt_vocab']
    
    print(f"✓ Model loaded!")

## Part 1: Greedy Decoding (Review)

Greedy decoding picks the most likely token at each step.

In [None]:
def greedy_decode(model, src_sentence, src_vocab, tgt_vocab, device, max_len=20):
    """Generate translation using greedy decoding.
    
    At each step, pick the single most likely next token.
    """
    model.eval()
    
    # Encode source
    src_indices = src_vocab.encode(src_sentence, add_special_tokens=True)
    src_tensor = torch.tensor([src_indices]).to(device)
    
    with torch.no_grad():
        memory = model.encode(src_tensor)
    
    # Decode
    output_indices = [tgt_vocab.sos_idx]
    
    for _ in range(max_len):
        tgt_tensor = torch.tensor([output_indices]).to(device)
        
        with torch.no_grad():
            logits = model.decode(tgt_tensor, memory)
        
        # Greedy: pick argmax
        next_token = logits[0, -1, :].argmax().item()
        output_indices.append(next_token)
        
        if next_token == tgt_vocab.eos_idx:
            break
    
    return tgt_vocab.decode(output_indices, skip_special_tokens=True)

In [None]:
# Test greedy decoding
test_sentence = "Hello world"
translation = greedy_decode(model, test_sentence, src_vocab, tgt_vocab, device)
print(f"Greedy: '{test_sentence}' → '{translation}'")

### Greedy Limitations

Greedy decoding can get stuck in suboptimal paths:

```
Step 1: P("El") = 0.4, P("La") = 0.35, P("Un") = 0.25
        → Pick "El" (highest)
        
Step 2: P("gato"|"El") = 0.3, P("perro"|"El") = 0.25, ...
        → But maybe "La casa" would have been better overall!
```

Greedy only considers one path—it can't backtrack.

## Part 2: Beam Search

Beam search keeps track of multiple hypotheses ("beams") and picks the best complete sequence.

In [None]:
def beam_search_decode(
    model, src_sentence, src_vocab, tgt_vocab, device,
    beam_size=3, max_len=20, length_penalty=0.6
):
    """Generate translation using beam search.
    
    Keeps beam_size hypotheses at each step and returns the best one.
    
    Args:
        beam_size: Number of hypotheses to keep
        length_penalty: Penalize/reward longer sequences (alpha in paper)
    """
    model.eval()
    
    # Encode source
    src_indices = src_vocab.encode(src_sentence, add_special_tokens=True)
    src_tensor = torch.tensor([src_indices]).to(device)
    
    with torch.no_grad():
        memory = model.encode(src_tensor)
    
    # Initialize beams: (sequence, log_prob)
    beams = [([tgt_vocab.sos_idx], 0.0)]
    completed = []
    
    for _ in range(max_len):
        all_candidates = []
        
        for seq, score in beams:
            # Skip completed sequences
            if seq[-1] == tgt_vocab.eos_idx:
                completed.append((seq, score))
                continue
            
            # Get probabilities for next token
            tgt_tensor = torch.tensor([seq]).to(device)
            
            with torch.no_grad():
                logits = model.decode(tgt_tensor, memory)
            
            log_probs = F.log_softmax(logits[0, -1, :], dim=-1)
            
            # Get top beam_size candidates
            topk_log_probs, topk_indices = log_probs.topk(beam_size)
            
            for log_prob, idx in zip(topk_log_probs, topk_indices):
                new_seq = seq + [idx.item()]
                new_score = score + log_prob.item()
                all_candidates.append((new_seq, new_score))
        
        # Keep top beam_size candidates
        all_candidates.sort(key=lambda x: x[1], reverse=True)
        beams = all_candidates[:beam_size]
        
        # Stop if all beams are completed
        if not beams:
            break
    
    # Add any remaining beams to completed
    completed.extend(beams)
    
    # Apply length penalty and pick best
    def score_with_length_penalty(seq, score):
        length = len(seq)
        return score / (length ** length_penalty)
    
    best_seq, best_score = max(
        completed, 
        key=lambda x: score_with_length_penalty(x[0], x[1])
    )
    
    return tgt_vocab.decode(best_seq, skip_special_tokens=True)

In [None]:
# Test beam search
test_sentence = "Hello world"

greedy_result = greedy_decode(model, test_sentence, src_vocab, tgt_vocab, device)
beam_result = beam_search_decode(model, test_sentence, src_vocab, tgt_vocab, device, beam_size=3)

print(f"Input:  '{test_sentence}'")
print(f"Greedy: '{greedy_result}'")
print(f"Beam-3: '{beam_result}'")

## Part 3: Comparing Strategies

In [None]:
# Compare on multiple sentences
test_sentences = [
    "Hello world",
    "Good morning",
    "Thank you",
    "I love you",
    "The cat sleeps",
]

print(f"{'Input':<20} {'Greedy':<20} {'Beam-3':<20}")
print("-" * 60)

for src in test_sentences:
    greedy = greedy_decode(model, src, src_vocab, tgt_vocab, device)
    beam = beam_search_decode(model, src, src_vocab, tgt_vocab, device)
    print(f"{src:<20} {greedy:<20} {beam:<20}")

In [None]:
# Effect of beam size
test_sentence = "Hello world"

print(f"Input: '{test_sentence}'")
print("-" * 40)
for beam_size in [1, 2, 3, 5, 10]:
    result = beam_search_decode(
        model, test_sentence, src_vocab, tgt_vocab, device,
        beam_size=beam_size
    )
    print(f"Beam-{beam_size:2d}: '{result}'")

## Part 4: BLEU Score Evaluation

BLEU (Bilingual Evaluation Understudy) measures translation quality by comparing n-gram overlap.

In [None]:
# Install sacrebleu if needed
try:
    from sacrebleu.metrics import BLEU
    print("sacrebleu is installed!")
except ImportError:
    print("Installing sacrebleu...")
    !pip install sacrebleu
    from sacrebleu.metrics import BLEU

In [None]:
from sacrebleu.metrics import BLEU

# Our test data
sources = [
    "Hello world",
    "Good morning",
    "Thank you",
    "I love you",
]

references = [
    "Hola mundo",
    "Buenos días",
    "Gracias",
    "Te amo",
]

# Generate translations
greedy_translations = [greedy_decode(model, s, src_vocab, tgt_vocab, device) for s in sources]
beam_translations = [beam_search_decode(model, s, src_vocab, tgt_vocab, device) for s in sources]

# Calculate BLEU
bleu = BLEU()

greedy_bleu = bleu.corpus_score(greedy_translations, [references])
beam_bleu = bleu.corpus_score(beam_translations, [references])

print(f"BLEU Scores:")
print(f"  Greedy: {greedy_bleu.score:.2f}")
print(f"  Beam-3: {beam_bleu.score:.2f}")

In [None]:
# Detailed comparison
print(f"{'Source':<20} {'Reference':<20} {'Greedy':<20} {'Beam':<20}")
print("-" * 80)

for src, ref, greedy, beam in zip(sources, references, greedy_translations, beam_translations):
    print(f"{src:<20} {ref:<20} {greedy:<20} {beam:<20}")

## Understanding BLEU

BLEU measures n-gram precision:

| Score | Quality |
|-------|------|
| < 10 | Almost unusable |
| 10-20 | Gist is clear |
| 20-30 | Understandable |
| 30-40 | Good quality |
| 40-50 | High quality |
| > 50 | Very high quality |

⚠️ Our toy model with tiny data won't achieve great BLEU scores—that's expected!

## Summary

You've learned:

1. **Greedy decoding**: Fast but can miss better translations
2. **Beam search**: Explores multiple paths, often better results
3. **Length penalty**: Prevents beam search from preferring short sequences
4. **BLEU score**: Standard metric for translation quality

## Key Takeaways

- **Beam size 3-5** is usually sufficient (diminishing returns after)
- **Length penalty** around 0.6-1.0 works well
- **BLEU** is useful but not perfect—humans judge translation differently

## What's Next?

- Explore the [API Reference](../reference/index.md) for more details
- Learn about [SentencePiece](../reference/preprocessing/sentencepiece.md) for better tokenization
- Try training on a real dataset!