# Top-K and Top-P (Nucleus) Sampling

### Problem Statement

Temperature scaling controls the sharpness of probability distributions, but it doesn't prevent sampling from low-probability tokens that could derail generation. **Top-K** and **Top-P (Nucleus) Sampling** are filtering techniques that truncate the probability distribution to improve generation quality.

Your task is to implement both Top-K and Top-P sampling from scratch and understand when to use each.

---

### Background

#### The Problem with Pure Temperature Sampling

Even with temperature control, sampling from the full vocabulary can produce problematic outputs:
- **Low-probability tokens**: With 50,000+ tokens, even 0.01% probability tokens get sampled occasionally
- **Inconsistent quality**: Rare tokens can derail coherent generation
- **The long tail problem**: Most of the probability mass is in a small number of tokens

#### Top-K Sampling (Fan et al., 2018)

**Idea**: Only consider the K most likely tokens, redistribute probability among them.

**Algorithm**:
1. Sort tokens by probability (descending)
2. Keep only the top K tokens
3. Set all other probabilities to 0
4. Renormalize to sum to 1
5. Sample from truncated distribution

**Limitation**: Fixed K doesn't adapt to confidence. When the model is very confident (one token has 95% probability), K=50 still considers 49 unlikely tokens. When uncertain (flat distribution), K=50 might exclude reasonable options.

#### Top-P (Nucleus) Sampling (Holtzman et al., 2020)

**Idea**: Dynamically select the smallest set of tokens whose cumulative probability exceeds threshold P.

**Algorithm**:
1. Sort tokens by probability (descending)
2. Compute cumulative probabilities
3. Find smallest set where cumsum >= P
4. Keep only tokens in this "nucleus"
5. Renormalize and sample

**Advantage**: Adapts to model confidence. Confident predictions use few tokens; uncertain predictions use more.

#### Combining Techniques

In practice, these techniques are often combined:
1. Apply temperature scaling first
2. Apply Top-K filtering
3. Apply Top-P filtering
4. Sample from result

**Common defaults** (OpenAI, Anthropic, etc.):
- Temperature: 0.7-1.0
- Top-P: 0.9-0.95
- Top-K: 40-100 (or disabled)

---

### Mathematical Formulation

#### Top-K

Given probabilities $p_1, p_2, ..., p_V$ sorted in descending order:

$$p_i^{\text{top-k}} = \begin{cases} \frac{p_i}{\sum_{j=1}^{K} p_j} & \text{if } i \leq K \\ 0 & \text{otherwise} \end{cases}$$

#### Top-P (Nucleus)

Find the smallest $k$ such that:

$$\sum_{i=1}^{k} p_i \geq P$$

Then:

$$p_i^{\text{top-p}} = \begin{cases} \frac{p_i}{\sum_{j=1}^{k} p_j} & \text{if } i \leq k \\ 0 & \text{otherwise} \end{cases}$$

---

### Learning Objectives

By the end of this notebook, you will:
1. Understand why filtering is necessary beyond temperature scaling
2. Implement Top-K sampling with proper renormalization
3. Implement Top-P (nucleus) sampling with dynamic cutoff
4. Know when to use each technique and common parameter values
5. Combine multiple sampling techniques effectively

---

### Requirements

1. **Top-K Filtering**: Implement `top_k_filtering(logits, k)` that zeros out all but top-k logits
2. **Top-P Filtering**: Implement `top_p_filtering(logits, p)` that zeros out tokens outside the nucleus
3. **Combined Sampling**: Implement `sample(logits, temperature, top_k, top_p)` combining all techniques
4. **Validation**: Test that filtering produces expected behavior

---

<details>
<summary>Hint 1: Top-K Implementation</summary>

Use `torch.topk()` to get indices of top-k values. Create a mask or use scatter to zero out non-top-k positions.

```python
values, indices = torch.topk(logits, k)
# Create output with -inf for filtered positions
```
</details>

<details>
<summary>Hint 2: Top-P Implementation</summary>

1. Convert logits to probabilities with softmax
2. Sort probabilities and compute cumsum
3. Find where cumsum exceeds p
4. Create mask and apply to original logits

```python
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumsum = torch.cumsum(sorted_probs, dim=-1)
```
</details>

<details>
<summary>Hint 3: Handling Edge Cases</summary>

- If k >= vocab_size, Top-K should return original logits
- If p >= 1.0, Top-P should return original logits  
- Always keep at least one token (the most probable)
- Use -inf (not 0) to zero out logits before softmax
</details>

