# Text Generation: From Probabilities to Creative Writing

Once a transformer is trained, how do we use it to generate text? This notebook explores the fascinating world of text generation strategies, from simple greedy decoding to creative sampling techniques.

## What You'll Learn

1. **Autoregressive Generation** - How transformers generate text step by step
2. **Greedy Decoding** - Always pick the most likely next token
3. **Sampling Strategies** - Temperature, top-k, top-p (nucleus) sampling
4. **Beam Search** - Exploring multiple possibilities
5. **Generation Quality** - Controlling repetition, coherence, and creativity
6. **Advanced Techniques** - Contrastive search, typical sampling

Let's unlock the creative potential of transformers!

In [None]:
import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Tuple, Optional, Dict, Any
import math
from collections import Counter
import random

# Set style for better plots
plt.style.use('default')
sns.set_palette("husl")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

print("Environment setup complete!")
print(f"PyTorch version: {torch.__version__}")

## 1. Autoregressive Generation Basics

Transformers generate text autoregressively - one token at a time, using previously generated tokens as context. Let's see how this works step by step.

In [None]:
def demonstrate_autoregressive_generation():
    """Show how autoregressive generation works step by step."""
    
    print("Autoregressive Text Generation")
    print("=" * 35)
    
    # Simple vocabulary for demonstration
    vocab = ["<PAD>", "The", "cat", "sat", "on", "the", "mat", ".", "dog", "ran"]
    vocab_size = len(vocab)
    word_to_id = {word: i for i, word in enumerate(vocab)}
    id_to_word = {i: word for i, word in enumerate(vocab)}
    
    print(f"Vocabulary: {vocab}")
    print(f"Vocabulary size: {vocab_size}")
    
    # Simulate model predictions (fake logits for demonstration)
    # In reality, these would come from a trained transformer
    def get_fake_logits(context_tokens):
        """Simulate model predictions based on context."""
        torch.manual_seed(42 + len(context_tokens))  # Different seed for each step
        logits = torch.randn(vocab_size)
        
        # Add some bias to make more realistic predictions
        if len(context_tokens) == 1 and context_tokens[0] == word_to_id["The"]:
            logits[word_to_id["cat"]] += 2.0  # "The" -> "cat" more likely
            logits[word_to_id["dog"]] += 1.5  # "The" -> "dog" somewhat likely
        elif len(context_tokens) >= 2 and context_tokens[-1] == word_to_id["cat"]:
            logits[word_to_id["sat"]] += 3.0  # "cat" -> "sat" very likely
        elif len(context_tokens) >= 2 and context_tokens[-1] == word_to_id["sat"]:
            logits[word_to_id["on"]] += 2.5   # "sat" -> "on" likely
        
        return logits
    
    # Generate text step by step
    prompt = "The"
    context = [word_to_id[prompt]]
    max_length = 6
    
    print(f"\nStarting with prompt: '{prompt}'")
    print("Generation process:")
    print("-" * 50)
    
    for step in range(max_length - 1):
        # Get model predictions
        logits = get_fake_logits(context)
        probabilities = F.softmax(logits, dim=0)
        
        # Show current context
        context_words = [id_to_word[token_id] for token_id in context]
        print(f"\nStep {step + 1}:")
        print(f"Context: {' '.join(context_words)}")
        
        # Show top 3 predictions
        top_probs, top_indices = torch.topk(probabilities, 3)
        print("Top predictions:")
        for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
            word = id_to_word[idx.item()]
            print(f"  {i+1}. {word}: {prob.item():.3f}")
        
        # Greedy selection (pick most likely)
        next_token_id = top_indices[0].item()
        next_word = id_to_word[next_token_id]
        
        # Add to context
        context.append(next_token_id)
        
        print(f"Selected: {next_word}")
        
        # Stop if we hit end token
        if next_word == ".":
            break
    
    # Final result
    final_text = " ".join([id_to_word[token_id] for token_id in context])
    print(f"\nFinal generated text: '{final_text}'")
    
    print("\nKey Points:")
    print("• Each step uses ALL previous tokens as context")
    print("• Model predicts probability distribution over vocabulary")
    print("• Selection strategy determines the next token")
    print("• Process continues until stopping condition")
    
    return context, vocab, word_to_id, id_to_word

