# Text Generation: From Probabilities to Creative Text

Once trained, how do transformers generate text? They produce probability distributions over vocabulary, then use sampling strategies to choose the next word.

## The Generation Process
1. **Start with prompt**: "The cat"
2. **Get probabilities**: Model outputs distribution over all possible next words
3. **Sample next word**: Use strategy (greedy, temperature, top-k, etc.)
4. **Add to sequence**: "The cat sat"
5. **Repeat**: Until complete sentence or max length

## What You'll Learn
- **Autoregressive generation** - Step-by-step text creation
- **Sampling strategies** - Greedy, temperature, top-k, top-p
- **Temperature effects** - Controlling creativity vs coherence  
- **Quality metrics** - Measuring generation quality

import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Tuple, Optional
import math
from collections import Counter
import random

plt.style.use('default')
sns.set_palette("husl")
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

print("Environment setup complete!")
print(f"PyTorch version: {torch.__version__}")

## Autoregressive Generation Demo

Demonstrate step-by-step text generation process with realistic probability distributions.

## Sampling Strategies Comparison

Compare different strategies for selecting the next token from probability distributions.

In [ ]:
# Simulate realistic step-by-step generation
vocab = ["The", "cat", "sat", "on", "the", "mat", ".", "dog", "ran"]
word_to_id = {word: i for i, word in enumerate(vocab)}

def simulate_generation_step(context_words):
    """Simulate model probabilities based on context"""
    torch.manual_seed(42 + len(context_words))
    logits = torch.randn(len(vocab))
    
    # Add realistic biases based on context
    if "The" in context_words:
        logits[word_to_id["cat"]] += 2.0
    if "cat" in context_words:
        logits[word_to_id["sat"]] += 2.5
    if "sat" in context_words:
        logits[word_to_id["on"]] += 2.0
        
    return logits

# Demonstrate autoregressive generation step by step
context = ["The"]
print("🔄 Autoregressive Generation Process")
print("=" * 40)

for step in range(4):
    logits = simulate_generation_step(context)
    probs = F.softmax(logits, dim=0)
    
    print(f"\nStep {step + 1}:")
    print(f"Context: '{' '.join(context)}'")
    print("Top 3 next word predictions:")
    
    top_probs, top_indices = torch.topk(probs, 3)
    for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
        word = vocab[idx.item()]
        print(f"  {i+1}. '{word}': {prob.item():.3f}")
    
    # Select most likely (greedy decoding)
    next_word = vocab[top_indices[0].item()]
    context.append(next_word)
    print(f"✓ Selected: '{next_word}'")
    
    if next_word == ".":
        break

print(f"\n🎯 Final generated text: '{' '.join(context)}'")
print("✅ That's autoregressive generation!")

In [None]:
## Sampling Strategy Implementations

Implement and compare different token selection strategies.

In [ ]:
# Implement different sampling strategies
vocab = ["cat", "dog", "bird", "fish", "mouse", "horse", "cow", "sheep"]
probs = torch.tensor([0.5, 0.3, 0.1, 0.05, 0.03, 0.01, 0.005, 0.005])
logits = torch.log(probs)

def greedy_selection(probs):
    """Always select the most likely token"""
    return torch.argmax(probs)

def temperature_sampling(logits, temperature=1.0):
    """Scale logits by temperature before sampling"""
    if temperature == 0:
        return torch.argmax(logits)
    scaled_logits = logits / temperature
    probs = F.softmax(scaled_logits, dim=0)
    return torch.multinomial(probs, 1).item()

def top_k_sampling(logits, k=3):
    """Sample from top-k most likely tokens"""
    top_k_logits, top_k_indices = torch.topk(logits, k)
    probs = F.softmax(top_k_logits, dim=0)
    selected_idx = torch.multinomial(probs, 1).item()
    return top_k_indices[selected_idx].item()

