# Module 4.2: Sliding Window Attention

**Goal**: Implement sliding window attention for efficient long sequences

**Time**: 75 minutes

**Concepts Covered**:
- Sliding window attention implementation
- Compare full vs windowed attention
- Measure effective context length (L × W)
- Benchmark speed and memory
- Visualize attention patterns

## Setup

In [None]:
!pip install torch transformers accelerate matplotlib seaborn numpy -q

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

def sliding_window_attention(Q, K, V, window_size=512, causal=True):
    """Sliding Window Attention
    
    Args:
        Q, K, V: Query, Key, Value matrices (batch, seq_len, d_k)
        window_size: Size of attention window
        causal: Whether to use causal masking
    """
    batch_size, seq_len, d_k = Q.shape
    
    # Compute attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_k)
    
    # Create sliding window mask
    mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
    for i in range(seq_len):
        start = max(0, i - window_size // 2)
        end = min(seq_len, i + window_size // 2 + 1)
        mask[i, start:end] = True
        
        if causal:
            mask[i, i+1:] = False
    
    # Apply mask
    scores = scores.masked_fill(~mask.unsqueeze(0), float('-inf'))
    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, V)
    
    return output, attn_weights

# Test sliding window attention
seq_len = 2048
d_k = 64
window_size = 512

Q = torch.randn(1, seq_len, d_k)
K = torch.randn(1, seq_len, d_k)
V = torch.randn(1, seq_len, d_k)

output, attn = sliding_window_attention(Q, K, V, window_size)

print(f"Sequence length: {seq_len}")
print(f"Window size: {window_size}")
print(f"Output shape: {output.shape}")
print(f"Memory efficient: O(seq_len × window_size) instead of O(seq_len²)")

## Key Takeaways

✅ **Module Complete**

## Next Steps

Continue to the next module in the course.