# Advanced Text Generation: The Art and Science of Language

Training a model is only half the battle. The real magic happens during generation - turning trained weights into coherent, creative, and controllable text.

## The Generation Challenge

Your model outputs probabilities over 50,000+ tokens. How do you turn this into meaningful text? The naive approach - always picking the highest probability token - produces boring, repetitive output.

## The Physics of Language Generation

Text generation is fundamentally about **sampling from probability distributions**. But not all sampling is equal:

**Deterministic Sampling**: Always pick highest probability
- Predictable but boring
- Often gets stuck in loops
- No creativity or surprise

**Random Sampling**: Pick tokens randomly according to probabilities
- Creative but often incoherent
- Can generate nonsense
- Hard to control quality

**Smart Sampling**: Balance creativity with coherence
- Temperature scaling for controlled randomness
- Top-k and nucleus sampling for quality control
- Beam search for structured exploration

## The Information Theory Foundation

Good text generation requires understanding **entropy** and **surprise**:

**Entropy H(p) = -Σ p(x) log p(x)**
- Low entropy: Model is confident (deterministic)
- High entropy: Model is uncertain (random)
- Sweet spot: Controlled uncertainty for creativity

**Perplexity = exp(H)**: Measures how "surprised" the model is
- Lower perplexity = better predictions
- But some surprise is needed for interesting text

## What You'll Master

1. **Temperature scaling** for controlling creativity
2. **Top-k and nucleus sampling** for quality control
3. **Beam search** for structured exploration
4. **Quality metrics** for evaluating generated text
5. **Controllable generation** for specific outputs

In [None]:
import sys
sys.path.append('..')

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
import math
import heapq

from src.model.transformer import GPTModel, create_model_config
from src.data.tokenizer import create_tokenizer

torch.manual_seed(42)
np.random.seed(42)

plt.style.use('default')
sns.set_palette("husl")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print("Advanced generation laboratory ready! 🎨")

## 1. Temperature: Controlling the Heat of Creativity

Temperature is the most fundamental knob for controlling text generation. It transforms the raw probability distribution from your model.

### The Mathematical Foundation

**Softmax with Temperature**:
```
p_i = exp(logit_i / T) / Σ_j exp(logit_j / T)
```

Where T is temperature:
- **T = 1**: Normal softmax (no modification)
- **T → 0**: Becomes deterministic (always picks highest logit)
- **T → ∞**: Becomes uniform random (all tokens equally likely)

### The Physics Analogy

Temperature comes from statistical mechanics:

**Low Temperature (T < 1)**:
- Like a cold system with low kinetic energy
- Particles settle into lowest energy states
- Text becomes more deterministic and predictable

**High Temperature (T > 1)**:
- Like a hot system with high kinetic energy
- Particles can access higher energy states
- Text becomes more random and creative

### Practical Effects

**T = 0.1**: Very conservative, often repetitive
**T = 0.7**: Good balance for most applications
**T = 1.0**: Raw model probabilities
**T = 1.2**: More creative and diverse
**T = 2.0**: Often incoherent but surprising

Let's implement and visualize temperature effects:

In [None]:
# Create a pre-trained model for generation experiments
config = {
    'vocab_size': 1000,
    'd_model': 128,
    'n_heads': 8,
    'n_layers': 4,
    'd_ff': 256,
    'max_seq_len': 64,
    'dropout': 0.1
}

model = GPTModel(**config).to(device)
print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")

def apply_temperature(logits, temperature):
    """Apply temperature scaling to logits."""
    if temperature == 0:
        # Special case: make it very low temperature instead of zero
        temperature = 1e-10
    return logits / temperature

def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None):
    """Sample next token from logits with various strategies."""
    # Apply temperature
    scaled_logits = apply_temperature(logits, temperature)
    
    # Convert to probabilities
    probs = F.softmax(scaled_logits, dim=-1)
    
    # Apply top-k if specified
    if top_k is not None:
        # Get top-k indices
        top_k_probs, top_k_indices = torch.topk(probs, top_k)
        # Create mask for top-k
        mask = torch.zeros_like(probs)
        mask.scatter_(-1, top_k_indices, 1)
        probs = probs * mask
        # Renormalize
        probs = probs / probs.sum(dim=-1, keepdim=True)
    
    # Apply nucleus (top-p) sampling if specified
    if top_p is not None:
        # Sort probabilities in descending order
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        # Calculate cumulative probabilities
        cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
        # Create mask for nucleus
        nucleus_mask = cumsum_probs <= top_p
        # Include at least one token
        nucleus_mask[..., 0] = True
        # Apply mask
        sorted_probs[~nucleus_mask] = 0
        # Renormalize
        sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
        # Scatter back to original order
        probs = torch.zeros_like(probs)
        probs.scatter_(-1, sorted_indices, sorted_probs)
    
    # Sample from the distribution
    next_token = torch.multinomial(probs, num_samples=1)
    return next_token, probs