<details>
<summary>Hint 4: Combining Techniques</summary>

Order matters! Standard order:
1. Temperature scaling (divide logits by T)
2. Top-K filtering
3. Top-P filtering
4. Softmax and sample
</details>

---

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

torch.manual_seed(42)

## Part 1: Top-K Filtering

Implement a function that keeps only the top-k logits and sets all others to negative infinity.

In [None]:
def top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor:
    """
    Filter logits to keep only top-k values, setting others to -inf.
    
    Args:
        logits: Tensor of shape (batch_size, vocab_size) or (vocab_size,)
        k: Number of top tokens to keep
        
    Returns:
        Filtered logits with same shape, non-top-k positions set to -inf
    """
    # TODO: Implement top-k filtering
    # 1. Handle edge cases (k <= 0, k >= vocab_size)
    # 2. Find the k-th largest value as threshold
    # 3. Create mask for values below threshold
    # 4. Set filtered positions to -inf
    pass

In [None]:
# Test Top-K filtering
print("=== Testing Top-K Filtering ===")

# Simple test case
logits = torch.tensor([1.0, 4.0, 2.0, 5.0, 3.0])
print(f"Original logits: {logits}")

filtered_k2 = top_k_filtering(logits, k=2)
print(f"Top-2 filtered:  {filtered_k2}")

filtered_k3 = top_k_filtering(logits, k=3)
print(f"Top-3 filtered:  {filtered_k3}")

# Verify only k values remain finite
assert (filtered_k2 > float('-inf')).sum() == 2, "Should have exactly 2 finite values"
assert (filtered_k3 > float('-inf')).sum() == 3, "Should have exactly 3 finite values"

# Verify the top-k values are preserved
probs_k2 = F.softmax(filtered_k2, dim=-1)
print(f"Probabilities after Top-2: {probs_k2}")
assert probs_k2[1] > 0 and probs_k2[3] > 0, "Tokens 1 and 3 should have probability"
assert probs_k2[0] == 0 and probs_k2[2] == 0 and probs_k2[4] == 0, "Other tokens should have 0 probability"

print("\n Top-K filtering tests passed!")

## Part 2: Top-P (Nucleus) Filtering

Implement a function that keeps the smallest set of tokens whose cumulative probability exceeds p.

In [None]:
def top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor:
    """
    Filter logits using nucleus (top-p) sampling.
    Keep smallest set of tokens with cumulative probability >= p.
    
    Args:
        logits: Tensor of shape (batch_size, vocab_size) or (vocab_size,)
        p: Cumulative probability threshold (0 < p <= 1)
        
    Returns:
        Filtered logits with same shape, tokens outside nucleus set to -inf
    """
    # TODO: Implement top-p (nucleus) filtering
    # 1. Convert logits to probabilities
    # 2. Sort probabilities descending
    # 3. Compute cumulative sum
    # 4. Find tokens to remove (cumsum > p)
    # 5. Create mask and apply to original logits
    pass

In [None]:
# Test Top-P filtering
print("=== Testing Top-P (Nucleus) Filtering ===")

# Create logits where we know the probabilities
# logits = [0, 1, 2, 3] -> softmax ~ [0.032, 0.087, 0.236, 0.643]
logits = torch.tensor([0.0, 1.0, 2.0, 3.0])
probs = F.softmax(logits, dim=-1)
print(f"Original logits: {logits}")
print(f"Original probs:  {probs}")
print(f"Cumulative:      {torch.cumsum(probs.sort(descending=True)[0], dim=-1)}")

# Top-p with p=0.9 should keep tokens with cumsum <= 0.9
# Sorted: 0.643 (cumsum=0.643), 0.236 (cumsum=0.879), 0.087 (cumsum=0.966)
# With p=0.9, we keep until cumsum >= 0.9, so we keep top 3
filtered_p09 = top_p_filtering(logits, p=0.9)
probs_p09 = F.softmax(filtered_p09, dim=-1)
print(f"\nTop-p=0.9 filtered: {filtered_p09}")
print(f"Probs after p=0.9:  {probs_p09}")

# Top-p with p=0.7 should keep fewer tokens
# We need cumsum >= 0.7, so just top 2 (0.643 + 0.236 = 0.879 >= 0.7)
filtered_p07 = top_p_filtering(logits, p=0.7)
probs_p07 = F.softmax(filtered_p07, dim=-1)
print(f"\nTop-p=0.7 filtered: {filtered_p07}")
print(f"Probs after p=0.7:  {probs_p07}")

