# Module 5.6: Probability & Sampling

**Goal**: Implement all sampling strategies

**Time**: 75 minutes

**Concepts Covered**:
- Greedy decoding implementation
- Temperature sampling with interactive demo
- Top-k sampling
- Top-p (nucleus) sampling
- Compare all strategies with visualizations

## Setup

In [None]:
!pip install torch transformers accelerate matplotlib seaborn numpy -q

In [None]:
import torch
import torch.nn.functional as F
import numpy as np

def greedy_decode(logits):
    """Greedy: always pick highest probability token"""
    return torch.argmax(logits, dim=-1)

def temperature_sample(logits, temperature=1.0):
    """Temperature sampling"""
    scaled_logits = logits / temperature
    probs = F.softmax(scaled_logits, dim=-1)
    return torch.multinomial(probs, 1)

def top_k_sample(logits, k=50):
    """Top-k sampling"""
    top_k_logits, top_k_indices = torch.topk(logits, k, dim=-1)
    top_k_probs = F.softmax(top_k_logits, dim=-1)
    sampled_idx = torch.multinomial(top_k_probs, 1)
    return top_k_indices.gather(-1, sampled_idx)

def top_p_sample(logits, p=0.9):
    """Top-p (nucleus) sampling"""
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    sorted_probs = F.softmax(sorted_logits, dim=-1)
    cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
    
    # Remove tokens with cumulative probability > p
    sorted_indices_to_remove = cumsum_probs > p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0
    
    # Set removed tokens to very negative value
    sorted_logits[sorted_indices_to_remove] = float('-inf')
    sorted_probs = F.softmax(sorted_logits, dim=-1)
    
    sampled_idx = torch.multinomial(sorted_probs, 1)
    return sorted_indices.gather(-1, sampled_idx)

# Example
logits = torch.randn(1, 1000)  # 1000 token vocabulary

print("Sampling strategies:")
print(f"Greedy: token {greedy_decode(logits).item()}")
print(f"Temperature (T=0.8): token {temperature_sample(logits, 0.8).item()}")
print(f"Top-k (k=50): token {top_k_sample(logits, 50).item()}")
print(f"Top-p (p=0.9): token {top_p_sample(logits, 0.9).item()}")

## Key Takeaways

✅ **Module Complete**

## Next Steps

Continue to the next module in the course.