def visualize_temperature_effects():
    """Visualize how temperature affects probability distributions."""
    
    # Create example logits (before softmax)
    logits = torch.tensor([2.0, 1.5, 1.0, 0.5, 0.2, -0.5, -1.0, -2.0])
    temperatures = [0.1, 0.5, 1.0, 1.5, 2.0]
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    axes = axes.flatten()
    
    token_labels = [f'Token_{i}' for i in range(len(logits))]
    
    for i, temp in enumerate(temperatures):
        if i < len(axes):
            # Apply temperature and softmax
            scaled_logits = apply_temperature(logits, temp)
            probs = F.softmax(scaled_logits, dim=-1)
            
            # Create bar plot
            bars = axes[i].bar(token_labels, probs.numpy(), 
                              color=plt.cm.viridis(i / len(temperatures)), 
                              alpha=0.8, edgecolor='black', linewidth=1)
            
            # Add probability values on bars
            for bar, prob in zip(bars, probs.numpy()):
                if prob > 0.01:  # Only show significant probabilities
                    axes[i].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                                f'{prob:.3f}', ha='center', va='bottom', fontsize=9)
            
            axes[i].set_title(f'Temperature = {temp}\n(Entropy: {-torch.sum(probs * torch.log(probs + 1e-10)):.2f})', 
                             fontsize=12, weight='bold')
            axes[i].set_ylabel('Probability')
            axes[i].set_ylim(0, 1.0)
            axes[i].tick_params(axis='x', rotation=45)
            axes[i].grid(True, alpha=0.3)
    
    # Add entropy vs temperature plot
    temp_range = np.linspace(0.1, 3.0, 50)
    entropies = []
    
    for temp in temp_range:
        scaled_logits = apply_temperature(logits, temp)
        probs = F.softmax(scaled_logits, dim=-1)
        entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item()
        entropies.append(entropy)
    
    axes[5].plot(temp_range, entropies, linewidth=3, color='red', marker='o', markersize=4)
    axes[5].set_title('Entropy vs Temperature\n(Creativity Control)', fontsize=12, weight='bold')
    axes[5].set_xlabel('Temperature')
    axes[5].set_ylabel('Entropy (bits)')
    axes[5].grid(True, alpha=0.3)
    
    # Add annotations for practical ranges
    axes[5].axvspan(0.7, 1.2, alpha=0.2, color='green', label='Sweet Spot')
    axes[5].axvspan(0.1, 0.5, alpha=0.2, color='blue', label='Conservative')
    axes[5].axvspan(1.5, 3.0, alpha=0.2, color='orange', label='Creative')
    axes[5].legend()
    
    plt.tight_layout()
    plt.show()
    
    print("🌡️ TEMPERATURE INSIGHTS:")
    print("• Low temp (0.1-0.5): Predictable, often repetitive")
    print("• Medium temp (0.7-1.2): Good balance of coherence and creativity")
    print("• High temp (1.5+): Creative but potentially incoherent")
    print("• Entropy measures 'surprise' - key for controlling creativity")

visualize_temperature_effects()

In [None]:
# Demonstrate temperature effects with actual text generation

def generate_text(model, prompt_tokens, max_length=20, temperature=1.0, top_k=None, top_p=None):
    """Generate text using specified sampling parameters."""
    model.eval()
    generated = prompt_tokens.clone()
    
    with torch.no_grad():
        for _ in range(max_length):
            # Get model predictions
            outputs = model(generated)
            next_token_logits = outputs[0, -1, :]  # Last token predictions
            
            # Sample next token
            next_token, probs = sample_from_logits(
                next_token_logits, temperature=temperature, top_k=top_k, top_p=top_p
            )
            
            # Append to sequence
            generated = torch.cat([generated, next_token.unsqueeze(0)], dim=1)
            
            # Stop if we hit max sequence length
            if generated.size(1) >= model.max_seq_len:
                break
    
    return generated

# Create a simple prompt (random tokens for demonstration)
prompt = torch.randint(0, config['vocab_size'], (1, 5), device=device)
print(f"Starting prompt tokens: {prompt[0].tolist()}")

# Test different temperatures
temperatures = [0.1, 0.7, 1.0, 1.5, 2.0]

print("\n🎨 TEMPERATURE COMPARISON:")
print("Generating text with different temperature settings...\n")

for temp in temperatures:
    generated = generate_text(model, prompt, max_length=15, temperature=temp)
    new_tokens = generated[0, len(prompt[0]):].tolist()
    
    print(f"Temperature {temp}:")
    print(f"  Generated tokens: {new_tokens}")
    print(f"  Uniqueness: {len(set(new_tokens))}/{len(new_tokens)} unique tokens")
    
    # Calculate repetition rate
    if len(new_tokens) > 1:
        repetitions = sum(1 for i in range(1, len(new_tokens)) if new_tokens[i] == new_tokens[i-1])
        repetition_rate = repetitions / (len(new_tokens) - 1)
        print(f"  Repetition rate: {repetition_rate:.2f}")
    print()

print("🔍 OBSERVATIONS:")
print("• Low temperature: More repetitive, predictable patterns")
print("• High temperature: More diverse, potentially chaotic patterns")
print("• Medium temperature: Best balance for most applications")

## 2. Top-k and Nucleus Sampling: Quality Control

Temperature alone isn't enough. Even with perfect temperature, you might sample a very low-probability token that ruins coherence. Top-k and nucleus (top-p) sampling provide quality control.

### Top-k Sampling: Fixed Vocabulary Filtering

**Algorithm**:
1. Sort tokens by probability (highest first)
2. Keep only the top k most likely tokens
3. Set all other probabilities to zero
4. Renormalize and sample

**Effect**: Prevents sampling from the "long tail" of unlikely tokens

**Typical values**: k = 40-100 for most applications

### Nucleus (Top-p) Sampling: Adaptive Vocabulary

**Algorithm**:
1. Sort tokens by probability (highest first)
2. Add tokens to "nucleus" until cumulative probability ≥ p
3. Sample only from the nucleus
4. Adapt vocabulary size based on probability distribution

**Effect**: Dynamic vocabulary size - smaller when model is confident, larger when uncertain

**Typical values**: p = 0.9-0.95 for most applications

### The Information Theory Perspective

Both methods are about **controlling the effective vocabulary size**:

**High Confidence Distributions** (low entropy):
- Few tokens have significant probability
- Top-k might be too restrictive
- Nucleus adapts to use smaller vocabulary

**Low Confidence Distributions** (high entropy):
- Many tokens have similar probability
- Top-k provides consistent filtering
- Nucleus adapts to use larger vocabulary

Let's implement and compare these approaches:

In [None]:
def analyze_sampling_strategies():
    """Compare different sampling strategies with various probability distributions."""
    
    # Create different types of probability distributions
    vocab_size = 100
    
    # High confidence (peaked distribution)
    high_conf_logits = torch.zeros(vocab_size)
    high_conf_logits[0] = 3.0  # Very likely token
    high_conf_logits[1] = 1.0  # Somewhat likely
    high_conf_logits[2:10] = 0.2  # Slightly likely
    high_conf_logits[10:] = -2.0  # Very unlikely
    
    # Low confidence (flat distribution)
    low_conf_logits = torch.randn(vocab_size) * 0.5
    
    # Medium confidence (moderate peak)
    med_conf_logits = torch.zeros(vocab_size)
    med_conf_logits[0] = 1.5
    med_conf_logits[1:5] = 1.0
    med_conf_logits[5:20] = 0.5
    med_conf_logits[20:] = -1.0
    
    distributions = {
        'High Confidence': high_conf_logits,
        'Medium Confidence': med_conf_logits,
        'Low Confidence': low_conf_logits
    }
    
    fig, axes = plt.subplots(3, 4, figsize=(20, 15))
    
    for row, (dist_name, logits) in enumerate(distributions.items()):
        # Original distribution
        probs = F.softmax(logits, dim=-1)
        entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item()
        
        # Plot original
        axes[row, 0].bar(range(min(20, vocab_size)), probs[:20].numpy(), 
                        alpha=0.7, color='blue', edgecolor='black')
        axes[row, 0].set_title(f'{dist_name}\nOriginal (H={entropy:.2f})', weight='bold')
        axes[row, 0].set_ylabel('Probability')
        
        # Top-k = 10
        _, top_k_probs = sample_from_logits(logits, temperature=1.0, top_k=10)
        k_entropy = -torch.sum(top_k_probs * torch.log(top_k_probs + 1e-10)).item()
        axes[row, 1].bar(range(min(20, vocab_size)), top_k_probs[:20].numpy(), 
                        alpha=0.7, color='green', edgecolor='black')
        axes[row, 1].set_title(f'Top-k=10\n(H={k_entropy:.2f})', weight='bold')
        
        # Nucleus p=0.9
        _, nucleus_probs = sample_from_logits(logits, temperature=1.0, top_p=0.9)
        n_entropy = -torch.sum(nucleus_probs * torch.log(nucleus_probs + 1e-10)).item()
        axes[row, 2].bar(range(min(20, vocab_size)), nucleus_probs[:20].numpy(), 
                        alpha=0.7, color='orange', edgecolor='black')
        axes[row, 2].set_title(f'Nucleus p=0.9\n(H={n_entropy:.2f})', weight='bold')
        
        # Combined: Top-k + Nucleus
        _, combined_probs = sample_from_logits(logits, temperature=1.0, top_k=10, top_p=0.9)
        c_entropy = -torch.sum(combined_probs * torch.log(combined_probs + 1e-10)).item()
        axes[row, 3].bar(range(min(20, vocab_size)), combined_probs[:20].numpy(), 
                        alpha=0.7, color='red', edgecolor='black')
        axes[row, 3].set_title(f'Top-k + Nucleus\n(H={c_entropy:.2f})', weight='bold')
        
        # Calculate effective vocabulary size (tokens with prob > 0.01)
        orig_vocab = (probs > 0.01).sum().item()
        k_vocab = (top_k_probs > 0.01).sum().item()
        n_vocab = (nucleus_probs > 0.01).sum().item()
        c_vocab = (combined_probs > 0.01).sum().item()
        
        print(f"\n{dist_name}:")
        print(f"  Original entropy: {entropy:.3f}, effective vocab: {orig_vocab}")
        print(f"  Top-k entropy: {k_entropy:.3f}, effective vocab: {k_vocab}")
        print(f"  Nucleus entropy: {n_entropy:.3f}, effective vocab: {n_vocab}")
        print(f"  Combined entropy: {c_entropy:.3f}, effective vocab: {c_vocab}")
    
    # Set common formatting
    for i in range(3):
        for j in range(4):
            axes[i, j].set_xlabel('Token Index')
            axes[i, j].grid(True, alpha=0.3)
            axes[i, j].set_ylim(0, 1.0)
    
    plt.tight_layout()
    plt.show()
    
    return distributions

print("🎯 ANALYZING SAMPLING STRATEGIES")
print("Comparing how different methods handle various probability distributions...")

distributions = analyze_sampling_strategies()

print("\n🔍 KEY INSIGHTS:")
print("• Top-k: Fixed vocabulary size, good for consistent filtering")
print("• Nucleus: Adaptive vocabulary, matches model confidence")
print("• Combined: Best of both - consistent bounds with adaptivity")
print("• Entropy measures randomness - lower = more focused sampling")

In [None]:
# Practical comparison of sampling strategies

