# Advanced Text Generation

Transform model probabilities into coherent, creative text through sophisticated sampling techniques.

## Generation Challenge

Models output probabilities over 50,000+ tokens. Naive approaches (always picking highest probability) produce boring, repetitive text.

## Core Techniques

**Temperature**: Controls creativity vs coherence
**Top-k/Nucleus**: Quality control sampling  
**Beam Search**: Global sequence optimization
**Quality Metrics**: Multi-dimensional evaluation

## What You'll Learn

1. Temperature scaling for creativity control
2. Top-k and nucleus sampling for quality
3. Beam search for structured exploration  
4. Comprehensive quality evaluation
5. Method selection for different tasks

### Environment Setup

This code imports necessary libraries for advanced text generation experiments including PyTorch for deep learning, visualization tools, and our transformer model components. It sets up device detection, random seeds for reproducibility, and prepares the experimental environment.

In [None]:
import sys
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Optional

from src.model.transformer import GPTModel, create_model_config

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

config = create_model_config('tiny')
model = GPTModel(**config).to(device)

print("Environment ready for advanced generation experiments! 🚀")

## 1. Temperature: Controlling the Heat of Creativity

Temperature is the most fundamental knob for controlling text generation. It transforms the raw probability distribution from your model.

### Mathematical Foundation

**Softmax with Temperature**: `p_i = exp(logit_i / T) / Σ exp(logit_j / T)`

- **T < 1**: Conservative, predictable
- **T = 1**: Raw model probabilities  
- **T > 1**: Creative, diverse

### Practical Effects

- **T = 0.3**: Very conservative
- **T = 0.8**: Good balance
- **T = 1.5**: More creative
- **T = 2.0**: Often incoherent

### Temperature Implementation

This code implements temperature scaling and generates text samples with different temperature settings. It demonstrates how temperature affects the sharpness of probability distributions and resulting text creativity.

In [None]:
def apply_temperature(logits, temperature=1.0):
    return logits / temperature

