# Text Generation: From Probabilities to Words

Once trained, how do transformers generate text? This notebook explores generation strategies from simple greedy decoding to advanced sampling techniques.

## Learning Objectives

1. **Autoregressive Generation**: How transformers generate text step-by-step
2. **Generation Strategies**: Greedy, temperature, top-k, top-p sampling  
3. **Temperature Effects**: Controlling creativity vs coherence
4. **Beam Search**: Exploring multiple possibilities simultaneously
5. **Quality Metrics**: Measuring generation quality and diversity

Let's unlock the creative potential of your trained transformer! 🎨

In [None]:
import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Tuple, Optional, Dict, Any
import math
from collections import Counter
import random

# Set style for better plots
plt.style.use('default')
sns.set_palette("husl")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

print("Environment setup complete!")
print(f"PyTorch version: {torch.__version__}")

import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Tuple, Optional
import math
from collections import Counter
import random

plt.style.use('default')
sns.set_palette("husl")
torch.manual_seed(42)
np.random.seed(42)

print("Environment setup complete!")
print(f"PyTorch version: {torch.__version__}")

## Autoregressive Generation

**Autoregressive = One token at a time, using previous tokens as context**

**Process**:
1. Start with prompt: "The cat"
2. Model predicts probabilities for next word
3. Select next word using some strategy  
4. Add to sequence: "The cat sat"
5. Repeat until complete

This is how all language models generate text!

In [None]:
vocab = ["The", "cat", "sat", "on", "the", "mat", ".", "dog", "ran"]
word_to_id = {word: i for i, word in enumerate(vocab)}

def simulate_generation_step(context_words):
    torch.manual_seed(42 + len(context_words))
    logits = torch.randn(len(vocab))
    
    # Add realistic biases
    if "The" in context_words:
        logits[word_to_id["cat"]] += 2.0
    if "cat" in context_words:
        logits[word_to_id["sat"]] += 2.5
    if "sat" in context_words:
        logits[word_to_id["on"]] += 2.0
        
    return logits

# Demonstrate step-by-step generation
context = ["The"]
max_steps = 4

print("🔄 Autoregressive Generation Demo")
print("=" * 35)

for step in range(max_steps):
    logits = simulate_generation_step(context)
    probs = F.softmax(logits, dim=0)
    
    print(f"\nStep {step + 1}:")
    print(f"Context: {' '.join(context)}")
    print("Top 3 predictions:")
    
    top_probs, top_indices = torch.topk(probs, 3)
    for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
        word = vocab[idx.item()]
        print(f"  {i+1}. {word}: {prob.item():.3f}")
    
    # Select most likely (greedy)
    next_word = vocab[top_indices[0].item()]
    context.append(next_word)
    print(f"Selected: {next_word}")
    
    if next_word == ".":
        break

print(f"\nFinal text: '{' '.join(context)}'")
print("✅ That's autoregressive generation!")

## Generation Strategies

**How do we choose the next token?** Different strategies produce different text styles:

1. **Greedy**: Always pick most likely token (deterministic, boring)
2. **Temperature**: Scale probabilities to control randomness  
3. **Top-k**: Sample from k most likely tokens
4. **Top-p**: Sample from tokens with cumulative probability p

Each has trade-offs between coherence and creativity!

In [None]:
vocab = ["cat", "dog", "bird", "fish", "mouse", "horse", "cow", "sheep"]
probs = torch.tensor([0.5, 0.3, 0.1, 0.05, 0.03, 0.01, 0.005, 0.005])
logits = torch.log(probs)

def greedy_selection(probs):
    return torch.argmax(probs)

def temperature_sampling(logits, temperature=1.0):
    if temperature == 0:
        return torch.argmax(logits)
    scaled_logits = logits / temperature
    probs = F.softmax(scaled_logits, dim=0)
    return torch.multinomial(probs, 1).item()

def top_k_sampling(logits, k=3):
    top_k_logits, top_k_indices = torch.topk(logits, k)
    probs = F.softmax(top_k_logits, dim=0)
    selected_idx = torch.multinomial(probs, 1).item()
    return top_k_indices[selected_idx].item()

def top_p_sampling(logits, p=0.9):
    probs = F.softmax(logits, dim=0)
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    cumulative_probs = torch.cumsum(sorted_probs, dim=0)
    
    sorted_indices_to_remove = cumulative_probs > p
    sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
    sorted_indices_to_remove[0] = 0
    
    indices_to_remove = sorted_indices[sorted_indices_to_remove]
    probs[indices_to_remove] = 0
    probs = probs / probs.sum()
    
    return torch.multinomial(probs, 1).item()

print("Generation Strategy Comparison:")
print(f"Vocabulary: {vocab}")

print(f"\nGreedy: {vocab[greedy_selection(probs)]} (always same)")

print("\nTemperature sampling (5 samples each):")
for temp in [0.1, 0.5, 1.0, 2.0]:
    samples = [vocab[temperature_sampling(logits, temp)] for _ in range(5)]
    print(f"  T={temp}: {samples}")

print(f"\nTop-k (k=3): {[vocab[top_k_sampling(logits, k=3)] for _ in range(5)]}")
print(f"Top-p (p=0.9): {[vocab[top_p_sampling(logits, p=0.9)] for _ in range(5)]}")

print("\n🔑 Key insight: Strategy dramatically affects output diversity!")

## Temperature Effects

Temperature controls the "sharpness" of probability distributions, directly affecting creativity vs coherence.

In [None]:
vocab = ["the", "cat", "dog", "sat", "ran", "jumped", "quickly", "slowly"]
logits = torch.tensor([3.0, 2.5, 2.0, 1.5, 1.0, 0.5, 0.0, -0.5])