context, vocab, word_to_id, id_to_word = demonstrate_autoregressive_generation()

## 2. Generation Strategies

How we choose the next token dramatically affects the quality and creativity of generated text. Let's explore different strategies.

In [None]:
def demonstrate_generation_strategies():
    """Compare different token selection strategies."""
    
    print("Token Selection Strategies")
    print("=" * 30)
    
    # Create example probability distribution
    vocab = ["cat", "dog", "bird", "fish", "mouse", "horse", "cow", "sheep"]
    # Simulate realistic probability distribution (some tokens much more likely)
    raw_probs = torch.tensor([0.5, 0.3, 0.1, 0.05, 0.03, 0.01, 0.005, 0.005])
    probs = raw_probs / raw_probs.sum()  # Normalize
    
    print(f"Example probability distribution:")
    for word, prob in zip(vocab, probs):
        print(f"  {word:<6}: {prob:.3f} {'█' * int(prob * 50)}")
    
    return vocab, probs

vocab, probs = demonstrate_generation_strategies()

# 1. Greedy Decoding
def greedy_selection(probs):
    """Always select the most likely token."""
    return torch.argmax(probs)

# 2. Temperature Sampling
def temperature_sampling(logits, temperature=1.0):
    """Sample with temperature scaling."""
    if temperature == 0:
        return torch.argmax(logits)
    
    scaled_logits = logits / temperature
    probs = F.softmax(scaled_logits, dim=0)
    return torch.multinomial(probs, 1).item()

# 3. Top-k Sampling
def top_k_sampling(logits, k=3):
    """Sample from top-k most likely tokens."""
    top_k_logits, top_k_indices = torch.topk(logits, k)
    probs = F.softmax(top_k_logits, dim=0)
    selected_idx = torch.multinomial(probs, 1).item()
    return top_k_indices[selected_idx].item()

# 4. Top-p (Nucleus) Sampling
def top_p_sampling(logits, p=0.9):
    """Sample from smallest set of tokens with cumulative probability >= p."""
    probs = F.softmax(logits, dim=0)
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    
    # Find cumulative probabilities
    cumulative_probs = torch.cumsum(sorted_probs, dim=0)
    
    # Remove tokens with cumulative probability above threshold
    sorted_indices_to_remove = cumulative_probs > p
    sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
    sorted_indices_to_remove[0] = 0  # Keep at least one token
    
    # Zero out probabilities of removed tokens
    indices_to_remove = sorted_indices[sorted_indices_to_remove]
    probs[indices_to_remove] = 0
    
    # Renormalize
    probs = probs / probs.sum()
    
    return torch.multinomial(probs, 1).item()

# Test strategies
logits = torch.log(probs)  # Convert back to logits

print("\nStrategy Comparison:")
print("-" * 25)

# Greedy
greedy_idx = greedy_selection(probs)
print(f"Greedy:       {vocab[greedy_idx]} (always same)")

# Temperature sampling with different temperatures
temperatures = [0.1, 0.5, 1.0, 2.0]
print("\nTemperature sampling (5 samples each):")
for temp in temperatures:
    samples = []
    for _ in range(5):
        idx = temperature_sampling(logits, temp)
        samples.append(vocab[idx])
    print(f"  T={temp:3.1f}: {samples}")

# Top-k sampling
print("\nTop-k sampling (k=3, 5 samples):")
top_k_samples = []
for _ in range(5):
    idx = top_k_sampling(logits, k=3)
    top_k_samples.append(vocab[idx])
print(f"  {top_k_samples}")

# Top-p sampling
print("\nTop-p sampling (p=0.9, 5 samples):")
top_p_samples = []
for _ in range(5):
    idx = top_p_sampling(logits, p=0.9)
    top_p_samples.append(vocab[idx])
print(f"  {top_p_samples}")