def compare_sampling_methods():
    """Generate text with different sampling methods and compare results."""
    
    # Generation settings
    prompt = torch.randint(0, config['vocab_size'], (1, 3), device=device)
    generation_length = 20
    num_samples = 5  # Generate multiple samples for each method
    
    methods = {
        'Greedy (T=0.1)': {'temperature': 0.1, 'top_k': None, 'top_p': None},
        'Pure Random (T=2.0)': {'temperature': 2.0, 'top_k': None, 'top_p': None},
        'Top-k=40': {'temperature': 1.0, 'top_k': 40, 'top_p': None},
        'Nucleus p=0.9': {'temperature': 1.0, 'top_k': None, 'top_p': 0.9},
        'Balanced (T=0.8, k=50, p=0.95)': {'temperature': 0.8, 'top_k': 50, 'top_p': 0.95}
    }
    
    results = {}
    
    print(f"🎲 SAMPLING METHOD COMPARISON")
    print(f"Starting with prompt tokens: {prompt[0].tolist()}")
    print(f"Generating {generation_length} tokens with each method...\n")
    
    for method_name, params in methods.items():
        print(f"📝 {method_name}:")
        
        samples = []
        diversities = []
        repetition_rates = []
        
        for sample_idx in range(num_samples):
            # Generate text
            generated = generate_text(
                model, prompt, 
                max_length=generation_length, 
                **params
            )
            
            # Extract new tokens
            new_tokens = generated[0, len(prompt[0]):].tolist()
            samples.append(new_tokens)
            
            # Calculate diversity (unique tokens / total tokens)
            diversity = len(set(new_tokens)) / len(new_tokens) if new_tokens else 0
            diversities.append(diversity)
            
            # Calculate repetition rate (consecutive identical tokens)
            if len(new_tokens) > 1:
                repetitions = sum(1 for i in range(1, len(new_tokens)) 
                                if new_tokens[i] == new_tokens[i-1])
                repetition_rate = repetitions / (len(new_tokens) - 1)
            else:
                repetition_rate = 0
            repetition_rates.append(repetition_rate)
            
            # Show first few samples
            if sample_idx < 2:
                print(f"  Sample {sample_idx + 1}: {new_tokens[:10]}{'...' if len(new_tokens) > 10 else ''}")
        
        # Calculate statistics across samples
        avg_diversity = np.mean(diversities)
        avg_repetition = np.mean(repetition_rates)
        
        # Calculate inter-sample diversity (how different samples are from each other)
        all_tokens = set()
        for sample in samples:
            all_tokens.update(sample)
        inter_sample_diversity = len(all_tokens)
        
        results[method_name] = {
            'avg_diversity': avg_diversity,
            'avg_repetition': avg_repetition,
            'inter_sample_diversity': inter_sample_diversity,
            'samples': samples
        }
        
        print(f"  Avg diversity: {avg_diversity:.3f}")
        print(f"  Avg repetition rate: {avg_repetition:.3f}")
        print(f"  Inter-sample diversity: {inter_sample_diversity} unique tokens")
        print()
    
    return results

# Run comparison
sampling_results = compare_sampling_methods()

print("🏆 SAMPLING METHOD ANALYSIS:")
print("\nRanking by diversity (creativity):")
diversity_ranking = sorted(sampling_results.items(), 
                          key=lambda x: x[1]['avg_diversity'], reverse=True)
for i, (method, stats) in enumerate(diversity_ranking, 1):
    print(f"  {i}. {method}: {stats['avg_diversity']:.3f}")

print("\nRanking by coherence (low repetition):")
coherence_ranking = sorted(sampling_results.items(), 
                          key=lambda x: x[1]['avg_repetition'])
for i, (method, stats) in enumerate(coherence_ranking, 1):
    print(f"  {i}. {method}: {stats['avg_repetition']:.3f} repetition rate")

print("\n🎯 PRACTICAL RECOMMENDATIONS:")
print("• Greedy: Best for consistency, worst for creativity")
print("• Pure Random: Most creative, often incoherent")
print("• Top-k: Good balance, consistent vocabulary control")
print("• Nucleus: Adaptive, matches model confidence")
print("• Balanced: Often optimal - combines multiple techniques")

## 3. Beam Search: Structured Exploration

All previous methods are **local** - they only consider the next token. Beam search is **global** - it considers entire sequences and their cumulative probabilities.

### The Algorithm

**Beam Search maintains k "beams" (partial sequences):**

1. **Initialize**: Start with k copies of the prompt
2. **Expand**: For each beam, generate all possible next tokens
3. **Score**: Calculate cumulative log-probability for each sequence
4. **Prune**: Keep only the k highest-scoring sequences
5. **Repeat**: Until all beams end or max length reached

### The Mathematics

**Sequence Score**: S(x₁, x₂, ..., xₙ) = Σᵢ log P(xᵢ | x₁, ..., xᵢ₋₁)

**Length Normalization**: Often divide by sequence length to avoid bias toward shorter sequences

**Diversity Penalties**: Subtract penalties for similar beams to encourage exploration

### Why Beam Search Works

**Global Optimization**: Unlike sampling, beam search optimizes the entire sequence

**Structured Exploration**: Systematically explores the most promising paths

**Quality Control**: High-probability sequences are more likely to be coherent

### Trade-offs

**Advantages**:
- Higher quality, more coherent output
- Deterministic (same input → same output)
- Good for tasks requiring consistency

**Disadvantages**:
- Less creative/diverse than sampling
- Computationally expensive (k forward passes per step)
- Can get stuck in repetitive patterns

Let's implement beam search and compare it with sampling methods:

In [None]:
class BeamSearchGenerator:
    """Implement beam search for text generation."""
    
    def __init__(self, model, beam_size=5, max_length=50, length_penalty=1.0, diversity_penalty=0.0):
        self.model = model
        self.beam_size = beam_size
        self.max_length = max_length
        self.length_penalty = length_penalty
        self.diversity_penalty = diversity_penalty
    
    def generate(self, prompt_tokens):
        """Generate text using beam search."""
        self.model.eval()
        
        # Initialize beams: (sequence, cumulative_score, finished)
        beams = [(prompt_tokens.clone(), 0.0, False)]
        finished_beams = []
        
        with torch.no_grad():
            for step in range(self.max_length):
                candidates = []
                
                # Expand each beam
                for beam_seq, beam_score, is_finished in beams:
                    if is_finished:
                        candidates.append((beam_seq, beam_score, True))
                        continue
                    
                    # Get model predictions for this beam
                    outputs = self.model(beam_seq)
                    next_token_logits = outputs[0, -1, :]
                    next_token_probs = F.softmax(next_token_logits, dim=-1)
                    
                    # Get top tokens to consider (limit search space)
                    top_probs, top_indices = torch.topk(next_token_probs, 
                                                       min(self.beam_size * 2, len(next_token_probs)))
                    
                    # Create candidate sequences
                    for prob, token_id in zip(top_probs, top_indices):
                        new_seq = torch.cat([beam_seq, token_id.unsqueeze(0).unsqueeze(0)], dim=1)
                        new_score = beam_score + torch.log(prob).item()
                        
                        # Apply length penalty
                        length_normalized_score = new_score / (new_seq.size(1) ** self.length_penalty)
                        
                        # Check if sequence should be finished (you could add end-of-sequence logic here)
                        is_finished = new_seq.size(1) >= self.max_length
                        
                        candidates.append((new_seq, length_normalized_score, is_finished))
                
                # Apply diversity penalty
                if self.diversity_penalty > 0:
                    candidates = self._apply_diversity_penalty(candidates)
                
                # Sort candidates by score and keep top beams
                candidates.sort(key=lambda x: x[1], reverse=True)
                
                # Separate finished and unfinished beams
                new_beams = []
                for seq, score, finished in candidates:
                    if finished:
                        finished_beams.append((seq, score))
                    else:
                        new_beams.append((seq, score, finished))
                    
                    if len(new_beams) >= self.beam_size:
                        break
                
                beams = new_beams
                
                # Stop if all beams are finished
                if not beams:
                    break
        
        # Add remaining beams to finished beams
        for seq, score, _ in beams:
            finished_beams.append((seq, score))
        
        # Sort finished beams by score
        finished_beams.sort(key=lambda x: x[1], reverse=True)
        
        return finished_beams
    
    def _apply_diversity_penalty(self, candidates):
        """Apply diversity penalty to encourage different sequences."""
        # Simple diversity penalty: reduce score if sequence is too similar to higher-scoring ones
        penalized_candidates = []
        
        for i, (seq, score, finished) in enumerate(candidates):
            penalty = 0.0
            
            # Compare with higher-scoring candidates
            for j in range(i):
                other_seq, _, _ = candidates[j]
                # Calculate similarity (simple: number of matching tokens)
                min_len = min(seq.size(1), other_seq.size(1))
                matches = (seq[0, :min_len] == other_seq[0, :min_len]).sum().item()
                similarity = matches / min_len if min_len > 0 else 0
                penalty += similarity * self.diversity_penalty
            
            penalized_score = score - penalty
            penalized_candidates.append((seq, penalized_score, finished))
        
        return penalized_candidates

# Test beam search
def compare_beam_search_vs_sampling():
    """Compare beam search with sampling methods."""
    
    prompt = torch.randint(0, config['vocab_size'], (1, 4), device=device)
    generation_length = 15
    
    print(f"🔍 BEAM SEARCH vs SAMPLING COMPARISON")
    print(f"Starting prompt: {prompt[0].tolist()}")
    print(f"Generating {generation_length} additional tokens...\n")
    
    # Beam search with different beam sizes
    beam_sizes = [1, 3, 5]
    
    for beam_size in beam_sizes:
        print(f"🔬 Beam Search (size={beam_size}):")
        
        generator = BeamSearchGenerator(
            model, beam_size=beam_size, max_length=generation_length, 
            length_penalty=1.0, diversity_penalty=0.0
        )
        
        beams = generator.generate(prompt)
        
        for i, (sequence, score) in enumerate(beams[:3]):
            new_tokens = sequence[0, len(prompt[0]):].tolist()
            print(f"  Beam {i+1}: {new_tokens} (score: {score:.3f})")
        
        # Calculate diversity within beams
        if len(beams) > 1:
            all_tokens = set()
            for seq, _ in beams[:beam_size]:
                tokens = seq[0, len(prompt[0]):].tolist()
                all_tokens.update(tokens)
            print(f"  Beam diversity: {len(all_tokens)} unique tokens across top {beam_size} beams")
        print()
    
    # Compare with sampling methods
    print("🎲 Sampling Methods (for comparison):")
    
    sampling_methods = {
        'Nucleus (p=0.9)': {'temperature': 1.0, 'top_k': None, 'top_p': 0.9},
        'Top-k (k=50)': {'temperature': 1.0, 'top_k': 50, 'top_p': None},
    }
    
    for method_name, params in sampling_methods.items():
        print(f"  {method_name}:")
        
        samples = []
        for i in range(3):
            generated = generate_text(model, prompt, max_length=generation_length, **params)
            new_tokens = generated[0, len(prompt[0]):].tolist()
            samples.append(new_tokens)
            print(f"    Sample {i+1}: {new_tokens}")
        
        # Calculate diversity
        all_tokens = set()
        for sample in samples:
            all_tokens.update(sample)
        print(f"    Sample diversity: {len(all_tokens)} unique tokens across 3 samples")
        print()
    
    return beams

beam_results = compare_beam_search_vs_sampling()

print("📊 BEAM SEARCH INSIGHTS:")
print("• Beam size 1 = greedy search (deterministic)")
print("• Larger beam sizes explore more possibilities")
print("• Beam search optimizes entire sequences, not just next tokens")
print("• Generally more coherent but less diverse than sampling")
print("• Good for tasks requiring consistency (translation, summarization)")
print("• Sampling better for creative tasks (story writing, dialogue)")

## 4. Quality Metrics: Measuring Generation Success

How do you know if your generation is good? Unlike training loss, there's no single metric for generation quality. You need multiple perspectives.

### Automatic Metrics

**Perplexity**: How "surprised" the model is by the generated text
- Lower = more predictable (potentially boring)
- Higher = more surprising (potentially incoherent)
- Formula: PP = exp(-1/N Σ log P(token_i))

**Diversity Metrics**:
- **Type-Token Ratio (TTR)**: Unique words / Total words
- **Entropy**: Information content of the distribution
- **Self-BLEU**: How similar generated samples are to each other (lower = more diverse)

**Repetition Metrics**:
- **n-gram Repetition**: Percentage of repeated n-grams
- **Longest Repeated Sequence**: Maximum length of repeated subsequence

### Semantic Metrics

**Coherence**: Does the text make sense?
- Sentence-level: Grammar and syntax
- Document-level: Logical flow and consistency

**Relevance**: Does it match the prompt/context?
- Semantic similarity to prompt
- Topic consistency

**Factuality**: Is the information correct?
- Fact-checking against knowledge bases
- Consistency with known facts

