### Decoding
- Decoding is the process of generating text from a language model, one token at a time.
- After training, when you want your LLM to generate text, you need a strategy to decide which token to pick next

#### The Core Problem
At each step, your model outputs a probability distribution over your entire vocabulary (let's say 50,000+ tokens). You need to decide:
- Which token do I pick?
- Do I always pick the most likely one?
- How do I balance quality vs diversity?

#### Why This Matters
The decoding strategy dramatically effects:
- Quality: Coherence and relevance of generated text
- Diversity: Whether you get repetitive or varied outputs
- Speed: Some strategies are faster than others
- Use case fit: Different strategies work better for different tasks

### Greedy Decoding
- Greedy decoding is the simplest approach: at each step, always pick the token with the highest probability.

In [1]:
import torch

In [None]:
def greedy_decode(model, input_ids, eos_token_id, max_length=50):
    generated = input_ids.clone()

    for _ in range(max_length):
        with torch.no_grad():
            logits = model(generated)
            next_logits = logits[:, -1, :]

        probs = torch.softmax(next_logits, dim=-1)

        next_token = torch.argmax(probs, dim=-1, keepdim=True)

        generated = torch.concat([generated, next_token], dim=1)

        if (next_token == eos_token_id).all():
            break

    return generated

#### Advantages
- **Deterministic**: Same input always produces same output
- **Fast**: No sampling overhead, just argmax
- **Simple**: Easy to implement and debug

#### Disadvantages
- **Repetitive**: Often gets stuck in loops ("very very very very...")
- **No diversity**: Can't generate multiple different completions
- **Not always best**: The highest probability at each step doesn't guarantee the best overall sequence

#### Beam Search
- Beam Search tries to fix greedy's shortsightedness by exploring multiple possible sequences in parallel. Instead of committing to one token at each step, we keep track of the top-k most promising sequences.

##### The Core Idea
- Beam width (k): Number of candidate sequences to track simultaneously
- At each step: expand all beams, keep only the k best overall sequences
- Score sequences by their cumulative log probability

In [None]:
def beam_search(model, input_ids, vocab_size, eos_token_id, beam_width, max_length):
    beams = input_ids.repeat(beam_width, 1)
    beam_scores = torch.zeros(beam_width, dtype=torch.float32)
    beam_scores[1:] = float('-inf')

    for _ in range(max_length):
        with torch.no_grad():
            logits = model(beams)
            next_logits = logits[:,:,-1,:]

        probs = torch.softmax(next_logits, dim=-1)

        next_scores = beam_scores.unsqueeze(1) + probs

        next_scores = next_scores.view(-1)

        topk_values, topk_indices = torch.topk(next_scores, k=beam_width)

        beam_idx = topk_values // vocab_size
        token_idx = topk_indices % vocab_size

        beams = torch.concat([
            beams[beam_idx],
            token_idx.unsqueeze(1)
        ], dim=1)

        if (next_logits == eos_token_id).all():
            break

    best_beam_idx = torch.argmax(beam_scores)
    return beams[best_beam_idx].unsqueeze(0)

#### Temperature Sampling
- Temperature sampling introduces randomness into text generation by controlling how "sharp" or "flat" the probability distribution is before sampling.
- Key idea: Modify the probability distribution, then use torch.multinomial to sample from it.

#### Temperature Values

- T = 1.0: Original distribution (no change)
- T < 1.0 (e.g., 0.5): Makes distribution sharper (more confident, less random)
- T > 1.0 (e.g., 2.0): Makes distribution flatter (less confident, more random)
- T → 0: Approaches greedy (always pick max)
- T → ∞: Uniform distribution (all tokens equally likely)

In [None]:
def temperature_sampling(model, input_ids, temperature, max_length, eos_token_id):
    generated = input_ids.clone()
    
    for _ in range(max_length):
        with torch.no_grad():
            logits = model(generated)
            next_logits = logits[:,-1,:] / temperature

        probs = torch.softmax(next_logits, dim=-1)

        next_token = torch.multinomial(probs, num_samples=1)

        generated = torch.cat([generated, next_token], dim=-1)

        if (next_token == eos_token_id).all():
            break

    return generated

In [8]:
# When T < 1: logits get BIGGER → exp() amplifies differences more
# Example: exp(4/0.5) = exp(8) = 2981 vs exp(1/0.5) = exp(2) = 7.4
# Ratio: 2981/7.4 = 402 (huge gap!)
#
# When T > 1: logits get SMALLER → exp() compresses differences
# Example: exp(4/2) = exp(2) = 7.4 vs exp(1/2) = exp(0.5) = 1.6
# Ratio: 7.4/1.6 = 4.6 (smaller gap!)

##### Advantages ✅

- Diversity: Different outputs each time
- Controllable randomness: Tune with single parameter
- Simple: Just one line of code
- Works well: Good default for creative tasks

#### Disadvantages ❌

- Can be incoherent at high T (too random)
- Can still repeat at low T (approaching greedy)
- No quality control: Might sample low-probability nonsense
- Hard to tune: Optimal T varies by task

##### When to Use Different Temperatures

- T = 0.1-0.5: Factual tasks (translation, Q&A, code generation): More deterministic, focused on high-probability tokens

- T = 0.7-0.8: Default for chat/general use: Good balance of quality and diversity

- T = 1.0-1.5: Creative writing, brainstorming: More variety and unexpected outputs

- T > 1.5: Experimental/artistic generation: Very random, often incoherent

### Top-k Sampling

- Top-k sampling addresses a key problem with temperature sampling: even with low temperature, you might still sample terrible low-probability tokens.
- Solution: Only consider the top-k most likely tokens, set all others to zero probability, then sample.

In [None]:
def top_k_sampling(model, input_ids, k, temperature, eos_token_id, max_length):
    generated = input_ids.clone()

    for _ in range(max_length):
        with torch.no_grad():
            logits = model(generated)
            next_logits = logits[:,-1,:] / temperature

        topk_values, topk_indices = torch.topk(next_logits, k=k, dim=-1)

        filtered_logits = torch.full_like(next_logits, fill_value=float('-inf'))
        filtered_logits.scatter_(dim=-1, index=topk_indices, src=topk_values)

        probs = torch.softmax(filtered_logits, dim=-1)

        next_token = torch.multinomial(probs, num_samples=1)

        generated = torch.concat([generated, next_token], dim=-1)

        if (next_token == eos_token_id).all():
            break

    return generated

##### Advantages ✅

- Prevents low-quality tokens: Can't sample garbage
- Adaptive: Works with any vocabulary size
- Combines with temperature: Apply both for fine control
- Simple parameter: Just choose k

##### Disadvantages ❌

- Fixed k: Always keeps k tokens, even if top 2 are clearly best
- Context-independent: k=50 might be too many for "The capital of France is ___" but too few for "The weather today is ___"
- Can cut off good options: If k=5 but token #6 is still reasonable

#### Top-p (Nucleus) Sampling
- Top-p sampling (also called Nucleus sampling) solves top-k's rigidity by using a dynamic cutoff based on cumulative probability instead of a fixed number of tokens.
- Key idea: Keep adding tokens (sorted by probability) until their cumulative probability reaches p. Then sample only from those tokens.

In [None]:
def top_p_sampling(model, input_ids, p, temperature, eos_token_id, max_length):
    generated = input_ids.clone()

    for _ in range(max_length):
        with torch.no_grad():
            logits = model(generated)
            next_logits = logits[:,-1,:] / temperature

        sorted_logits, sorted_indices = torch.sort(next_logits, descending=True, dim=-1)

        sorted_probs = torch.softmax(sorted_logits, dim=-1)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        nucleus_mask = cumulative_probs > p
        nucleus_mask[..., 0] = False
        sorted_logits[nucleus_mask] = float('-inf')
        
        filtered_logits = torch.full_like(next_logits, fill_value=float('-inf'))
        filtered_logits.scatter_(dim=-1, index=sorted_indices, src=sorted_logits)

        probs = torch.softmax(filtered_logits, dim=-1)

        next_token = torch.multinomial(probs, num_samples=1)

        generated = torch.concat([generated, next_token], dim=-1)

        if eos_token_id is not None and (next_token == eos_token_id).all():
            break

    return generated

#### Advantages ✅

- Adaptive: Adjusts to model confidence
- Flexible: Works well across different contexts
- Quality control: Prevents sampling junk while allowing diversity
- Intuitive parameter: p=0.9 means "keep 90% probability mass"

#### Disadvantages ❌

- Slightly more complex than top-k
- Still can include low-quality tokens if many tokens have similar probabilities
- Computational overhead: Sorting is O(n log n)

### Constrative Search

- The Problem It Solves: All the previous methods have a fundamental trade-off:
    - Greedy/Beam Search: High quality but repetitive ("very very very...")
    - Sampling (temp/top-k/top-p): Diverse but can be incoherent

- Contrastive search tries to get the best of both worlds: high quality AND diversity.

#### The Core Idea
Instead of just picking high-probability tokens, contrastive search balances two objectives:

1. Model confidence (like greedy): Pick tokens the model thinks are likely
2. Degeneration penalty: Avoid tokens that are too similar to what we've already generated

The insight: Repetition happens because similar contexts lead to similar next tokens. If we explicitly penalize similarity to past tokens, we can break the repetition cycle while maintaining coherence.

```
score(token) = (1 - α) × model_confidence(token) 
               - α × max_similarity_to_context(token)
```

Where:

- α (alpha): Hyperparameter controlling the trade-off (typically 0.6)
- model_confidence: Probability from the model (like in greedy)
- max_similarity_to_context: How similar this token is to recently generated tokens

```
For candidate token w:
1. Get its hidden representation: h(w)
2. Compare to hidden states of previously generated tokens: h(context)
3. Compute cosine similarity: cos_sim(h(w), h(context))
4. Take the MAXIMUM similarity (penalize being similar to ANY past token)
```

Pick the token with the highest score.

##### Advantages ✅

- Deterministic: Same input → same output (unlike sampling)
- No repetition: Explicitly penalizes it
- Maintains coherence: Still respects model probabilities
- No temperature tuning: α is more intuitive than temperature
- Human-like text: Often produces more natural outputs than beam search

##### Disadvantages ❌

- Computationally expensive: Need to compute similarities for top-k candidates
- Requires hidden states: Need access to model internals (not just logits)
- Memory intensive: Must store hidden states of all generated tokens
- Less diverse than sampling: Still deterministic, no randomness
- Hyperparameter sensitive: α needs tuning per task

##### Context Window Consideration
An important practical detail: Which past tokens do we compare to?
Typically:

- Compare to tokens in a sliding window (e.g., last 20-50 tokens)
- Not the entire history (too expensive and less relevant)
- Recent context matters most for avoiding repetition

In [10]:
def constrastive_search(model, input_ids, eos_token_id, max_length=50, alpha=0.6, k=4):
    generated = input_ids.clone()
    past_hidden_states = []
    
    for _ in range(max_length):
        with torch.no_grad():
            logits = model(input_ids, output_hidden_states=True)
            next_logits = logits[:, -1, :]
            hidden_states = logits.hidden_states[-1][:, -1, :]

        past_hidden_states.append(hidden_states)

        probs = torch.softmax(next_logits, dim=-1)
        top_k_probs, top_k_indices = torch.topk(probs, k=k)

        candidate_hidden_states = model.get_input_embeddings()[top_k_indices]

        scores = []

        for i in range(k):
            model_conf = top_k_probs[0, i].item()
            candidate_h = candidate_hidden_states[0, i]

            max_similarity = 0.0
            if len(past_hidden_states) > 0:
                past_h = torch.stack(past_hidden_states).squeeze(1)
                
                # Compute cosine similarity
                similarities = torch.nn.functional.cosine_similarity(
                    candidate_h.unsqueeze(0), 
                    past_h,                    
                    dim=-1
                )
                max_similarity = similarities.max().item()
            
            score = (1 - alpha) * model_conf - alpha * max_similarity
            scores.append(score)

        best_idx = torch.argmax(torch.tensor(scores))
        next_token = top_k_indices[0, best_idx].unsqueeze(0).unsqueeze(0)  # [1, 1]
        
        generated = torch.cat([generated, next_token], dim=1)
        
        if next_token.item() == eos_token_id:
            break

    return generated