## 3. Temperature Effects

Temperature is one of the most important parameters for controlling generation. Let's visualize how it affects the probability distribution.

In [None]:
def visualize_temperature_effects():
    """Visualize how temperature affects probability distributions."""
    
    # Create example logits
    vocab = ["the", "cat", "dog", "sat", "ran", "jumped", "quickly", "slowly"]
    logits = torch.tensor([3.0, 2.5, 2.0, 1.5, 1.0, 0.5, 0.0, -0.5])
    
    temperatures = [0.1, 0.5, 1.0, 2.0]
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    axes = axes.flatten()
    
    for i, temp in enumerate(temperatures):
        if temp == 0:
            # Special case: greedy (one-hot)
            probs = torch.zeros_like(logits)
            probs[torch.argmax(logits)] = 1.0
        else:
            scaled_logits = logits / temp
            probs = F.softmax(scaled_logits, dim=0)
        
        # Bar plot
        bars = axes[i].bar(range(len(vocab)), probs, alpha=0.7)
        axes[i].set_xlabel('Tokens')
        axes[i].set_ylabel('Probability')
        axes[i].set_title(f'Temperature = {temp}')
        axes[i].set_xticks(range(len(vocab)))
        axes[i].set_xticklabels(vocab, rotation=45)
        axes[i].grid(True, alpha=0.3)
        
        # Color the highest probability bar
        max_idx = torch.argmax(probs)
        bars[max_idx].set_color('red')
        bars[max_idx].set_alpha(0.9)
        
        # Add probability values on bars
        for j, (bar, prob) in enumerate(zip(bars, probs)):
            if prob > 0.01:  # Only show if probability > 1%
                axes[i].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                           f'{prob:.2f}', ha='center', va='bottom', fontsize=8)
    
    plt.tight_layout()
    plt.show()
    
    print("Temperature Effects:")
    print("=" * 20)
    print("• Low temperature (T < 1):  More focused, less creative")
    print("• Temperature = 1:          Original distribution")
    print("• High temperature (T > 1): More random, more creative")
    print("• Temperature → 0:          Greedy decoding")
    print("• Temperature → ∞:          Uniform random")
    
    # Calculate entropy for each temperature
    print("\nEntropy (measure of randomness):")
    for temp in temperatures:
        if temp == 0:
            entropy = 0.0
        else:
            scaled_logits = logits / temp
            probs = F.softmax(scaled_logits, dim=0)
            entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item()
        print(f"  T={temp:3.1f}: entropy={entropy:.3f}")

visualize_temperature_effects()

## 4. Real Text Generation with a Trained Model

Let's load our trained transformer and explore different generation strategies with real text.

In [None]:
from src.model.transformer import GPTModel, create_model_config
from src.data.tokenizer import create_tokenizer

