# Module 11.4: Long Context Techniques

**Goal**: Extend context length beyond training with YaRN and streaming

**Time**: 90 minutes

**Concepts Covered**:
- YaRN (RoPE extension) implementation
- Streaming inference with sliding window
- Sink token preservation
- Context length extension beyond training
- Memory-efficient long-context handling

## Setup

In [None]:
!pip install torch transformers accelerate matplotlib seaborn numpy -q

In [None]:
# YaRN: Yet another RoPE extensioN
import torch
import torch.nn as nn
import math

def apply_yarn_scaling(rope_freqs, scale_factor, original_max_len, new_max_len):
    """Apply YaRN scaling to RoPE frequencies"""
    # YaRN scales frequencies to extend context
    # Scale factor typically: new_max_len / original_max_len
    
    # Low frequency components: scale down
    # High frequency components: keep original
    
    scaled_freqs = rope_freqs.clone()
    
    # Find transition point
    alpha = scale_factor
    beta = 32  # Hyperparameter
    
    # Scale low frequencies
    for i in range(len(rope_freqs)):
        if rope_freqs[i] < beta:
            scaled_freqs[i] = rope_freqs[i] / alpha
        else:
            # Keep high frequencies
            scaled_freqs[i] = rope_freqs[i]
    
    return scaled_freqs

# Example
original_max_len = 2048
new_max_len = 8192
scale_factor = new_max_len / original_max_len

# Simulated RoPE frequencies
rope_freqs = torch.linspace(0.1, 100, 64)

scaled_freqs = apply_yarn_scaling(rope_freqs, scale_factor, original_max_len, new_max_len)

print(f"YaRN Scaling:")
print(f"  Original max length: {original_max_len}")
print(f"  New max length: {new_max_len}")
print(f"  Scale factor: {scale_factor:.2f}")
print(f"  Low freq scaling: applied")
print(f"  High freq preservation: applied")

In [None]:
# Streaming Inference with Sliding Window
class SlidingWindowInference:
    def __init__(self, model, window_size=2048, stride=512):
        self.model = model
        self.window_size = window_size
        self.stride = stride
        self.kv_cache = None
    
    def process_long_sequence(self, tokens, max_new_tokens=100):
        """Process long sequence with sliding window"""
        seq_len = len(tokens)
        
        if seq_len <= self.window_size:
            # Fits in one window
            return self.model.generate(tokens, max_new_tokens=max_new_tokens)
        
        # Sliding window approach
        outputs = []
        start_idx = 0
        
        while start_idx < seq_len:
            end_idx = min(start_idx + self.window_size, seq_len)
            window_tokens = tokens[start_idx:end_idx]
            
            # Process window
            window_output = self.model.generate(
                window_tokens,
                max_new_tokens=max_new_tokens if end_idx == seq_len else 0
            )
            
            outputs.append(window_output)
            
            # Move window
            start_idx += self.stride
        
        # Combine outputs
        return self._combine_outputs(outputs)
    
    def _combine_outputs(self, outputs):
        """Combine outputs from multiple windows"""
        # Simple concatenation (in practice, use overlap handling)
        return torch.cat(outputs, dim=0)

print("Sliding Window Inference:")
print("- Process long sequences in chunks")
print("- Maintain context across windows")
print("- Memory efficient for long contexts")

## Key Takeaways

✅ **Module Complete**

## Next Steps

Continue to the next module in the course.