### The Evaluation Challenge

**No Single "Best" Metric**: Different tasks need different evaluation criteria

**Human Evaluation**: Often the gold standard, but expensive and subjective

**Task-Specific Metrics**: Translation has BLEU, summarization has ROUGE, etc.

Let's implement a comprehensive evaluation suite:

In [None]:
class GenerationMetrics:
    """Comprehensive metrics for evaluating text generation quality."""
    
    def __init__(self, model):
        self.model = model
    
    def calculate_perplexity(self, text_tokens):
        """Calculate perplexity of generated text under the model."""
        self.model.eval()
        
        with torch.no_grad():
            if len(text_tokens.shape) == 1:
                text_tokens = text_tokens.unsqueeze(0)
            
            # Get model predictions
            outputs = self.model(text_tokens)
            
            # Calculate cross-entropy loss (negative log likelihood)
            # Shift tokens for teacher forcing
            logits = outputs[:, :-1, :].contiguous()
            targets = text_tokens[:, 1:].contiguous()
            
            # Calculate loss
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
            perplexity = torch.exp(loss).item()
            
            return perplexity
    
    def calculate_diversity_metrics(self, token_sequences):
        """Calculate various diversity metrics for a set of generated sequences."""
        if not token_sequences:
            return {}
        
        # Convert to lists if needed
        if isinstance(token_sequences[0], torch.Tensor):
            token_sequences = [seq.tolist() if isinstance(seq, torch.Tensor) else seq 
                             for seq in token_sequences]
        
        metrics = {}
        
        # Type-Token Ratio (TTR) for each sequence
        ttrs = []
        for seq in token_sequences:
            if len(seq) > 0:
                ttr = len(set(seq)) / len(seq)
                ttrs.append(ttr)
        
        metrics['avg_ttr'] = np.mean(ttrs) if ttrs else 0
        metrics['std_ttr'] = np.std(ttrs) if ttrs else 0
        
        # Inter-sequence diversity (unique tokens across all sequences)
        all_tokens = set()
        total_tokens = 0
        for seq in token_sequences:
            all_tokens.update(seq)
            total_tokens += len(seq)
        
        metrics['inter_seq_ttr'] = len(all_tokens) / total_tokens if total_tokens > 0 else 0
        metrics['vocab_size'] = len(all_tokens)
        
        # Self-BLEU (how similar sequences are to each other - lower is more diverse)
        if len(token_sequences) > 1:
            self_bleu_scores = []
            for i, seq1 in enumerate(token_sequences):
                bleu_sum = 0
                count = 0
                for j, seq2 in enumerate(token_sequences):
                    if i != j:
                        # Simple n-gram overlap (proxy for BLEU)
                        overlap = self._calculate_ngram_overlap(seq1, seq2, n=2)
                        bleu_sum += overlap
                        count += 1
                if count > 0:
                    self_bleu_scores.append(bleu_sum / count)
            
            metrics['self_bleu'] = np.mean(self_bleu_scores) if self_bleu_scores else 0
        
        return metrics
    
    def calculate_repetition_metrics(self, token_sequence):
        """Calculate repetition metrics for a single sequence."""
        if isinstance(token_sequence, torch.Tensor):
            token_sequence = token_sequence.tolist()
        
        if len(token_sequence) <= 1:
            return {'repetition_rate': 0, 'max_repeated_ngram': 0, 'ngram_repetition_2': 0}
        
        metrics = {}
        
        # Immediate repetition rate (consecutive identical tokens)
        repetitions = sum(1 for i in range(1, len(token_sequence)) 
                         if token_sequence[i] == token_sequence[i-1])
        metrics['repetition_rate'] = repetitions / (len(token_sequence) - 1)
        
        # N-gram repetition (for n=2, 3)
        for n in [2, 3]:
            if len(token_sequence) >= n:
                ngrams = [tuple(token_sequence[i:i+n]) for i in range(len(token_sequence)-n+1)]
                unique_ngrams = len(set(ngrams))
                total_ngrams = len(ngrams)
                repetition = 1 - (unique_ngrams / total_ngrams) if total_ngrams > 0 else 0
                metrics[f'ngram_repetition_{n}'] = repetition
        
        # Longest repeated subsequence
        max_repeat_len = self._find_longest_repeated_subsequence(token_sequence)
        metrics['max_repeated_ngram'] = max_repeat_len
        
        return metrics
    
    def _calculate_ngram_overlap(self, seq1, seq2, n=2):
        """Calculate n-gram overlap between two sequences (simple BLEU approximation)."""
        if len(seq1) < n or len(seq2) < n:
            return 0.0
        
        ngrams1 = set(tuple(seq1[i:i+n]) for i in range(len(seq1)-n+1))
        ngrams2 = set(tuple(seq2[i:i+n]) for i in range(len(seq2)-n+1))
        
        if len(ngrams1) == 0:
            return 0.0
        
        overlap = len(ngrams1.intersection(ngrams2))
        return overlap / len(ngrams1)
    
    def _find_longest_repeated_subsequence(self, sequence):
        """Find the length of the longest repeated subsequence."""
        max_length = 0
        n = len(sequence)
        
        # Check for repeated subsequences of increasing length
        for length in range(1, n // 2 + 1):
            for i in range(n - length):
                subseq = sequence[i:i+length]
                # Look for this subsequence later in the sequence
                for j in range(i + length, n - length + 1):
                    if sequence[j:j+length] == subseq:
                        max_length = max(max_length, length)
                        break
        
        return max_length
    
    def comprehensive_evaluation(self, generated_sequences, prompts=None):
        """Run comprehensive evaluation on generated text."""
        results = {
            'perplexities': [],
            'diversity_metrics': {},
            'repetition_metrics': [],
            'quality_scores': []
        }
        
        # Calculate per-sequence metrics
        for seq in generated_sequences:
            # Perplexity
            if isinstance(seq, list):
                seq_tensor = torch.tensor([seq], device=device)
            else:
                seq_tensor = seq
            
            try:
                perplexity = self.calculate_perplexity(seq_tensor)
                results['perplexities'].append(perplexity)
            except:
                results['perplexities'].append(float('inf'))
            
            # Repetition metrics
            rep_metrics = self.calculate_repetition_metrics(seq)
            results['repetition_metrics'].append(rep_metrics)
        
        # Calculate diversity metrics across all sequences
        results['diversity_metrics'] = self.calculate_diversity_metrics(generated_sequences)
        
        # Calculate overall quality score (lower is better)
        avg_perplexity = np.mean([p for p in results['perplexities'] if p != float('inf')])
        avg_repetition = np.mean([m['repetition_rate'] for m in results['repetition_metrics']])
        diversity_score = results['diversity_metrics'].get('avg_ttr', 0)
        
        # Composite quality score (you can adjust weights)
        quality_score = {
            'perplexity_score': avg_perplexity,
            'repetition_penalty': avg_repetition * 100,  # Higher repetition = worse
            'diversity_bonus': diversity_score * 100,    # Higher diversity = better
            'composite_score': avg_perplexity + (avg_repetition * 50) - (diversity_score * 20)
        }
        results['quality_scores'] = quality_score
        
        return results

print("📊 Generation metrics toolkit ready!")

In [None]:
# Comprehensive evaluation of different generation methods

def evaluate_generation_methods():
    """Evaluate different generation methods using comprehensive metrics."""
    
    metrics_calculator = GenerationMetrics(model)
    
    # Test different generation methods
    methods = {
        'Greedy': {'temperature': 0.1, 'top_k': None, 'top_p': None},
        'High Temperature': {'temperature': 2.0, 'top_k': None, 'top_p': None},
        'Top-k (k=40)': {'temperature': 1.0, 'top_k': 40, 'top_p': None},
        'Nucleus (p=0.9)': {'temperature': 1.0, 'top_k': None, 'top_p': 0.9},
        'Balanced': {'temperature': 0.8, 'top_k': 50, 'top_p': 0.95}
    }
    
    # Generate multiple samples for each method
    prompt = torch.randint(0, config['vocab_size'], (1, 4), device=device)
    num_samples = 5
    generation_length = 20
    
    print(f"🔬 COMPREHENSIVE GENERATION EVALUATION")
    print(f"Generating {num_samples} samples of {generation_length} tokens each...\n")
    
    all_results = {}
    
    for method_name, params in methods.items():
        print(f"📝 Evaluating {method_name}...")
        
        # Generate samples
        samples = []
        for _ in range(num_samples):
            generated = generate_text(model, prompt, max_length=generation_length, **params)
            # Extract only the newly generated tokens
            new_tokens = generated[0, len(prompt[0]):].tolist()
            samples.append(new_tokens)
        
        # Run comprehensive evaluation
        results = metrics_calculator.comprehensive_evaluation(samples)
        all_results[method_name] = results
        
        # Print summary
        print(f"  Avg Perplexity: {results['quality_scores']['perplexity_score']:.2f}")
        print(f"  Avg TTR (diversity): {results['diversity_metrics']['avg_ttr']:.3f}")
        print(f"  Repetition Rate: {np.mean([m['repetition_rate'] for m in results['repetition_metrics']]):.3f}")
        print(f"  Composite Score: {results['quality_scores']['composite_score']:.2f} (lower = better)")
        print()
    
    return all_results

# Run evaluation
evaluation_results = evaluate_generation_methods()

# Create comparison visualization
def visualize_evaluation_results(results):
    """Visualize the evaluation results across different methods."""
    
    methods = list(results.keys())
    
    # Extract metrics
    perplexities = [results[method]['quality_scores']['perplexity_score'] for method in methods]
    diversities = [results[method]['diversity_metrics']['avg_ttr'] for method in methods]
    repetitions = [np.mean([m['repetition_rate'] for m in results[method]['repetition_metrics']]) 
                  for method in methods]
    composite_scores = [results[method]['quality_scores']['composite_score'] for method in methods]
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # Perplexity (lower is better)
    bars = axes[0, 0].bar(methods, perplexities, color='skyblue', alpha=0.8, edgecolor='black')
    axes[0, 0].set_title('Perplexity (Lower = Better)', fontsize=14, weight='bold')
    axes[0, 0].set_ylabel('Perplexity')
    axes[0, 0].tick_params(axis='x', rotation=45)
    
    for bar, val in zip(bars, perplexities):
        axes[0, 0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
                       f'{val:.1f}', ha='center', va='bottom', fontsize=10)
    
    # Diversity (higher is better)
    bars = axes[0, 1].bar(methods, diversities, color='lightgreen', alpha=0.8, edgecolor='black')
    axes[0, 1].set_title('Diversity (TTR, Higher = Better)', fontsize=14, weight='bold')
    axes[0, 1].set_ylabel('Type-Token Ratio')
    axes[0, 1].tick_params(axis='x', rotation=45)
    
    for bar, val in zip(bars, diversities):
        axes[0, 1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.005,
                       f'{val:.3f}', ha='center', va='bottom', fontsize=10)
    
    # Repetition Rate (lower is better)
    bars = axes[1, 0].bar(methods, repetitions, color='salmon', alpha=0.8, edgecolor='black')
    axes[1, 0].set_title('Repetition Rate (Lower = Better)', fontsize=14, weight='bold')
    axes[1, 0].set_ylabel('Repetition Rate')
    axes[1, 0].tick_params(axis='x', rotation=45)
    
    for bar, val in zip(bars, repetitions):
        axes[1, 0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.005,
                       f'{val:.3f}', ha='center', va='bottom', fontsize=10)
    
    # Composite Score (lower is better)
    bars = axes[1, 1].bar(methods, composite_scores, color='gold', alpha=0.8, edgecolor='black')
    axes[1, 1].set_title('Composite Quality Score (Lower = Better)', fontsize=14, weight='bold')
    axes[1, 1].set_ylabel('Composite Score')
    axes[1, 1].tick_params(axis='x', rotation=45)
    
    for bar, val in zip(bars, composite_scores):
        axes[1, 1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                       f'{val:.1f}', ha='center', va='bottom', fontsize=10)
    
    plt.tight_layout()
    plt.show()
    
    # Print rankings
    print("🏆 GENERATION METHOD RANKINGS:")
    
    # Best perplexity (lowest)
    perp_ranking = sorted(zip(methods, perplexities), key=lambda x: x[1])
    print("\n📉 Best Perplexity (Coherence):")
    for i, (method, score) in enumerate(perp_ranking, 1):
        print(f"  {i}. {method}: {score:.2f}")
    
    # Best diversity (highest)
    div_ranking = sorted(zip(methods, diversities), key=lambda x: x[1], reverse=True)
    print("\n🌈 Best Diversity (Creativity):")
    for i, (method, score) in enumerate(div_ranking, 1):
        print(f"  {i}. {method}: {score:.3f}")
    
    # Best composite score (lowest)
    comp_ranking = sorted(zip(methods, composite_scores), key=lambda x: x[1])
    print("\n🎯 Best Overall Quality:")
    for i, (method, score) in enumerate(comp_ranking, 1):
        print(f"  {i}. {method}: {score:.2f}")

visualize_evaluation_results(evaluation_results)

print("\n🔍 EVALUATION INSIGHTS:")
print("• Perplexity measures predictability - lower usually means more coherent")
print("• Diversity (TTR) measures creativity - higher means more varied vocabulary")
print("• Repetition rate measures quality - lower means less repetitive text")
print("• Composite score balances all factors - find the sweet spot for your task")
print("• No single 'best' method - choose based on your specific requirements")

## Summary: Master Advanced Text Generation

You've learned the complete toolkit for sophisticated text generation - from mathematical foundations to practical implementation.

### 🎨 The Four Pillars of Generation

**1. Temperature Scaling**
```python
# Control creativity through probability sharpening/flattening
scaled_logits = logits / temperature
probs = F.softmax(scaled_logits, dim=-1)
```
- **T < 1**: Conservative, predictable (good for formal text)
- **T = 1**: Raw model probabilities
- **T > 1**: Creative, diverse (good for creative writing)

**2. Quality Control Sampling**
```python
# Top-k: Fixed vocabulary filtering
top_k_probs, top_k_indices = torch.topk(probs, k)

# Nucleus: Adaptive vocabulary based on cumulative probability
sorted_probs = torch.sort(probs, descending=True)
nucleus = cumsum_probs <= p
```

**3. Beam Search**
```python
# Global optimization of entire sequences
for each beam:
    expand all possible next tokens
    score by cumulative log-probability
    keep top-k sequences
```

**4. Quality Metrics**
```python
# Multi-dimensional evaluation
perplexity = exp(cross_entropy_loss)
diversity = unique_tokens / total_tokens
repetition = consecutive_duplicates / (total_tokens - 1)
```

### 🎯 Method Selection Guide

**For Creative Writing**:
- Temperature: 0.8-1.2
- Nucleus sampling: p=0.9-0.95
- Avoid beam search (too deterministic)

**For Technical Documentation**:
- Temperature: 0.3-0.7
- Top-k: k=20-40
- Consider beam search for consistency

**For Dialogue/Conversation**:
- Temperature: 0.7-1.0
- Nucleus: p=0.85-0.9
- Balance creativity with coherence

**For Code Generation**:
- Temperature: 0.1-0.5
- Top-k: k=10-30
- Beam search often beneficial

### 📊 Quality Assessment Framework

**Automatic Metrics**:
- **Perplexity**: Model confidence (lower = more coherent)
- **TTR**: Vocabulary diversity (higher = more creative)
- **Repetition Rate**: Text quality (lower = better)
- **Self-BLEU**: Inter-sample similarity (lower = more diverse)

**Task-Specific Evaluation**:
- **Summarization**: ROUGE, factual consistency
- **Translation**: BLEU, semantic similarity
- **QA**: Exact match, F1 score
- **Creative writing**: Human evaluation, engagement

### 🔧 Advanced Techniques

**Controllable Generation**:
```python
# Prompt engineering for control
prompt = "Write a formal business email about [topic]:"

# Classifier-guided generation
if classifier(generated_text) != target_style:
    apply_style_penalty(logits)
```

**Multi-objective Optimization**:
```python
# Balance multiple criteria
score = coherence_weight * coherence_score + \
        creativity_weight * diversity_score - \
        repetition_penalty * repetition_rate
```

### 💡 Best Practices

**1. Start with Proven Baselines**:
- Temperature 0.8, nucleus p=0.9 works for most tasks
- Adjust based on specific requirements

**2. Use Multiple Metrics**:
- No single metric captures generation quality
- Combine automatic and human evaluation

**3. Task-Specific Tuning**:
- Different tasks need different strategies
- A/B test with real users when possible

**4. Monitor Edge Cases**:
- Watch for repetition loops
- Check for factual errors
- Validate semantic coherence

### 🚀 Advanced Applications

**Interactive Generation**:
- Real-time adjustment of parameters
- User feedback incorporation
- Dynamic style adaptation

**Multi-modal Generation**:
- Text conditioned on images
- Audio-to-text generation
- Cross-modal consistency

**Personalized Generation**:
- User-specific style adaptation
- Contextual preference learning
- Progressive refinement

### 🎭 The Art and Science Balance

Great text generation combines:

**Science**: Mathematical foundations, rigorous evaluation, systematic optimization

**Art**: Creative intuition, aesthetic judgment, human-centered design

**Engineering**: Robust implementation, efficient algorithms, production scalability

You now have the complete toolkit to generate high-quality, controllable, and creative text. Use these techniques to build AI systems that truly understand and create human language! 🎨📚🤖