class TextGenerator:
    """Advanced text generator with multiple strategies."""
    
    def __init__(self, model, tokenizer, device='cpu'):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.model.eval()
    
    def generate(self, prompt: str, max_length: int = 50, strategy: str = "temperature", **kwargs):
        """Generate text using specified strategy."""
        
        # Encode prompt
        tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
        input_ids = torch.tensor(tokens).unsqueeze(0).to(self.device)
        
        generated_tokens = tokens.copy()
        
        with torch.no_grad():
            for _ in range(max_length):
                # Get model predictions
                logits, _ = self.model(input_ids)
                next_token_logits = logits[0, -1, :]  # Last position
                
                # Apply generation strategy
                if strategy == "greedy":
                    next_token = torch.argmax(next_token_logits).item()
                
                elif strategy == "temperature":
                    temperature = kwargs.get('temperature', 1.0)
                    if temperature == 0:
                        next_token = torch.argmax(next_token_logits).item()
                    else:
                        scaled_logits = next_token_logits / temperature
                        probs = F.softmax(scaled_logits, dim=0)
                        next_token = torch.multinomial(probs, 1).item()
                
                elif strategy == "top_k":
                    k = kwargs.get('k', 10)
                    top_k_logits, top_k_indices = torch.topk(next_token_logits, k)
                    probs = F.softmax(top_k_logits, dim=0)
                    selected_idx = torch.multinomial(probs, 1).item()
                    next_token = top_k_indices[selected_idx].item()
                
                elif strategy == "top_p":
                    p = kwargs.get('p', 0.9)
                    probs = F.softmax(next_token_logits, dim=0)
                    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                    cumulative_probs = torch.cumsum(sorted_probs, dim=0)
                    
                    # Remove tokens with cumulative probability above threshold
                    sorted_indices_to_remove = cumulative_probs > p
                    sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
                    sorted_indices_to_remove[0] = 0
                    
                    indices_to_remove = sorted_indices[sorted_indices_to_remove]
                    probs[indices_to_remove] = 0
                    probs = probs / probs.sum()
                    
                    next_token = torch.multinomial(probs, 1).item()
                
                else:
                    raise ValueError(f"Unknown strategy: {strategy}")
                
                # Add to sequence
                generated_tokens.append(next_token)
                input_ids = torch.cat([input_ids, torch.tensor([[next_token]]).to(self.device)], dim=1)
                
                # Stop if we hit a natural stopping point
                if len(generated_tokens) > len(tokens) + 5:  # Generate at least 5 tokens
                    decoded = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
                    if decoded.endswith('.') or decoded.endswith('!') or decoded.endswith('?'):
                        break
        
        return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)

# Setup model and generator
def setup_text_generator():
    """Setup a text generator with a simple trained model."""
    
    # Create a small model for demonstration
    config = create_model_config("tiny")
    config["vocab_size"] = 200
    model = GPTModel(**config)
    
    tokenizer = create_tokenizer("simple")
    
    # Quick training on sample text to make generation more interesting
    from src.data.dataset import SimpleTextDataset, create_dataloader
    
    training_text = """
    The cat sat on the mat and looked around the room. The dog ran quickly across the yard.
    A beautiful bird flew high in the sky above the trees. The sun shone brightly on the garden.
    Children played happily in the park with their friends. The ocean waves crashed on the shore.
    Mountains stood tall against the clear blue sky. Flowers bloomed in the spring meadow.
    """
    
    dataset = SimpleTextDataset(training_text, tokenizer, block_size=16)
    dataloader = create_dataloader(dataset, batch_size=2, shuffle=True)
    
    # Quick training
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    criterion = torch.nn.CrossEntropyLoss()
    
    model.train()
    for epoch in range(10):  # Quick training
        for input_ids, target_ids in dataloader:
            optimizer.zero_grad()
            logits, _ = model(input_ids)
            loss = criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1))
            loss.backward()
            optimizer.step()
    
    generator = TextGenerator(model, tokenizer)
    
    print(f"Model ready with {sum(p.numel() for p in model.parameters()):,} parameters")
    return generator

generator = setup_text_generator()

# Test different generation strategies
prompt = "The cat"
print(f"Prompt: '{prompt}'\n")

strategies = [
    ("greedy", {}),
    ("temperature", {"temperature": 0.5}),
    ("temperature", {"temperature": 1.0}),
    ("temperature", {"temperature": 1.5}),
    ("top_k", {"k": 5}),
    ("top_p", {"p": 0.8}),
]

for strategy, kwargs in strategies:
    generated = generator.generate(prompt, max_length=15, strategy=strategy, **kwargs)
    params_str = ", ".join([f"{k}={v}" for k, v in kwargs.items()]) if kwargs else ""
    strategy_label = f"{strategy}({params_str})" if params_str else strategy
    print(f"{strategy_label:<20}: '{generated}'")

print("\nNotice how different strategies produce different styles of text!")

## 5. Beam Search

Beam search explores multiple possible sequences simultaneously, keeping track of the most promising candidates.