# Verify the nucleus property
print(f"\n--- Verification ---")
kept_p09 = (probs_p09 > 0).sum().item()
kept_p07 = (probs_p07 > 0).sum().item()
print(f"Tokens kept with p=0.9: {kept_p09}")
print(f"Tokens kept with p=0.7: {kept_p07}")
assert kept_p07 <= kept_p09, "Lower p should keep fewer or equal tokens"

# Test that probabilities still sum to 1
assert torch.allclose(probs_p09.sum(), torch.tensor(1.0)), "Probs should sum to 1"
assert torch.allclose(probs_p07.sum(), torch.tensor(1.0)), "Probs should sum to 1"

print("\n Top-P filtering tests passed!")

## Part 3: Combined Sampling Function

Implement a complete sampling function that combines temperature, top-k, and top-p.

In [None]:
def sample(
    logits: torch.Tensor,
    temperature: float = 1.0,
    top_k: int = 0,
    top_p: float = 1.0,
) -> torch.Tensor:
    """
    Sample from logits with temperature scaling and optional top-k/top-p filtering.
    
    Args:
        logits: Tensor of shape (batch_size, vocab_size) or (vocab_size,)
        temperature: Temperature for scaling (default 1.0)
        top_k: If > 0, only sample from top-k tokens (default 0 = disabled)
        top_p: If < 1.0, use nucleus sampling (default 1.0 = disabled)
        
    Returns:
        Sampled token indices
    """
    # TODO: Implement combined sampling
    # 1. Apply temperature scaling
    # 2. Apply top-k filtering (if top_k > 0)
    # 3. Apply top-p filtering (if top_p < 1.0)
    # 4. Convert to probabilities and sample
    pass

In [None]:
# Test combined sampling
print("=== Testing Combined Sampling ===")

torch.manual_seed(42)

# Create test logits
vocab_size = 100
logits = torch.randn(vocab_size)

# Test 1: Pure sampling (no filtering)
samples_pure = [sample(logits, temperature=1.0).item() for _ in range(100)]
unique_pure = len(set(samples_pure))
print(f"Pure sampling (T=1.0): {unique_pure} unique tokens from 100 samples")

# Test 2: Low temperature (more deterministic)
samples_low_t = [sample(logits, temperature=0.1).item() for _ in range(100)]
unique_low_t = len(set(samples_low_t))
print(f"Low temp (T=0.1): {unique_low_t} unique tokens from 100 samples")

# Test 3: Top-K sampling
samples_topk = [sample(logits, top_k=5).item() for _ in range(100)]
unique_topk = len(set(samples_topk))
print(f"Top-K (k=5): {unique_topk} unique tokens from 100 samples")
assert unique_topk <= 5, "Top-K should limit to at most K unique tokens"

# Test 4: Top-P sampling
samples_topp = [sample(logits, top_p=0.5).item() for _ in range(100)]
unique_topp = len(set(samples_topp))
print(f"Top-P (p=0.5): {unique_topp} unique tokens from 100 samples")

# Test 5: Combined (typical production settings)
samples_combined = [sample(logits, temperature=0.7, top_k=40, top_p=0.9).item() for _ in range(100)]
unique_combined = len(set(samples_combined))
print(f"Combined (T=0.7, k=40, p=0.9): {unique_combined} unique tokens from 100 samples")

# Test 6: Batch processing
batch_logits = torch.randn(4, vocab_size)
batch_samples = sample(batch_logits, temperature=0.8, top_k=10)
assert batch_samples.shape == (4, 1), f"Expected shape (4, 1), got {batch_samples.shape}"
print(f"Batch sampling works: shape {batch_samples.shape}")

print("\n Combined sampling tests passed!")

## Part 4: Visualization

Let's visualize how Top-K and Top-P affect the probability distribution.

