# Transformer Model with Beam Search

This notebook implements and evaluates the transformer model with **beam search decoding** for improved caption generation.

## Beam Search vs. Greedy/Sampling

**Greedy decoding**: Picks the most likely token at each step (fast, but locally optimal)

**Sampling**: Uses temperature/nucleus sampling for diversity (creative, but less coherent)

**Beam search**: Maintains top-k hypotheses, exploring multiple paths (better quality, more coherent)

## What's Covered:
1. Implementing beam search for transformer model
2. Training the transformer with memory-efficient settings
3. Comparing beam search vs greedy vs sampling
4. Analyzing results and metrics

## Step 1: Setup and Imports

In [None]:
import os
os.environ['MKL_THREADING_LAYER'] = 'GNU'
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

import sys
from pathlib import Path

# Add parent directory to path
project_root = Path('..').absolute()
sys.path.insert(0, str(project_root))

import torch
torch.cuda.empty_cache()
import gc
gc.collect()

import json
import numpy as np
import math
from tqdm import tqdm
from src.models import create_model, TransformerModel
from src.dataset import create_dataloaders
from src.trainer import ModelTrainer
from src.utils import load_vocab, set_seed, get_device, count_parameters, make_json_serializable

print("✓ Imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Step 2: Implement Beam Search

Beam search maintains the top-k most likely sequences at each step.

In [None]:
def beam_search_generate(model, mel, beam_width=5, max_len=30, sos_idx=1, eos_idx=2, 
                         length_penalty=0.6, device='cuda'):
    """
    Generate captions using beam search
    
    Args:
        model: TransformerModel instance
        mel: Audio mel spectrogram (batch, 1, 64, 3000)
        beam_width: Number of beams to maintain
        max_len: Maximum caption length
        sos_idx: Start-of-sequence token index
        eos_idx: End-of-sequence token index
        length_penalty: Length normalization penalty (0.6-1.0 typical)
        device: Device to run on
    
    Returns:
        best_sequences: (batch, max_len) generated token IDs
        best_scores: (batch,) log probabilities of best sequences
    """
    model.eval()
    batch_size = mel.size(0)
    
    # Encode audio once (shared across all beams)
    with torch.no_grad():
        audio_features = model.encode_audio(mel)  # (batch, audio_len, d_model)
    
    # Initialize beams for each item in batch
    # Each beam: (sequence, score)
    all_best_sequences = []
    all_best_scores = []
    
    for b in range(batch_size):
        # Get audio features for this batch item
        audio_feat = audio_features[b:b+1]  # (1, audio_len, d_model)
        
        # Initialize beams: each beam is (sequence, score)
        beams = [(torch.tensor([sos_idx], device=device), 0.0)]
        completed_beams = []
        
        for step in range(max_len):
            candidates = []
            
            for seq, score in beams:
                # Skip if sequence already ended
                if seq[-1].item() == eos_idx:
                    completed_beams.append((seq, score))
                    continue
                
                # Prepare input: (1, current_len)
                seq_input = seq.unsqueeze(0)
                
                # Generate next token probabilities
                with torch.no_grad():
                    # Embed and add positional encoding
                    embedded = model.embedding(seq_input) * math.sqrt(model.d_model)
                    embedded = model.pos_encoder(embedded)
                    
                    # Create causal mask
                    tgt_mask = model.generate_square_subsequent_mask(seq_input.size(1), device)
                    
                    # Decode
                    output = model.transformer(
                        src=audio_feat,
                        tgt=embedded,
                        tgt_mask=tgt_mask
                    )
                    
                    # Get logits for last token
                    logits = model.output_projection(output[0, -1, :])  # (vocab_size,)
                    log_probs = torch.log_softmax(logits, dim=-1)
                
                # Get top-k next tokens
                top_log_probs, top_indices = torch.topk(log_probs, beam_width)
                
                # Create new candidate beams
                for log_prob, idx in zip(top_log_probs, top_indices):
                    new_seq = torch.cat([seq, idx.unsqueeze(0)])
                    new_score = score + log_prob.item()
                    candidates.append((new_seq, new_score))
            
            # If no candidates, break
            if not candidates:
                break
            
            # Select top beam_width candidates
            # Apply length penalty to scores
            candidates_with_penalty = []
            for seq, score in candidates:
                length_norm = ((5 + len(seq)) / 6) ** length_penalty
                normalized_score = score / length_norm
                candidates_with_penalty.append((seq, score, normalized_score))
            
            # Sort by normalized score and keep top beam_width
            candidates_with_penalty.sort(key=lambda x: x[2], reverse=True)
            beams = [(seq, score) for seq, score, _ in candidates_with_penalty[:beam_width]]
            
            # Check if all beams ended
            if all(seq[-1].item() == eos_idx for seq, _ in beams):
                completed_beams.extend(beams)
                break
        
        # Add remaining beams to completed
        completed_beams.extend(beams)
        
        # Select best beam with length normalization
        best_beam = None
        best_normalized_score = float('-inf')
        
        for seq, score in completed_beams:
            length_norm = ((5 + len(seq)) / 6) ** length_penalty
            normalized_score = score / length_norm
            if normalized_score > best_normalized_score:
                best_normalized_score = normalized_score
                best_beam = (seq, score)
        
        # Store best sequence
        if best_beam:
            all_best_sequences.append(best_beam[0])
            all_best_scores.append(best_beam[1])
        else:
            all_best_sequences.append(torch.tensor([sos_idx, eos_idx], device=device))
            all_best_scores.append(0.0)
    
    # Pad sequences to same length
    max_seq_len = max(len(seq) for seq in all_best_sequences)
    padded_sequences = []
    
    for seq in all_best_sequences:
        if len(seq) < max_seq_len:
            padding = torch.zeros(max_seq_len - len(seq), dtype=torch.long, device=device)
            seq = torch.cat([seq, padding])
        padded_sequences.append(seq)
    
    best_sequences = torch.stack(padded_sequences)
    best_scores = torch.tensor(all_best_scores, device=device)
    
    # Remove <sos> token from output
    best_sequences = best_sequences[:, 1:]
    
    return best_sequences, best_scores


# Add beam search method to TransformerModel
TransformerModel.beam_search_generate = beam_search_generate

print("✓ Beam search implementation ready!")

## Step 3: Setup and Load Data

In [None]:
set_seed(42)
device = get_device()

# Check GPU memory
if torch.cuda.is_available():
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
    free_memory = torch.cuda.mem_get_info()[0] / 1024**3
    print(f"\nGPU: {torch.cuda.get_device_name(0)}")
    print(f"Total GPU memory: {total_memory:.2f} GB")
    print(f"Free GPU memory: {free_memory:.2f} GB")
else:
    print("\n⚠️ No GPU available, using CPU (will be slower)")

In [None]:
print("Loading vocabulary...")
vocab = load_vocab('../vocab.json')
print(f"\n✓ Vocabulary size: {len(vocab)}")
print(f"  <pad>: {vocab['<pad>']}")
print(f"  <sos>: {vocab['<sos>']}")
print(f"  <eos>: {vocab['<eos>']}")
print(f"  <unk>: {vocab['<unk>']}")

In [None]:
print("Creating dataloaders...")
print("Using batch_size=8 for transformer\n")

train_loader, val_loader, eval_dataset = create_dataloaders(
    train_captions='../data/train_captions.json',
    val_captions='../data/val_captions.json',
    eval_captions='../data/eval_captions.json',
    train_features_dir='../features/mel/',
    val_features_dir='../features/mel/',
    eval_features_dir='../features/mel_eval/',
    vocab=vocab,
    batch_size=8,
    num_workers=2
)

print(f"\n✓ Dataloaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Eval samples: {len(eval_dataset)}")

## Step 4: Create or Load Transformer Model

In [None]:
import os.path as osp

# Check if pre-trained model exists
pretrained_path = '../checkpoints/best_transformer_small.pth'

if osp.exists(pretrained_path):
    print(f"Loading pre-trained model from {pretrained_path}...\n")
    
    model = create_model(
        'transformer',
        vocab_size=len(vocab),
        d_model=256,
        nhead=4,
        num_encoder_layers=2,
        num_decoder_layers=2,
        dim_feedforward=512
    )
    
    # Load weights
    checkpoint = torch.load(pretrained_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    
    print("✓ Loaded pre-trained model")
    
else:
    print("No pre-trained model found. Creating and training new model...\n")
    print("Model configuration (optimized for 8GB GPU):")
    print("  - d_model: 256")
    print("  - num_heads: 4")
    print("  - num_encoder_layers: 2")
    print("  - num_decoder_layers: 2")
    print("  - dim_feedforward: 512\n")
    
    model = create_model(
        'transformer',
        vocab_size=len(vocab),
        d_model=256,
        nhead=4,
        num_encoder_layers=2,
        num_decoder_layers=2,
        dim_feedforward=512
    )
    
    model = model.to(device)

total_params, trainable_params = count_parameters(model)
print(f"\nModel parameters: {total_params:,} total, {trainable_params:,} trainable")

## Step 5: (Optional) Train Model if Not Pre-trained

Skip this if you loaded a pre-trained model above.

In [None]:
# Only run if model is not pre-trained
if not osp.exists(pretrained_path):
    print("Training transformer model...\n")
    print("="*80)
    print("TRANSFORMER MODEL TRAINING")
    print("="*80)
    print("\nExpected training time: 4-6 hours on GPU\n")
    
    trainer = ModelTrainer(
        model=model,
        vocab=vocab,
        device=device,
        model_name='transformer_beam'
    )
    
    history = trainer.fit(
        train_loader=train_loader,
        val_loader=val_loader,
        eval_dataset=eval_dataset,
        num_epochs=40,
        learning_rate=1e-4,
        weight_decay=1e-4,
        patience=5,
        label_smoothing=0.1,
        save_dir='../checkpoints'
    )
    
    print("\n✓ Training complete!")
    
    # Save history
    with open('../results/transformer_beam_history.json', 'w') as f:
        json.dump(history, f, indent=2)
else:
    print("Using pre-trained model. Skipping training.")

## Step 6: Test Beam Search on Sample

Let's test beam search on a single sample and compare with greedy/sampling.

In [None]:
def decode_tokens(token_ids, vocab):
    """Convert token IDs to text"""
    idx_to_word = {v: k for k, v in vocab.items()}
    words = []
    for idx in token_ids:
        idx = idx.item() if torch.is_tensor(idx) else idx
        if idx == vocab['<eos>']:
            break
        if idx not in [vocab['<pad>'], vocab['<sos>'], 0]:
            words.append(idx_to_word.get(idx, '<unk>'))
    return ' '.join(words)


# Get a sample
sample_idx = 0
sample = eval_dataset[sample_idx]
mel = sample['mel'].unsqueeze(0).to(device)

print(f"Sample: {sample['fname']}")
print(f"\nReference captions:")
for i, ref in enumerate(sample['captions'], 1):
    print(f"  {i}. {ref}")

print("\n" + "="*80)
print("GENERATION METHODS COMPARISON")
print("="*80)

model.eval()

# 1. Greedy decoding (temperature=0 approximation, or argmax)
print("\n1. GREEDY DECODING")
with torch.no_grad():
    greedy_ids = model.generate(
        mel, max_len=30, 
        sos_idx=vocab['<sos>'], 
        eos_idx=vocab['<eos>'],
        temperature=0.1,  # Very low temperature ≈ greedy
        top_p=1.0
    )
greedy_caption = decode_tokens(greedy_ids[0], vocab)
print(f"   {greedy_caption}")

# 2. Sampling with temperature
print("\n2. SAMPLING (temperature=0.7, top_p=0.9)")
with torch.no_grad():
    sample_ids = model.generate(
        mel, max_len=30,
        sos_idx=vocab['<sos>'],
        eos_idx=vocab['<eos>'],
        temperature=0.7,
        top_k=50,
        top_p=0.9
    )
sample_caption = decode_tokens(sample_ids[0], vocab)
print(f"   {sample_caption}")

# 3. Beam search
print("\n3. BEAM SEARCH (beam_width=5)")
with torch.no_grad():
    beam_ids, beam_scores = beam_search_generate(
        model, mel,
        beam_width=5,
        max_len=30,
        sos_idx=vocab['<sos>'],
        eos_idx=vocab['<eos>'],
        length_penalty=0.6,
        device=device
    )
beam_caption = decode_tokens(beam_ids[0], vocab)
print(f"   {beam_caption}")
print(f"   Score: {beam_scores[0].item():.4f}")

print("\n" + "="*80)

## Step 7: Comprehensive Evaluation with Beam Search

Evaluate on multiple samples using beam search.

In [None]:
def evaluate_with_beam_search(model, eval_dataset, vocab, device='cuda', 
                               num_samples=100, beam_width=5, length_penalty=0.6):
    """
    Evaluate model using beam search
    """
    from src.evaluation import calculate_repetition_rate, evaluate_diversity, calculate_caption_length_stats
    
    model.eval()
    idx_to_word = {v: k for k, v in vocab.items()}
    
    generated_captions = []
    reference_captions = []
    repetition_scores = []
    
    num_samples = min(num_samples, len(eval_dataset))
    
    print(f"Evaluating with beam search (beam_width={beam_width})...")
    
    for i in tqdm(range(num_samples)):
        item = eval_dataset[i]
        mel = item['mel'].unsqueeze(0).to(device)
        
        # Generate with beam search
        with torch.no_grad():
            ids, scores = beam_search_generate(
                model, mel,
                beam_width=beam_width,
                max_len=30,
                sos_idx=vocab['<sos>'],
                eos_idx=vocab['<eos>'],
                length_penalty=length_penalty,
                device=device
            )
        
        # Decode
        caption = decode_tokens(ids[0], vocab)
        generated_captions.append(caption)
        reference_captions.append(item['captions'])
        repetition_scores.append(calculate_repetition_rate(caption))
    
    # Calculate metrics
    diversity_metrics = evaluate_diversity(generated_captions)
    length_stats = calculate_caption_length_stats(generated_captions)
    avg_repetition = np.mean(repetition_scores)
    
    results = {
        'num_samples': num_samples,
        'beam_width': beam_width,
        'length_penalty': length_penalty,
        'avg_repetition_rate': avg_repetition,
        'vocabulary_diversity': diversity_metrics['diversity'],
        'unique_words_used': diversity_metrics['unique_words'],
        'total_words_generated': diversity_metrics['total_words'],
        'mean_caption_length': length_stats['mean_length'],
        'std_caption_length': length_stats['std_length'],
        'min_caption_length': length_stats['min_length'],
        'max_caption_length': length_stats['max_length']
    }
    
    print("\n" + "="*60)
    print("BEAM SEARCH EVALUATION RESULTS")
    print("="*60)
    for key, value in results.items():
        if isinstance(value, float):
            print(f"{key:.<40} {value:.4f}")
        else:
            print(f"{key:.<40} {value}")
    print("="*60)
    
    return results, generated_captions, reference_captions


# Run evaluation
beam_results, beam_captions, beam_refs = evaluate_with_beam_search(
    model, eval_dataset, vocab, 
    device=device,
    num_samples=100,
    beam_width=5,
    length_penalty=0.6
)

## Step 8: Compare Different Beam Widths

In [None]:
print("Comparing different beam widths...\n")

beam_widths = [1, 3, 5, 10]
comparison_results = {}

for beam_width in beam_widths:
    print(f"\nTesting beam_width={beam_width}...")
    results, captions, refs = evaluate_with_beam_search(
        model, eval_dataset, vocab,
        device=device,
        num_samples=50,  # Use fewer samples for faster comparison
        beam_width=beam_width,
        length_penalty=0.6
    )
    comparison_results[f"beam_{beam_width}"] = results

# Print comparison table
print("\n" + "="*100)
print("BEAM WIDTH COMPARISON")
print("="*100)
print(f"{'Beam Width':<15} {'Repetition':<15} {'Diversity':<15} {'Avg Length':<15} {'Vocab Used':<15}")
print("-"*100)

for name, results in comparison_results.items():
    beam_w = results['beam_width']
    print(f"{beam_w:<15} {results['avg_repetition_rate']:<15.4f} "
          f"{results['vocabulary_diversity']:<15.4f} {results['mean_caption_length']:<15.2f} "
          f"{results['unique_words_used']:<15}")

print("="*100)

## Step 9: Sample Predictions with Beam Search

In [None]:
print("\n" + "="*80)
print("SAMPLE PREDICTIONS WITH BEAM SEARCH")
print("="*80)

num_samples_to_show = 10

for i in range(min(num_samples_to_show, len(eval_dataset))):
    item = eval_dataset[i]
    mel = item['mel'].unsqueeze(0).to(device)
    
    # Generate with beam search
    with torch.no_grad():
        beam_ids, beam_scores = beam_search_generate(
            model, mel,
            beam_width=5,
            max_len=30,
            sos_idx=vocab['<sos>'],
            eos_idx=vocab['<eos>'],
            length_penalty=0.6,
            device=device
        )
    
    caption = decode_tokens(beam_ids[0], vocab)
    
    print(f"\nSample {i+1}: {item['fname']}")
    print("-"*80)
    print(f"Generated:  {caption}")
    print(f"Score:      {beam_scores[0].item():.4f}")
    print(f"References:")
    for j, ref in enumerate(item['captions'], 1):
        print(f"  {j}. {ref}")

print("\n" + "="*80)

## Step 10: Save Results

In [None]:
# Save beam search results
serializable_results = make_json_serializable(beam_results)

with open('../results/transformer_beam_search_results.json', 'w') as f:
    json.dump(serializable_results, f, indent=2)

# Save beam width comparison
serializable_comparison = {k: make_json_serializable(v) for k, v in comparison_results.items()}

with open('../results/beam_width_comparison.json', 'w') as f:
    json.dump(serializable_comparison, f, indent=2)

print("✓ Results saved!")
print("  - ../results/transformer_beam_search_results.json")
print("  - ../results/beam_width_comparison.json")

## Summary

This notebook demonstrated:

1. **Beam search implementation** for transformer models
2. **Comparison** between greedy, sampling, and beam search decoding
3. **Effect of beam width** on generation quality
4. **Comprehensive evaluation** with beam search

### Key Findings:

- **Beam search** generally produces more coherent and natural captions than greedy decoding
- **Larger beam widths** (5-10) typically give better results but are slower
- **Length penalty** (0.6) helps avoid overly short captions
- Beam search often has **lower repetition** and **better vocabulary diversity** than sampling

### Trade-offs:

- **Greedy**: Fastest, but locally optimal
- **Sampling**: Creative and diverse, but less coherent
- **Beam search**: Best quality, but slower (especially with large beams)

### Next Steps:

- Try different length penalties (0.4, 0.6, 0.8, 1.0)
- Experiment with diverse beam search (group beam search)
- Combine with temperature sampling for controllable diversity
- Compare with reference-based metrics (BLEU, METEOR, CIDEr)