In [None]:
class BeamSearchGenerator:
    """Beam search text generator."""
    
    def __init__(self, model, tokenizer, device='cpu'):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.model.eval()
    
    def beam_search(self, prompt: str, beam_width: int = 3, max_length: int = 20):
        """Generate text using beam search."""
        
        # Encode prompt
        initial_tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
        
        # Initialize beams: (sequence, log_probability)
        beams = [(initial_tokens, 0.0)]
        
        print(f"Beam Search with beam_width={beam_width}")
        print(f"Starting prompt: '{prompt}'")
        print("=" * 50)
        
        with torch.no_grad():
            for step in range(max_length):
                candidates = []
                
                print(f"\nStep {step + 1}:")
                
                for beam_idx, (sequence, log_prob) in enumerate(beams):
                    # Convert to tensor
                    input_ids = torch.tensor(sequence).unsqueeze(0).to(self.device)
                    
                    # Get model predictions
                    logits, _ = self.model(input_ids)
                    next_token_logits = logits[0, -1, :]
                    
                    # Get top-k tokens for this beam
                    log_probs = F.log_softmax(next_token_logits, dim=0)
                    top_log_probs, top_indices = torch.topk(log_probs, beam_width)
                    
                    # Show current beam
                    current_text = self.tokenizer.decode(sequence, skip_special_tokens=True)
                    print(f"  Beam {beam_idx}: '{current_text}' (score: {log_prob:.3f})")
                    
                    # Add candidates
                    for token_idx, token_log_prob in zip(top_indices, top_log_probs):
                        new_sequence = sequence + [token_idx.item()]
                        new_log_prob = log_prob + token_log_prob.item()
                        candidates.append((new_sequence, new_log_prob))
                
                # Keep top beam_width candidates
                candidates.sort(key=lambda x: x[1], reverse=True)  # Sort by log probability
                beams = candidates[:beam_width]
                
                # Show selected beams
                print("  Selected for next step:")
                for i, (sequence, log_prob) in enumerate(beams):
                    text = self.tokenizer.decode(sequence, skip_special_tokens=True)
                    print(f"    {i+1}. '{text}' (score: {log_prob:.3f})")
                
                # Check if all beams end with sentence terminators
                all_terminated = True
                for sequence, _ in beams:
                    text = self.tokenizer.decode(sequence, skip_special_tokens=True)
                    if not (text.endswith('.') or text.endswith('!') or text.endswith('?')):
                        all_terminated = False
                        break
                
                if all_terminated and step > 5:  # At least 5 steps
                    break
        
        # Return best beam
        best_sequence, best_score = beams[0]
        best_text = self.tokenizer.decode(best_sequence, skip_special_tokens=True)
        
        print(f"\nFinal Results:")
        print("-" * 20)
        for i, (sequence, score) in enumerate(beams):
            text = self.tokenizer.decode(sequence, skip_special_tokens=True)
            print(f"{i+1}. '{text}' (score: {score:.3f})")
        
        return best_text, beams

# Test beam search
beam_generator = BeamSearchGenerator(generator.model, generator.tokenizer)
best_text, all_beams = beam_generator.beam_search("The cat", beam_width=3, max_length=10)

print(f"\nBest result: '{best_text}'")

print("\nBeam Search vs. Sampling:")
print("• Beam search: Finds high-probability sequences")
print("• Pro: More coherent, grammatically correct")
print("• Con: Can be repetitive, less creative")
print("• Sampling: More diverse but potentially less coherent")
print("• Best: Combine both - beam search + sampling within beams")

## 6. Advanced Generation Techniques

Let's explore some advanced techniques for improving generation quality.

In [None]:
def repetition_penalty_sampling(logits, previous_tokens, penalty=1.2):
    """Apply repetition penalty to discourage repeated tokens."""
    
    # Count token frequencies in previous context
    token_counts = Counter(previous_tokens)
    
    # Apply penalty
    penalized_logits = logits.clone()
    for token_id, count in token_counts.items():
        if token_id < len(penalized_logits):
            penalized_logits[token_id] /= (penalty ** count)
    
    return penalized_logits