def top_p_sampling(logits, p=0.9):
    """Sample from smallest set with cumulative probability >= p"""
    probs = F.softmax(logits, dim=0)
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    cumulative_probs = torch.cumsum(sorted_probs, dim=0)
    
    # Remove tokens with cumulative probability above p
    sorted_indices_to_remove = cumulative_probs > p
    sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
    sorted_indices_to_remove[0] = 0
    
    indices_to_remove = sorted_indices[sorted_indices_to_remove]
    probs[indices_to_remove] = 0
    probs = probs / probs.sum()
    
    return torch.multinomial(probs, 1).item()

print("🎲 Generation Strategy Comparison")
print(f"Vocabulary: {vocab}")
print(f"True probabilities: {probs.tolist()}")

print(f"\nGreedy: {vocab[greedy_selection(probs)]} (always deterministic)")

print("\nTemperature sampling (5 samples each):")
for temp in [0.1, 0.5, 1.0, 2.0]:
    samples = [vocab[temperature_sampling(logits, temp)] for _ in range(5)]
    print(f"  T={temp}: {samples}")

print(f"\nTop-k (k=3): {[vocab[top_k_sampling(logits, k=3)] for _ in range(5)]}")
print(f"Top-p (p=0.9): {[vocab[top_p_sampling(logits, p=0.9)] for _ in range(5)]}")

print("\n🔑 Key insight: Strategy dramatically affects diversity!")

In [None]:
## Temperature Effects Visualization

Visualize how temperature affects probability distributions and creativity.

In [ ]:
# Visualize temperature effects on probability distributions
vocab = ["the", "cat", "dog", "sat", "ran", "jumped", "quickly", "slowly"]
logits = torch.tensor([3.0, 2.5, 2.0, 1.5, 1.0, 0.5, 0.0, -0.5])

temperatures = [0.1, 0.5, 1.0, 2.0]
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
axes = axes.flatten()