temperatures = [0.1, 0.5, 1.0, 2.0]
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
axes = axes.flatten()

for i, temp in enumerate(temperatures):
    if temp == 0:
        probs = torch.zeros_like(logits)
        probs[torch.argmax(logits)] = 1.0
    else:
        scaled_logits = logits / temp
        probs = F.softmax(scaled_logits, dim=0)
    
    bars = axes[i].bar(range(len(vocab)), probs, alpha=0.7)
    axes[i].set_xlabel('Tokens')
    axes[i].set_ylabel('Probability')
    axes[i].set_title(f'Temperature = {temp}')
    axes[i].set_xticks(range(len(vocab)))
    axes[i].set_xticklabels(vocab, rotation=45)
    axes[i].grid(True, alpha=0.3)
    
    # Highlight max probability
    max_idx = torch.argmax(probs)
    bars[max_idx].set_color('red')

plt.tight_layout()
plt.show()

print("Temperature Effects:")
print("• T < 1.0: More focused, deterministic")
print("• T = 1.0: Original distribution")  
print("• T > 1.0: More random, creative")
print("• T → 0:   Greedy decoding")
print("• T → ∞:   Uniform random")

## Generation Quality Metrics

How do we measure the quality of generated text? Key metrics include diversity, coherence, and repetition.

In [None]:
def analyze_text_quality(text):
    words = text.split()
    unique_words = set(words)
    
    # Basic metrics
    print(f"Text: '{text}'")
    print(f"Length: {len(words)} words")
    print(f"Unique words: {len(unique_words)}")
    print(f"Repetition ratio: {1 - len(unique_words)/len(words):.3f}")
    
    # Repetition analysis
    word_counts = Counter(words)
    repeated_words = {word: count for word, count in word_counts.items() if count > 1}
    if repeated_words:
        print(f"Repeated words: {repeated_words}")
    
    # Diversity score (estimated)
    diversity = len(unique_words) / len(words) if len(words) > 0 else 0
    print(f"Diversity score: {diversity:.3f}")
    print()

# Example texts with different quality patterns
examples = [
    "The cat sat on the mat and looked around",  # Good
    "The the the cat cat sat sat on on",         # Repetitive  
    "Quantum flux temporal paradox synthesis",    # Too random
]

print("📊 Text Quality Analysis:")
print("=" * 30)

for i, text in enumerate(examples, 1):
    print(f"Example {i}:")
    analyze_text_quality(text)

print("Quality Guidelines:")
print("• Lower repetition ratio = better")
print("• Higher diversity = more interesting") 
print("• Must balance diversity with coherence")
print("• Human evaluation often most reliable")

## Summary: Text Generation Mastery

You've learned how to generate text with transformers!

### Core Concepts
- **Autoregressive**: Generate one token at a time using previous context
- **Strategies**: Greedy, temperature, top-k, top-p each have different trade-offs
- **Temperature**: Controls creativity (low = focused, high = random)
- **Quality**: Balance diversity, coherence, and repetition avoidance

### Strategy Guide

**For factual, coherent text**:
- Low temperature (0.3-0.7)
- Top-p sampling (p=0.8-0.9)
- Avoid high randomness

**For creative writing**:
- Higher temperature (0.8-1.2) 
- Top-k or top-p sampling
- Allow more exploration

**For reliable completion**:
- Beam search for multiple candidates
- Lower temperature for consistency

### Next Steps
Now you understand the complete transformer pipeline: tokenization → attention → blocks → training → generation!

You're ready to build and deploy your own language models! 🚀

## 6. Advanced Generation Techniques

Let's explore some advanced techniques for improving generation quality.

## 7. Generation Quality Metrics

How do we measure the quality of generated text? Let's explore some metrics.

## Summary

In this notebook, we've explored the fascinating world of text generation:

1. **Autoregressive Generation** - Step-by-step token prediction process
2. **Generation Strategies** - Greedy, temperature, top-k, top-p sampling
3. **Temperature Effects** - Controlling randomness and creativity
4. **Beam Search** - Exploring multiple sequence possibilities
5. **Advanced Techniques** - Repetition penalty, typical sampling
6. **Quality Metrics** - Measuring repetition, diversity, coherence

### Key Generation Insights:

- **Strategy matters**: Different approaches produce different styles
- **Temperature is crucial**: Controls the creativity-coherence tradeoff
- **Top-p often best**: Adaptive cutoff based on probability mass
- **Repetition is the enemy**: Use penalties and diverse sampling
- **Quality is multifaceted**: No single metric captures everything

### Generation Strategy Guide:

**For coherent, factual text:**
- Low temperature (0.3-0.7)
- Top-p sampling (p=0.8-0.9)
- Mild repetition penalty (1.1-1.3)

**For creative writing:**
- Higher temperature (0.8-1.2)
- Top-p or top-k sampling
- Strong repetition penalty (1.3-1.5)

**For reliable completion:**
- Beam search with beam width 3-5
- Length penalties to avoid too-short sequences
- Multiple candidates for selection

### Modern Developments:

- **Contrastive search**: Balances probability and diversity
- **Typical sampling**: Avoids both too-common and too-rare tokens
- **MCTS-based generation**: Uses tree search for better planning
- **Classifier-free guidance**: Steers generation toward desired attributes

### Quality Considerations:

- **Coherence**: Does the text make sense?
- **Consistency**: Are facts and details consistent?
- **Relevance**: Does it address the prompt appropriately?
- **Fluency**: Is the language natural and grammatical?
- **Diversity**: Is the output varied and interesting?

The art of text generation lies in balancing these competing objectives. The best approach depends on your specific use case, from creative writing to factual question answering to code generation.

Congratulations! You've now completed a comprehensive journey through transformer architecture, training, and generation. You understand how these powerful models work from the ground up! 🎉