def typical_sampling(logits, tau=0.95):
    """Typical sampling - sample from tokens with 'typical' probability."""
    
    probs = F.softmax(logits, dim=0)
    
    # Calculate entropy
    entropy = -torch.sum(probs * torch.log(probs + 1e-10))
    
    # Calculate "surprisal" (negative log probability)
    surprisals = -torch.log(probs + 1e-10)
    
    # Find tokens with surprisal close to entropy ("typical" tokens)
    differences = torch.abs(surprisals - entropy)
    
    # Sort by how "typical" they are
    sorted_diffs, sorted_indices = torch.sort(differences)
    
    # Keep tokens until we reach tau probability mass
    cumulative_prob = 0
    typical_indices = []
    
    for idx in sorted_indices:
        typical_indices.append(idx.item())
        cumulative_prob += probs[idx].item()
        if cumulative_prob >= tau:
            break
    
    # Create new probability distribution over typical tokens
    typical_probs = torch.zeros_like(probs)
    for idx in typical_indices:
        typical_probs[idx] = probs[idx]
    
    typical_probs = typical_probs / typical_probs.sum()
    
    return torch.multinomial(typical_probs, 1).item()

def demonstrate_advanced_techniques():
    """Demonstrate advanced generation techniques."""
    
    print("Advanced Generation Techniques")
    print("=" * 35)
    
    # Create example scenario
    vocab = ["the", "cat", "dog", "sat", "ran", "quickly", "slowly", ".", "and", "then"]
    logits = torch.tensor([2.0, 3.0, 1.5, 2.5, 1.0, 0.5, 0.3, 1.8, 1.2, 0.8])
    previous_tokens = [1, 3, 1, 1]  # "cat sat cat cat" - repetitive!
    
    print(f"Vocabulary: {vocab}")
    print(f"Previous tokens: {[vocab[i] for i in previous_tokens]}")
    print()
    
    # Show original probabilities
    original_probs = F.softmax(logits, dim=0)
    print("Original probabilities:")
    for i, (word, prob) in enumerate(zip(vocab, original_probs)):
        marker = " ← repetitive!" if i in previous_tokens else ""
        print(f"  {word:<8}: {prob:.3f}{marker}")
    
    # Apply repetition penalty
    penalized_logits = repetition_penalty_sampling(logits, previous_tokens, penalty=1.5)
    penalized_probs = F.softmax(penalized_logits, dim=0)
    
    print("\nAfter repetition penalty:")
    for i, (word, prob) in enumerate(zip(vocab, penalized_probs)):
        change = prob - original_probs[i]
        arrow = "↓" if change < -0.01 else "↑" if change > 0.01 else "→"
        print(f"  {word:<8}: {prob:.3f} {arrow}")
    
    # Test typical sampling
    print("\nTypical Sampling:")
    typical_samples = []
    for _ in range(5):
        idx = typical_sampling(logits, tau=0.8)
        typical_samples.append(vocab[idx])
    print(f"Samples: {typical_samples}")
    
    print("\nTechnique Benefits:")
    print("• Repetition penalty: Reduces boring repetition")
    print("• Typical sampling: Avoids both too-common and too-rare tokens")
    print("• Contrastive search: Balances coherence and diversity")
    print("• Dynamic temperature: Adjusts creativity based on confidence")

demonstrate_advanced_techniques()