for i, temp in enumerate(temperatures):
    if temp == 0:
        probs = torch.zeros_like(logits)
        probs[torch.argmax(logits)] = 1.0
    else:
        scaled_logits = logits / temp
        probs = F.softmax(scaled_logits, dim=0)
    
    bars = axes[i].bar(range(len(vocab)), probs, alpha=0.7, color='skyblue')
    axes[i].set_xlabel('Tokens')
    axes[i].set_ylabel('Probability')
    axes[i].set_title(f'Temperature = {temp}')
    axes[i].set_xticks(range(len(vocab)))
    axes[i].set_xticklabels(vocab, rotation=45)
    axes[i].grid(True, alpha=0.3)
    
    # Highlight most likely token
    max_idx = torch.argmax(probs)
    bars[max_idx].set_color('red')
    
    # Add entropy (measure of randomness)
    entropy = -torch.sum(probs * torch.log(probs + 1e-10))
    axes[i].text(0.7, 0.9, f'Entropy: {entropy:.2f}', transform=axes[i].transAxes, 
                bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow", alpha=0.7))

plt.tight_layout()
plt.show()

print("🌡️ Temperature Effects:")
print("• T < 1.0: More focused, deterministic (low entropy)")
print("• T = 1.0: Original distribution")  
print("• T > 1.0: More random, creative (high entropy)")
print("• T → 0:   Greedy decoding (entropy = 0)")
print("• T → ∞:   Uniform random (maximum entropy)")
print("\n🎯 Use case guide:")
print("• Factual tasks: T = 0.3-0.7")
print("• Creative writing: T = 0.8-1.2")
print("• Brainstorming: T = 1.0-2.0")

In [None]:
## Text Quality Analysis

Analyze generated text quality using diversity and repetition metrics.

In [ ]:
def analyze_text_quality(text):
    """Analyze text quality using various metrics"""
    words = text.split()
    unique_words = set(words)
    
    print(f"Text: '{text}'")
    print(f"Length: {len(words)} words")
    print(f"Unique words: {len(unique_words)}")
    
    # Repetition analysis
    if len(words) > 0:
        repetition_ratio = 1 - len(unique_words)/len(words)
        print(f"Repetition ratio: {repetition_ratio:.3f}")
        
        # Diversity score  
        diversity = len(unique_words) / len(words)
        print(f"Diversity score: {diversity:.3f}")
        
        # Find repeated words
        word_counts = Counter(words)
        repeated_words = {word: count for word, count in word_counts.items() if count > 1}
        if repeated_words:
            print(f"Repeated words: {repeated_words}")
    
    print()

# Analyze different quality examples
examples = [
    "The cat sat on the mat and looked around carefully",  # Good quality
    "The the the cat cat sat sat on on the the",          # High repetition
    "Quantum flux temporal paradox synthesis nebula",      # Too random
    "Cat cat cat cat cat cat cat cat cat",                 # Extreme repetition
    "The weather is nice today and tomorrow looks good"    # Balanced
]

print("📊 Text Quality Analysis Examples:")
print("=" * 45)

for i, text in enumerate(examples, 1):
    print(f"Example {i}:")
    analyze_text_quality(text)

print("Quality Guidelines:")
print("✅ Lower repetition ratio = better (< 0.3 good)")
print("✅ Higher diversity = more interesting (> 0.7 good)") 
print("✅ Balance diversity with coherence")
print("✅ Avoid excessive repetition")
print("⚠️ Human evaluation often most reliable metric")

## Summary

You've mastered text generation with transformers!

**Core Concepts**:
- **Autoregressive**: Generate one token at a time using previous context
- **Strategies**: Greedy, temperature, top-k, top-p each balance coherence vs creativity
- **Temperature**: Controls randomness (low = focused, high = creative)
- **Quality**: Balance diversity, coherence, and repetition avoidance

**Strategy Selection Guide**:

**For factual, coherent text**:
- Low temperature (0.3-0.7)
- Top-p sampling (p=0.8-0.9)
- Minimize randomness

**For creative writing**:
- Higher temperature (0.8-1.2) 
- Top-k or top-p sampling
- Allow exploration

**For reliable completion**:
- Greedy or low-temperature sampling
- Focus on consistency

**Next Steps**: You now understand the complete transformer pipeline from tokenization to generation! This knowledge applies to all modern language models.

Ready to build your own language model! 🚀

## 6. Advanced Generation Techniques

Let's explore some advanced techniques for improving generation quality.

## Summary

In this notebook, we've explored the fascinating world of text generation:

1. **Autoregressive Generation** - Step-by-step token prediction process
2. **Generation Strategies** - Greedy, temperature, top-k, top-p sampling
3. **Temperature Effects** - Controlling randomness and creativity
4. **Beam Search** - Exploring multiple sequence possibilities
5. **Advanced Techniques** - Repetition penalty, typical sampling
6. **Quality Metrics** - Measuring repetition, diversity, coherence

### Key Generation Insights:

- **Strategy matters**: Different approaches produce different styles
- **Temperature is crucial**: Controls the creativity-coherence tradeoff
- **Top-p often best**: Adaptive cutoff based on probability mass
- **Repetition is the enemy**: Use penalties and diverse sampling
- **Quality is multifaceted**: No single metric captures everything

### Generation Strategy Guide:

**For coherent, factual text:**
- Low temperature (0.3-0.7)
- Top-p sampling (p=0.8-0.9)
- Mild repetition penalty (1.1-1.3)

**For creative writing:**
- Higher temperature (0.8-1.2)
- Top-p or top-k sampling
- Strong repetition penalty (1.3-1.5)

**For reliable completion:**
- Beam search with beam width 3-5
- Length penalties to avoid too-short sequences
- Multiple candidates for selection

### Modern Developments:

- **Contrastive search**: Balances probability and diversity
- **Typical sampling**: Avoids both too-common and too-rare tokens
- **MCTS-based generation**: Uses tree search for better planning
- **Classifier-free guidance**: Steers generation toward desired attributes

### Quality Considerations:

- **Coherence**: Does the text make sense?
- **Consistency**: Are facts and details consistent?
- **Relevance**: Does it address the prompt appropriately?
- **Fluency**: Is the language natural and grammatical?
- **Diversity**: Is the output varied and interesting?

The art of text generation lies in balancing these competing objectives. The best approach depends on your specific use case, from creative writing to factual question answering to code generation.

Congratulations! You've now completed a comprehensive journey through transformer architecture, training, and generation. You understand how these powerful models work from the ground up! 🎉