def generate_text(model, prompt, max_length=20, temperature=1.0, top_k=None, top_p=None):
    model.eval()
    generated = prompt.clone()
    
    with torch.no_grad():
        for _ in range(max_length):
            outputs = model(generated)
            next_token_logits = outputs[0, -1, :]
            
            scaled_logits = apply_temperature(next_token_logits, temperature)
            
            if top_k is not None:
                top_k_logits, top_k_indices = torch.topk(scaled_logits, min(top_k, scaled_logits.size(-1)))
                scaled_logits = torch.full_like(scaled_logits, float('-inf'))
                scaled_logits.scatter_(0, top_k_indices, top_k_logits)
            
            if top_p is not None:
                sorted_logits, sorted_indices = torch.sort(scaled_logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
                sorted_indices_to_remove[0] = 0
                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                scaled_logits[indices_to_remove] = float('-inf')
            
            probs = F.softmax(scaled_logits, dim=-1)
            next_token = torch.multinomial(probs, 1)
            generated = torch.cat([generated, next_token.unsqueeze(0)], dim=1)
    
    return generated

def demonstrate_temperature_effects():
    prompt = torch.randint(0, config['vocab_size'], (1, 3), device=device)
    temperatures = [0.1, 0.7, 1.0, 1.5, 2.0]
    
    print("🌡️ TEMPERATURE EFFECTS DEMONSTRATION")
    print(f"Starting prompt: {prompt[0].tolist()}\n")
    
    for temp in temperatures:
        generated = generate_text(model, prompt, max_length=15, temperature=temp)
        new_tokens = generated[0, len(prompt[0]):].tolist()
        print(f"T={temp}: {new_tokens}")

demonstrate_temperature_effects()

## 2. Top-k and Nucleus Sampling: Quality Control

Temperature alone isn't enough. Top-k and nucleus sampling provide quality control by filtering low-probability tokens.

### Top-k Sampling

1. Sort tokens by probability
2. Keep only top k tokens
3. Renormalize and sample

**Effect**: Fixed vocabulary filtering
**Typical values**: k = 40-100

### Nucleus (Top-p) Sampling

1. Sort tokens by probability
2. Add tokens until cumulative probability ≥ p  
3. Sample from this "nucleus"

**Effect**: Adaptive vocabulary size
**Typical values**: p = 0.9-0.95

### Sampling Strategy Analysis

This code analyzes and compares top-k and nucleus sampling strategies across different probability distributions. It demonstrates how each method adapts to high-confidence vs low-confidence model predictions, showing effective vocabulary sizes and entropy changes.

In [None]:
def analyze_sampling_strategies():
    vocab_size = 1000
    logits = torch.randn(vocab_size)
    probs = F.softmax(logits, dim=-1)
    
    k_values = [10, 20, 50, 100]
    p_values = [0.8, 0.9, 0.95, 0.99]
    
    print("📊 SAMPLING STRATEGY ANALYSIS\n")
    
    print("Top-k Sampling Analysis:")
    for k in k_values:
        top_k_probs, _ = torch.topk(probs, k)
        effective_vocab = k
        entropy = -torch.sum(top_k_probs * torch.log(top_k_probs + 1e-10))
        print(f"  k={k}: Effective vocab={effective_vocab}, Entropy={entropy:.3f}")
    
    print("\nNucleus (Top-p) Sampling Analysis:")
    for p in p_values:
        sorted_probs, _ = torch.sort(probs, descending=True)
        cumsum_probs = torch.cumsum(sorted_probs, dim=0)
        nucleus_size = (cumsum_probs <= p).sum().item()
        nucleus_probs = sorted_probs[:nucleus_size]
        entropy = -torch.sum(nucleus_probs * torch.log(nucleus_probs + 1e-10))
        print(f"  p={p}: Effective vocab={nucleus_size}, Entropy={entropy:.3f}")

analyze_sampling_strategies()

### Practical Sampling Comparison

This code generates actual text samples using different sampling methods and measures their diversity, repetition rates, and inter-sample variation. It provides concrete metrics to compare greedy, random, top-k, nucleus, and balanced approaches.

In [None]:
def compare_sampling_methods():
    prompt = torch.randint(0, config['vocab_size'], (1, 3), device=device)
    generation_length = 20
    num_samples = 5
    
    methods = {
        'Greedy (T=0.1)': {'temperature': 0.1, 'top_k': None, 'top_p': None},
        'Pure Random (T=2.0)': {'temperature': 2.0, 'top_k': None, 'top_p': None},
        'Top-k=40': {'temperature': 1.0, 'top_k': 40, 'top_p': None},
        'Nucleus p=0.9': {'temperature': 1.0, 'top_k': None, 'top_p': 0.9},
        'Balanced': {'temperature': 0.8, 'top_k': 50, 'top_p': 0.95}
    }
    
    results = {}
    print(f"🎲 SAMPLING METHOD COMPARISON")
    print(f"Starting with prompt tokens: {prompt[0].tolist()}")
    
    for method_name, params in methods.items():
        print(f"📝 {method_name}:")
        samples = []
        diversities = []
        
        for sample_idx in range(num_samples):
            generated = generate_text(model, prompt, max_length=generation_length, **params)
            new_tokens = generated[0, len(prompt[0]):].tolist()
            samples.append(new_tokens)
            
            diversity = len(set(new_tokens)) / len(new_tokens) if new_tokens else 0
            diversities.append(diversity)
            
            if sample_idx < 2:
                print(f"  Sample {sample_idx + 1}: {new_tokens[:10]}{'...' if len(new_tokens) > 10 else ''}")
        
        avg_diversity = np.mean(diversities)
        print(f"  Avg diversity: {avg_diversity:.3f}\n")
        
        results[method_name] = {'avg_diversity': avg_diversity, 'samples': samples}
    
    return results

sampling_results = compare_sampling_methods()

## 3. Beam Search: Structured Exploration

Beam search optimizes entire sequences rather than individual tokens by maintaining multiple candidate sequences.

### Algorithm

1. **Initialize**: Start with k copies of prompt
2. **Expand**: Generate all possible next tokens for each beam
3. **Score**: Calculate cumulative log-probability
4. **Prune**: Keep only k highest-scoring sequences
5. **Repeat**: Until completion

### Key Formula

**Sequence Score**: `S(x₁...xₙ) = Σᵢ log P(xᵢ | x₁...xᵢ₋₁)`

### Trade-offs

**Advantages**: Higher quality, deterministic, coherent
**Disadvantages**: Less creative, computationally expensive

### Beam Search Implementation

This code implements a complete beam search generator with length normalization and diversity penalties. It demonstrates how beam search optimizes entire sequences rather than individual tokens, comparing results with sampling methods to show the coherence vs creativity trade-off.

In [None]:
class BeamSearchGenerator:
    def __init__(self, model, beam_size=5, max_length=50, length_penalty=1.0):
        self.model = model
        self.beam_size = beam_size
        self.max_length = max_length
        self.length_penalty = length_penalty
    
    def generate(self, prompt_tokens):
        self.model.eval()
        beams = [(prompt_tokens.clone(), 0.0, False)]
        finished_beams = []
        
        with torch.no_grad():
            for step in range(self.max_length):
                candidates = []
                
                for beam_seq, beam_score, is_finished in beams:
                    if is_finished:
                        candidates.append((beam_seq, beam_score, True))
                        continue
                    
                    outputs = self.model(beam_seq)
                    next_token_logits = outputs[0, -1, :]
                    next_token_probs = F.softmax(next_token_logits, dim=-1)
                    
                    top_probs, top_indices = torch.topk(next_token_probs, min(self.beam_size * 2, len(next_token_probs)))
                    
                    for prob, token_id in zip(top_probs, top_indices):
                        new_seq = torch.cat([beam_seq, token_id.unsqueeze(0).unsqueeze(0)], dim=1)
                        new_score = beam_score + torch.log(prob).item()
                        length_normalized_score = new_score / (new_seq.size(1) ** self.length_penalty)
                        is_finished = new_seq.size(1) >= self.max_length
                        candidates.append((new_seq, length_normalized_score, is_finished))
                
                candidates.sort(key=lambda x: x[1], reverse=True)
                
                new_beams = []
                for seq, score, finished in candidates:
                    if finished:
                        finished_beams.append((seq, score))
                    else:
                        new_beams.append((seq, score, finished))
                    if len(new_beams) >= self.beam_size:
                        break
                
                beams = new_beams
                if not beams:
                    break
        
        for seq, score, _ in beams:
            finished_beams.append((seq, score))
        
        finished_beams.sort(key=lambda x: x[1], reverse=True)
        return finished_beams

def compare_beam_search_vs_sampling():
    prompt = torch.randint(0, config['vocab_size'], (1, 4), device=device)
    generation_length = 15
    
    print(f"🔍 BEAM SEARCH vs SAMPLING")
    print(f"Starting prompt: {prompt[0].tolist()}\n")
    
    generator = BeamSearchGenerator(model, beam_size=3, max_length=generation_length)
    beams = generator.generate(prompt)
    
    print("🔬 Beam Search (size=3):")
    for i, (sequence, score) in enumerate(beams[:3]):
        new_tokens = sequence[0, len(prompt[0]):].tolist()
        print(f"  Beam {i+1}: {new_tokens} (score: {score:.3f})")
    
    return beams

beam_results = compare_beam_search_vs_sampling()

## 4. Quality Metrics: Measuring Generation Success

Evaluating text generation requires multiple metrics since there's no single "correct" output.

### Automatic Metrics

**Perplexity**: Model surprise at generated text
- Lower = more predictable (coherent)
- Formula: `PP = exp(-1/N Σ log P(token_i))`

**Diversity Metrics**:
- **TTR**: Unique tokens / Total tokens
- **Self-BLEU**: Similarity between samples (lower = more diverse)

**Repetition Metrics**:
- **n-gram repetition**: Repeated subsequences
- **Consecutive repetition**: Immediate token repeats

### Evaluation Challenge

No single metric captures all aspects of quality. Combine automatic metrics with human evaluation for comprehensive assessment.

### Quality Metrics Implementation

This code implements comprehensive evaluation metrics for text generation including perplexity calculation, diversity measurements, and repetition analysis. It provides a complete toolkit for assessing generation quality across multiple dimensions with both automatic and semantic metrics.

In [None]:
class GenerationMetrics:
    def __init__(self, model):
        self.model = model
    
    def calculate_perplexity(self, text_tokens):
        self.model.eval()
        with torch.no_grad():
            if len(text_tokens.shape) == 1:
                text_tokens = text_tokens.unsqueeze(0)
            
            outputs = self.model(text_tokens)
            logits = outputs[:, :-1, :].contiguous()
            targets = text_tokens[:, 1:].contiguous()
            
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
            perplexity = torch.exp(loss).item()
            return perplexity
    
    def calculate_diversity_metrics(self, token_sequences):
        if not token_sequences:
            return {}
        
        if isinstance(token_sequences[0], torch.Tensor):
            token_sequences = [seq.tolist() for seq in token_sequences]
        
        metrics = {}
        
        ttrs = []
        for seq in token_sequences:
            if len(seq) > 0:
                ttr = len(set(seq)) / len(seq)
                ttrs.append(ttr)
        
        metrics['avg_ttr'] = np.mean(ttrs) if ttrs else 0
        
        all_tokens = set()
        total_tokens = 0
        for seq in token_sequences:
            all_tokens.update(seq)
            total_tokens += len(seq)
        
        metrics['inter_seq_ttr'] = len(all_tokens) / total_tokens if total_tokens > 0 else 0
        metrics['vocab_size'] = len(all_tokens)
        
        return metrics
    
    def calculate_repetition_metrics(self, token_sequence):
        if isinstance(token_sequence, torch.Tensor):
            token_sequence = token_sequence.tolist()
        
        if len(token_sequence) <= 1:
            return {'repetition_rate': 0, 'ngram_repetition_2': 0}
        
        metrics = {}
        
        repetitions = sum(1 for i in range(1, len(token_sequence)) 
                         if token_sequence[i] == token_sequence[i-1])
        metrics['repetition_rate'] = repetitions / (len(token_sequence) - 1)
        
        if len(token_sequence) >= 2:
            ngrams = [tuple(token_sequence[i:i+2]) for i in range(len(token_sequence)-1)]
            unique_ngrams = len(set(ngrams))
            total_ngrams = len(ngrams)
            repetition = 1 - (unique_ngrams / total_ngrams) if total_ngrams > 0 else 0
            metrics['ngram_repetition_2'] = repetition
        
        return metrics
    
    def comprehensive_evaluation(self, generated_sequences):
        results = {
            'perplexities': [],
            'diversity_metrics': {},
            'repetition_metrics': [],
            'quality_scores': []
        }
        
        for seq in generated_sequences:
            if isinstance(seq, list):
                seq_tensor = torch.tensor([seq], device=device)
            else:
                seq_tensor = seq
            
            try:
                perplexity = self.calculate_perplexity(seq_tensor)
                results['perplexities'].append(perplexity)
            except:
                results['perplexities'].append(float('inf'))
            
            rep_metrics = self.calculate_repetition_metrics(seq)
            results['repetition_metrics'].append(rep_metrics)
        
        results['diversity_metrics'] = self.calculate_diversity_metrics(generated_sequences)
        
        avg_perplexity = np.mean([p for p in results['perplexities'] if p != float('inf')])
        avg_repetition = np.mean([m['repetition_rate'] for m in results['repetition_metrics']])
        diversity_score = results['diversity_metrics'].get('avg_ttr', 0)
        
        quality_score = {
            'perplexity_score': avg_perplexity,
            'repetition_penalty': avg_repetition * 100,
            'diversity_bonus': diversity_score * 100,
            'composite_score': avg_perplexity + (avg_repetition * 50) - (diversity_score * 20)
        }
        results['quality_scores'] = quality_score
        
        return results

print("📊 Generation metrics toolkit ready!")

### Comprehensive Method Evaluation

This code evaluates all generation methods using the complete metrics suite, providing detailed comparisons across perplexity, diversity, repetition, and composite quality scores. It includes visualization and ranking to identify optimal methods for different use cases.

In [None]:
def evaluate_generation_methods():
    metrics_calculator = GenerationMetrics(model)
    
    methods = {
        'Greedy': {'temperature': 0.1, 'top_k': None, 'top_p': None},
        'High Temperature': {'temperature': 2.0, 'top_k': None, 'top_p': None},
        'Top-k (k=40)': {'temperature': 1.0, 'top_k': 40, 'top_p': None},
        'Nucleus (p=0.9)': {'temperature': 1.0, 'top_k': None, 'top_p': 0.9},
        'Balanced': {'temperature': 0.8, 'top_k': 50, 'top_p': 0.95}
    }
    
    prompt = torch.randint(0, config['vocab_size'], (1, 4), device=device)
    num_samples = 5
    generation_length = 20
    
    print(f"🔬 COMPREHENSIVE GENERATION EVALUATION")
    print(f"Generating {num_samples} samples of {generation_length} tokens each...\n")
    
    all_results = {}
    
    for method_name, params in methods.items():
        print(f"📝 Evaluating {method_name}...")
        
        samples = []
        for _ in range(num_samples):
            generated = generate_text(model, prompt, max_length=generation_length, **params)
            new_tokens = generated[0, len(prompt[0]):].tolist()
            samples.append(new_tokens)
        
        results = metrics_calculator.comprehensive_evaluation(samples)
        all_results[method_name] = results
        
        print(f"  Avg Perplexity: {results['quality_scores']['perplexity_score']:.2f}")
        print(f"  Avg TTR (diversity): {results['diversity_metrics']['avg_ttr']:.3f}")
        print(f"  Repetition Rate: {np.mean([m['repetition_rate'] for m in results['repetition_metrics']]):.3f}")
        print(f"  Composite Score: {results['quality_scores']['composite_score']:.2f} (lower = better)\n")
    
    return all_results

def visualize_evaluation_results(results):
    methods = list(results.keys())
    
    perplexities = [results[method]['quality_scores']['perplexity_score'] for method in methods]
    diversities = [results[method]['diversity_metrics']['avg_ttr'] for method in methods]
    composite_scores = [results[method]['quality_scores']['composite_score'] for method in methods]
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    axes[0].bar(methods, perplexities, color='skyblue', alpha=0.8)
    axes[0].set_title('Perplexity (Lower = Better)', weight='bold')
    axes[0].tick_params(axis='x', rotation=45)
    
    axes[1].bar(methods, diversities, color='lightgreen', alpha=0.8)
    axes[1].set_title('Diversity (TTR, Higher = Better)', weight='bold')
    axes[1].tick_params(axis='x', rotation=45)
    
    axes[2].bar(methods, composite_scores, color='gold', alpha=0.8)
    axes[2].set_title('Composite Quality Score (Lower = Better)', weight='bold')
    axes[2].tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    print("🏆 GENERATION METHOD RANKINGS:")
    
    comp_ranking = sorted(zip(methods, composite_scores), key=lambda x: x[1])
    print("\n🎯 Best Overall Quality:")
    for i, (method, score) in enumerate(comp_ranking, 1):
        print(f"  {i}. {method}: {score:.2f}")

evaluation_results = evaluate_generation_methods()
visualize_evaluation_results(evaluation_results)

## Summary: Advanced Text Generation Mastery

Transform raw model probabilities into high-quality text through sophisticated sampling and evaluation.

### Core Techniques

**Temperature Scaling**: `logits / T` - Controls creativity vs coherence
- T < 1: Conservative, predictable
- T > 1: Creative, diverse  
- Sweet spot: 0.7-1.2

**Quality Control Sampling**:
- **Top-k**: Fixed vocabulary filtering (k=40-100)
- **Nucleus**: Adaptive vocabulary (p=0.9-0.95)
- **Combined**: Best of both approaches

**Beam Search**: Global sequence optimization
- Maintains k candidate sequences
- Optimizes cumulative probability
- Better coherence, less creativity

### Method Selection Guide

**Creative Writing**: T=0.8-1.2, nucleus p=0.9, avoid beam search
**Technical Docs**: T=0.3-0.7, top-k=20-40, consider beam search  
**Dialogue**: T=0.7-1.0, nucleus p=0.85-0.9
**Code Generation**: T=0.1-0.5, top-k=10-30, beam search beneficial

### Quality Assessment

**Key Metrics**:
- **Perplexity**: Model confidence (lower = more coherent)
- **TTR**: Vocabulary diversity (higher = more creative)
- **Repetition Rate**: Text quality (lower = better)
- **Self-BLEU**: Inter-sample similarity (lower = more diverse)

**Best Practice**: Combine automatic metrics with human evaluation for comprehensive quality assessment.

### Implementation Framework

```python
# Balanced generation (good default)
temperature = 0.8
top_k = 50  
top_p = 0.95

# Apply techniques
scaled_logits = logits / temperature
filtered_probs = apply_top_k_nucleus(scaled_logits, top_k, top_p)
next_token = torch.multinomial(filtered_probs, 1)
```

You now have the complete toolkit for sophisticated, controllable text generation across any domain or application.