# Implement an advanced generator
class AdvancedGenerator:
    """Generator with advanced techniques."""
    
    def __init__(self, model, tokenizer, device='cpu'):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.model.eval()
    
    def advanced_generate(self, prompt: str, max_length: int = 30, 
                         temperature: float = 1.0, repetition_penalty: float = 1.1,
                         top_p: float = 0.9):
        """Generate with multiple advanced techniques."""
        
        tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
        input_ids = torch.tensor(tokens).unsqueeze(0).to(self.device)
        generated_tokens = tokens.copy()
        
        with torch.no_grad():
            for step in range(max_length):
                # Get model predictions
                logits, _ = self.model(input_ids)
                next_token_logits = logits[0, -1, :]
                
                # Apply repetition penalty
                if repetition_penalty > 1.0:
                    context_window = generated_tokens[-20:]  # Last 20 tokens
                    next_token_logits = repetition_penalty_sampling(
                        next_token_logits, context_window, repetition_penalty
                    )
                
                # Apply temperature
                if temperature != 1.0:
                    next_token_logits = next_token_logits / temperature
                
                # Top-p sampling
                probs = F.softmax(next_token_logits, dim=0)
                sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                cumulative_probs = torch.cumsum(sorted_probs, dim=0)
                
                # Remove tokens with cumulative probability above threshold
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
                sorted_indices_to_remove[0] = 0
                
                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                probs[indices_to_remove] = 0
                probs = probs / probs.sum()
                
                # Sample
                next_token = torch.multinomial(probs, 1).item()
                
                generated_tokens.append(next_token)
                input_ids = torch.cat([input_ids, torch.tensor([[next_token]]).to(self.device)], dim=1)
                
                # Check for natural stopping
                if step > 5:
                    decoded = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
                    if decoded.endswith('.') or decoded.endswith('!') or decoded.endswith('?'):
                        break
        
        return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)

# Test advanced generator
advanced_gen = AdvancedGenerator(generator.model, generator.tokenizer)

print("\nAdvanced Generation Comparison:")
print("=" * 40)

prompt = "The cat"
configs = [
    {"temperature": 0.8, "repetition_penalty": 1.0, "top_p": 1.0, "label": "Basic"},
    {"temperature": 0.8, "repetition_penalty": 1.2, "top_p": 1.0, "label": "+ Rep. Penalty"},
    {"temperature": 0.8, "repetition_penalty": 1.0, "top_p": 0.9, "label": "+ Top-p"},
    {"temperature": 0.8, "repetition_penalty": 1.2, "top_p": 0.9, "label": "All Techniques"},
]

for config in configs:
    label = config.pop("label")
    result = advanced_gen.advanced_generate(prompt, max_length=20, **config)
    print(f"{label:<15}: '{result}'")

print("\n✅ Advanced generation techniques improve quality and diversity!")

## 7. Generation Quality Metrics

How do we measure the quality of generated text? Let's explore some metrics.

In [None]:
def analyze_generation_quality(texts: List[str]):
    """Analyze various quality metrics for generated texts."""
    
    print("Generation Quality Analysis")
    print("=" * 30)
    
    for i, text in enumerate(texts):
        print(f"\nText {i+1}: '{text}'")
        
        # Basic statistics
        words = text.split()
        unique_words = set(words)
        
        print(f"  Length: {len(words)} words")
        print(f"  Unique words: {len(unique_words)}")
        print(f"  Repetition ratio: {1 - len(unique_words)/len(words):.3f}")
        
        # Repetition analysis
        word_counts = Counter(words)
        repeated_words = {word: count for word, count in word_counts.items() if count > 1}
        if repeated_words:
            print(f"  Repeated words: {repeated_words}")
        
        # N-gram repetition
        bigrams = [tuple(words[i:i+2]) for i in range(len(words)-1)]
        bigram_counts = Counter(bigrams)
        repeated_bigrams = {bg: count for bg, count in bigram_counts.items() if count > 1}
        if repeated_bigrams:
            print(f"  Repeated bigrams: {repeated_bigrams}")
        
        # Perplexity simulation (would need actual model for real perplexity)
        # For demonstration, we'll estimate based on word frequency
        avg_word_freq = sum(word_counts.values()) / len(unique_words)
        estimated_perplexity = len(unique_words) / avg_word_freq
        print(f"  Estimated diversity: {estimated_perplexity:.2f}")

# Generate examples with different strategies for comparison
prompt = "The beautiful garden"
examples = []

# Greedy (should be repetitive)
greedy_text = generator.generate(prompt, max_length=15, strategy="greedy")
examples.append(greedy_text)

# High temperature (should be diverse but potentially incoherent)
high_temp_text = generator.generate(prompt, max_length=15, strategy="temperature", temperature=2.0)
examples.append(high_temp_text)