In [None]:
def visualize_filtering():
    """Visualize the effects of Top-K and Top-P filtering."""
    # Create a realistic distribution (Zipf-like, common in language models)
    vocab_size = 50
    # Higher logits for fewer tokens (realistic LM distribution)
    logits = torch.randn(vocab_size) * 2
    logits[0] = 5.0   # One very likely token
    logits[1] = 3.5
    logits[2] = 2.5
    logits[3] = 2.0
    logits[4] = 1.5
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # Original distribution
    original_probs = F.softmax(logits, dim=-1).numpy()
    sorted_probs = np.sort(original_probs)[::-1]
    
    # Plot 1: Original distribution
    ax1 = axes[0, 0]
    ax1.bar(range(vocab_size), sorted_probs, color='steelblue', alpha=0.7)
    ax1.set_xlabel('Token Rank')
    ax1.set_ylabel('Probability')
    ax1.set_title('Original Distribution')
    ax1.axhline(y=0.01, color='red', linestyle='--', label='1% threshold')
    ax1.legend()
    
    # Plot 2: Top-K filtered (k=10)
    filtered_k10 = top_k_filtering(logits, k=10)
    probs_k10 = F.softmax(filtered_k10, dim=-1).numpy()
    sorted_probs_k10 = np.sort(probs_k10)[::-1]
    
    ax2 = axes[0, 1]
    colors = ['coral' if p > 0 else 'lightgray' for p in sorted_probs_k10]
    ax2.bar(range(vocab_size), sorted_probs_k10, color=colors, alpha=0.7)
    ax2.set_xlabel('Token Rank')
    ax2.set_ylabel('Probability')
    ax2.set_title('Top-K Filtered (k=10)')
    ax2.annotate(f'Kept: 10 tokens', xy=(0.7, 0.9), xycoords='axes fraction')
    
    # Plot 3: Top-P filtered (p=0.9)
    filtered_p09 = top_p_filtering(logits, p=0.9)
    probs_p09 = F.softmax(filtered_p09, dim=-1).numpy()
    sorted_probs_p09 = np.sort(probs_p09)[::-1]
    n_kept_p09 = (probs_p09 > 0).sum()
    
    ax3 = axes[1, 0]
    colors = ['seagreen' if p > 0 else 'lightgray' for p in sorted_probs_p09]
    ax3.bar(range(vocab_size), sorted_probs_p09, color=colors, alpha=0.7)
    ax3.set_xlabel('Token Rank')
    ax3.set_ylabel('Probability')
    ax3.set_title('Top-P Filtered (p=0.9)')
    ax3.annotate(f'Kept: {n_kept_p09} tokens', xy=(0.7, 0.9), xycoords='axes fraction')
    
    # Plot 4: Combined (k=20, p=0.9)
    filtered_combined = top_p_filtering(top_k_filtering(logits, k=20), p=0.9)
    probs_combined = F.softmax(filtered_combined, dim=-1).numpy()
    sorted_probs_combined = np.sort(probs_combined)[::-1]
    n_kept_combined = (probs_combined > 0).sum()
    
    ax4 = axes[1, 1]
    colors = ['purple' if p > 0 else 'lightgray' for p in sorted_probs_combined]
    ax4.bar(range(vocab_size), sorted_probs_combined, color=colors, alpha=0.7)
    ax4.set_xlabel('Token Rank')
    ax4.set_ylabel('Probability')
    ax4.set_title('Combined: Top-K (k=20) + Top-P (p=0.9)')
    ax4.annotate(f'Kept: {n_kept_combined} tokens', xy=(0.7, 0.9), xycoords='axes fraction')
    
    plt.tight_layout()
    plt.savefig('filtering_visualization.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("Saved visualization to filtering_visualization.png")

# Uncomment to run after implementing the functions:
# visualize_filtering()

## Part 5: Adaptive Behavior Demonstration

Show how Top-P adapts to model confidence while Top-K does not.

In [None]:
def demonstrate_adaptive_behavior():
    """Show how Top-P adapts to confidence while Top-K doesn't."""
    vocab_size = 100
    
    print("=== Adaptive Behavior: Top-K vs Top-P ===")
    print()
    
    # Scenario 1: High confidence (one dominant token)
    print("Scenario 1: HIGH CONFIDENCE (one token dominates)")
    logits_confident = torch.zeros(vocab_size)
    logits_confident[0] = 10.0  # One very confident prediction
    probs_confident = F.softmax(logits_confident, dim=-1)
    print(f"  Top token probability: {probs_confident[0]:.4f}")
    print(f"  Top 5 tokens sum: {probs_confident[:5].sum():.4f}")
    
    # Top-K still keeps K tokens even though only 1 matters
    filtered_k10 = top_k_filtering(logits_confident, k=10)
    n_kept_k = (F.softmax(filtered_k10, dim=-1) > 1e-6).sum().item()
    print(f"  Top-K (k=10) keeps: {n_kept_k} tokens (wasteful!)")
    
    # Top-P adapts and keeps fewer
    filtered_p09 = top_p_filtering(logits_confident, p=0.9)
    n_kept_p = (F.softmax(filtered_p09, dim=-1) > 1e-6).sum().item()
    print(f"  Top-P (p=0.9) keeps: {n_kept_p} tokens (adaptive!)")
    print()
    
    # Scenario 2: Low confidence (flat distribution)
    print("Scenario 2: LOW CONFIDENCE (flat distribution)")
    logits_uncertain = torch.randn(vocab_size) * 0.1  # Very flat
    probs_uncertain = F.softmax(logits_uncertain, dim=-1)
    print(f"  Max token probability: {probs_uncertain.max():.4f}")
    print(f"  Top 5 tokens sum: {probs_uncertain.topk(5)[0].sum():.4f}")
    
    # Top-K still only keeps K tokens
    filtered_k10_unc = top_k_filtering(logits_uncertain, k=10)
    n_kept_k_unc = (F.softmax(filtered_k10_unc, dim=-1) > 1e-6).sum().item()
    print(f"  Top-K (k=10) keeps: {n_kept_k_unc} tokens (might miss good options!)")
    
    # Top-P adapts and keeps more
    filtered_p09_unc = top_p_filtering(logits_uncertain, p=0.9)
    n_kept_p_unc = (F.softmax(filtered_p09_unc, dim=-1) > 1e-6).sum().item()
    print(f"  Top-P (p=0.9) keeps: {n_kept_p_unc} tokens (explores more options!)")
    print()
    
    print("Key Insight:")
    print("  - Top-K: Fixed number of tokens regardless of confidence")
    print("  - Top-P: Adapts to model confidence (fewer when confident, more when uncertain)")

# Uncomment to run after implementing the functions:
# demonstrate_adaptive_behavior()

---

## Summary

### Key Concepts

1. **Top-K Sampling**: Keep only K most likely tokens
   - Simple and predictable
   - Doesn't adapt to model confidence
   - Good for limiting search space

2. **Top-P (Nucleus) Sampling**: Keep smallest set with cumulative prob >= P
   - Adapts to model confidence
   - More flexible than Top-K
   - Industry standard (p=0.9 common default)

3. **Combining Techniques**: Temperature + Top-K + Top-P
   - Order: Temperature -> Top-K -> Top-P -> Sample
   - Production defaults: T=0.7, top_p=0.9, top_k=40 (or disabled)

4. **Implementation Details**:
   - Use -inf (not 0) to filter logits before softmax
   - Always keep at least one token
   - Handle batch dimension properly

---

## Interview Tips

**Q: What's the difference between Top-K and Top-P sampling?**

A: Top-K keeps exactly K tokens regardless of their probabilities. Top-P (nucleus) dynamically selects tokens until cumulative probability reaches threshold P. Top-P adapts to model confidence - when confident, it uses fewer tokens; when uncertain, it uses more.

**Q: Why not just use temperature alone?**

A: Temperature affects distribution sharpness but still allows sampling from the entire vocabulary. With 50K+ tokens, even 0.001% probability tokens occasionally get sampled, potentially derailing generation. Filtering removes low-probability tokens entirely.

**Q: What are typical production values for these parameters?**

A: Common defaults: Temperature 0.7-1.0, Top-P 0.9-0.95. Top-K is often disabled or set to 40-100. ChatGPT, Claude, and other systems use similar ranges.

**Q: In what order should you apply these techniques?**

A: Temperature first (scales logits), then Top-K (fixed filtering), then Top-P (adaptive filtering), then softmax and sample. Temperature must come before filtering because it changes relative probabilities.

**Q: When would you use Top-K over Top-P?**

A: Top-K when you want predictable compute/memory (always exactly K candidates) or when the model's confidence estimates are unreliable. Top-P is generally preferred for text generation due to its adaptive behavior.

**Q: How do you handle the edge case where filtering removes all tokens?**

A: Always keep at least the most probable token. In Top-P, even with p=0.01, the top token is always included. This prevents sampling from an empty distribution.

**Q: What's the computational complexity of Top-P?**

A: O(V log V) for sorting the vocabulary, then O(V) for cumsum and masking. In practice, can be optimized with partial sorting since we often only need a small nucleus.

---

## References

1. Fan, A., Lewis, M., & Dauphin, Y. (2018). "Hierarchical Neural Story Generation." *ACL 2018*. (Introduced Top-K sampling)

2. Holtzman, A., Buys, J., Du, L., Forbes, M., & Choi, Y. (2020). "The Curious Case of Neural Text Degeneration." *ICLR 2020*. (Introduced nucleus/Top-P sampling)

3. [HuggingFace Transformers - Generation Strategies](https://huggingface.co/docs/transformers/generation_strategies)

4. [OpenAI API Documentation - Temperature and Top-P](https://platform.openai.com/docs/api-reference/chat/create)