# Balanced approach
balanced_text = advanced_gen.advanced_generate(prompt, max_length=15, temperature=0.8, repetition_penalty=1.2, top_p=0.9)
examples.append(balanced_text)

analyze_generation_quality(examples)

print("\nQuality Metrics Summary:")
print("=" * 25)
print("• Repetition ratio: Lower is better (less repetitive)")
print("• Unique word count: Higher often better (more diverse)")
print("• Perplexity: Moderate is best (not too predictable, not too random)")
print("• Coherence: Must be evaluated by humans or advanced models")
print("• Factual accuracy: Requires external knowledge validation")

# Visualize repetition patterns
def visualize_repetition_patterns(text: str):
    """Visualize word repetition in generated text."""
    
    words = text.split()
    word_positions = {}
    
    # Track positions of each word
    for i, word in enumerate(words):
        if word not in word_positions:
            word_positions[word] = []
        word_positions[word].append(i)
    
    # Find repeated words
    repeated_words = {word: positions for word, positions in word_positions.items() if len(positions) > 1}
    
    if repeated_words:
        plt.figure(figsize=(12, 6))
        
        colors = ['red', 'blue', 'green', 'orange', 'purple']
        
        for i, (word, positions) in enumerate(repeated_words.items()):
            color = colors[i % len(colors)]
            plt.scatter(positions, [i] * len(positions), c=color, s=100, label=word, alpha=0.7)
            
            # Draw lines between repetitions
            for j in range(len(positions) - 1):
                plt.plot([positions[j], positions[j+1]], [i, i], color=color, alpha=0.3, linewidth=2)
        
        plt.xlabel('Word Position')
        plt.ylabel('Repeated Words')
        plt.title(f'Repetition Pattern: "{text}"')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()
    else:
        print(f"No repetitions found in: "{text}"")

# Visualize repetition for the examples
print("\nRepetition Pattern Visualization:")
for i, text in enumerate(examples):
    print(f"\nExample {i+1}:")
    visualize_repetition_patterns(text)

## Summary

In this notebook, we've explored the fascinating world of text generation:

1. **Autoregressive Generation** - Step-by-step token prediction process
2. **Generation Strategies** - Greedy, temperature, top-k, top-p sampling
3. **Temperature Effects** - Controlling randomness and creativity
4. **Beam Search** - Exploring multiple sequence possibilities
5. **Advanced Techniques** - Repetition penalty, typical sampling
6. **Quality Metrics** - Measuring repetition, diversity, coherence

### Key Generation Insights:

- **Strategy matters**: Different approaches produce different styles
- **Temperature is crucial**: Controls the creativity-coherence tradeoff
- **Top-p often best**: Adaptive cutoff based on probability mass
- **Repetition is the enemy**: Use penalties and diverse sampling
- **Quality is multifaceted**: No single metric captures everything

### Generation Strategy Guide:

**For coherent, factual text:**
- Low temperature (0.3-0.7)
- Top-p sampling (p=0.8-0.9)
- Mild repetition penalty (1.1-1.3)

**For creative writing:**
- Higher temperature (0.8-1.2)
- Top-p or top-k sampling
- Strong repetition penalty (1.3-1.5)

**For reliable completion:**
- Beam search with beam width 3-5
- Length penalties to avoid too-short sequences
- Multiple candidates for selection

### Modern Developments:

- **Contrastive search**: Balances probability and diversity
- **Typical sampling**: Avoids both too-common and too-rare tokens
- **MCTS-based generation**: Uses tree search for better planning
- **Classifier-free guidance**: Steers generation toward desired attributes

### Quality Considerations:

- **Coherence**: Does the text make sense?
- **Consistency**: Are facts and details consistent?
- **Relevance**: Does it address the prompt appropriately?
- **Fluency**: Is the language natural and grammatical?
- **Diversity**: Is the output varied and interesting?

The art of text generation lies in balancing these competing objectives. The best approach depends on your specific use case, from creative writing to factual question answering to code generation.

Congratulations! You've now completed a comprehensive journey through transformer architecture, training, and generation. You understand how these powerful models work from the ground up! 🎉