# âš¡ Sparse Transformer

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/transformer_problems/blob/main/transformer_architectures/06_sparse_transformer/demo.ipynb)

![Architecture](architecture.png)

### Key Innovation
- **Sparse Attention Patterns**: Attend to subset of tokens
- **O(NâˆšN) Complexity**: vs O(NÂ²) for standard attention
- **Scalable**: Handle much longer sequences

In [None]:
!pip install torch matplotlib numpy -q
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import numpy as np

torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

## Problem: Quadratic Attention Complexity

In [None]:
# Memory and FLOPs comparison
def compute_complexity(seq_lengths):
    standard = [n**2 for n in seq_lengths]
    sparse = [n * math.sqrt(n) for n in seq_lengths]
    return standard, sparse

seq_lengths = [128, 256, 512, 1024, 2048, 4096, 8192]
standard, sparse = compute_complexity(seq_lengths)

plt.figure(figsize=(10, 5))
plt.plot(seq_lengths, standard, 'ro-', label='Standard O(NÂ²)', linewidth=2)
plt.plot(seq_lengths, sparse, 'go-', label='Sparse O(NâˆšN)', linewidth=2)
plt.xlabel('Sequence Length')
plt.ylabel('Operations (relative)')
plt.title('Attention Complexity Comparison')
plt.legend()
plt.grid(True, alpha=0.3)
plt.yscale('log')
plt.show()

print(f'\nAt seq_len=8192:')
print(f'  Standard: {standard[-1]:,.0f} operations')
print(f'  Sparse: {sparse[-1]:,.0f} operations')
print(f'  Speedup: {standard[-1]/sparse[-1]:.1f}x')

## Sparse Attention Patterns

In [None]:
def create_attention_patterns(seq_len):
    """Create different sparse attention patterns."""
    patterns = {}
    
    # 1. Full attention (standard)
    full = torch.ones(seq_len, seq_len)
    full = torch.tril(full)  # Causal
    patterns['Full (O(NÂ²))'] = full
    
    # 2. Local attention (sliding window)
    window_size = seq_len // 4
    local = torch.zeros(seq_len, seq_len)
    for i in range(seq_len):
        start = max(0, i - window_size)
        local[i, start:i+1] = 1
    patterns['Local Window'] = local
    
    # 3. Strided attention (every k-th token)
    stride = int(math.sqrt(seq_len))
    strided = torch.zeros(seq_len, seq_len)
    for i in range(seq_len):
        for j in range(0, i+1, stride):
            strided[i, j] = 1
    patterns['Strided'] = strided
    
    # 4. Combined (Sparse Transformer): Local + Strided
    combined = (local + strided).clamp(0, 1)
    patterns['Sparse (Local+Strided)'] = combined
    
    # 5. Fixed positions (attend to first few tokens)
    fixed = local.clone()
    fixed[:, :4] = 1  # Always attend to first 4 tokens
    fixed = torch.tril(fixed)
    patterns['Fixed + Local'] = fixed
    
    return patterns

# Visualize patterns
seq_len = 32
patterns = create_attention_patterns(seq_len)

fig, axes = plt.subplots(1, 5, figsize=(20, 4))
for ax, (name, pattern) in zip(axes, patterns.items()):
    ax.imshow(pattern, cmap='Blues')
    ax.set_title(f'{name}\nDensity: {pattern.sum()/(seq_len**2)*100:.1f}%')
    ax.set_xlabel('Key')
    ax.set_ylabel('Query')
plt.tight_layout()
plt.show()

## Sparse Transformer Implementation

In [None]:
class SparseAttention(nn.Module):
    """Sparse attention with configurable pattern."""
    def __init__(self, d_model, n_heads, seq_len, window_size=None, stride=None, dropout=0.1):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.scale = self.d_k ** -0.5
        self.seq_len = seq_len
        
        self.W_qkv = nn.Linear(d_model, 3 * d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        
        # Create sparse mask
        window_size = window_size or int(math.sqrt(seq_len))
        stride = stride or int(math.sqrt(seq_len))
        
        mask = self._create_sparse_mask(seq_len, window_size, stride)
        self.register_buffer('sparse_mask', mask)
    
    def _create_sparse_mask(self, seq_len, window_size, stride):
        """Combined local + strided pattern."""
        mask = torch.zeros(seq_len, seq_len)
        
        for i in range(seq_len):
            # Local: attend to nearby tokens
            start = max(0, i - window_size)
            mask[i, start:i+1] = 1
            
            # Strided: attend to every stride-th token
            for j in range(0, i+1, stride):
                mask[i, j] = 1
        
        return mask
    
    def forward(self, x):
        B, T, C = x.shape
        
        # QKV projection
        qkv = self.W_qkv(x).reshape(B, T, 3, self.n_heads, self.d_k).permute(2, 0, 3, 1, 4)
        Q, K, V = qkv[0], qkv[1], qkv[2]
        
        # Compute attention scores
        attn = (Q @ K.transpose(-2, -1)) * self.scale
        
        # Apply sparse mask
        mask = self.sparse_mask[:T, :T]
        attn = attn.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float('-inf'))
        
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        
        out = (attn @ V).transpose(1, 2).reshape(B, T, C)
        return self.W_o(out), attn

class SparseTransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, seq_len, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = SparseAttention(d_model, n_heads, seq_len, dropout=dropout)
        self.norm2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        attn_out, attn_weights = self.attn(self.norm1(x))
        x = x + attn_out
        x = x + self.ff(self.norm2(x))
        return x, attn_weights

class SparseTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=128, n_heads=4, n_layers=3, seq_len=256, dropout=0.1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(seq_len, d_model)
        self.dropout = nn.Dropout(dropout)
        
        self.layers = nn.ModuleList([SparseTransformerBlock(d_model, n_heads, seq_len, dropout) for _ in range(n_layers)])
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        B, T = x.shape
        pos = torch.arange(T, device=x.device).unsqueeze(0)
        
        x = self.dropout(self.embed(x) + self.pos_embed(pos))
        
        attn_weights = []
        for layer in self.layers:
            x, attn = layer(x)
            attn_weights.append(attn)
        
        x = self.norm(x)
        return self.head(x), attn_weights

model = SparseTransformer(vocab_size=1000, d_model=64, n_heads=4, n_layers=2, seq_len=128).to(device)
print(f'Sparse Transformer Parameters: {sum(p.numel() for p in model.parameters()):,}')

## Training on Longer Sequences

In [None]:
# Create dataset
text = 'the quick brown fox jumps over the lazy dog and the cat sleeps on the mat ' * 200
chars = sorted(list(set(text)))
vocab_size = len(chars)
char_to_idx = {c: i for i, c in enumerate(chars)}
idx_to_char = {i: c for i, c in enumerate(chars)}
data = torch.tensor([char_to_idx[c] for c in text], dtype=torch.long)

# Training
seq_len = 128  # Longer sequence!
model = SparseTransformer(vocab_size=vocab_size, d_model=64, n_heads=4, n_layers=2, seq_len=seq_len).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

losses = []
n_steps = 300
batch_size = 16

print(f'Training Sparse Transformer (seq_len={seq_len})...')
for step in range(n_steps):
    idx = torch.randint(0, len(data) - seq_len - 1, (batch_size,))
    x = torch.stack([data[i:i+seq_len] for i in idx]).to(device)
    y = torch.stack([data[i+1:i+seq_len+1] for i in idx]).to(device)
    
    optimizer.zero_grad()
    logits, _ = model(x)
    loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())
    
    if (step + 1) % 50 == 0:
        print(f'Step {step+1}: Loss = {loss.item():.4f}')

plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Sparse Transformer Training')
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Visualize learned attention pattern
model.eval()
test_input = torch.randint(0, vocab_size, (1, 64)).to(device)

with torch.no_grad():
    _, attn_weights = model(test_input)

# Plot attention from last layer, first head
attn = attn_weights[-1][0, 0].cpu()

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Sparse mask
axes[0].imshow(model.layers[-1].attn.sparse_mask[:64, :64].cpu(), cmap='Greys', alpha=0.3)
axes[0].set_title('Sparse Attention Mask')
axes[0].set_xlabel('Key')
axes[0].set_ylabel('Query')

# Actual attention weights
axes[1].imshow(attn, cmap='Blues')
axes[1].set_title('Learned Attention Weights')
axes[1].set_xlabel('Key')
axes[1].set_ylabel('Query')

plt.tight_layout()
plt.show()

print('\nðŸŽ¯ Key Takeaways:')
print('1. Sparse patterns reduce O(NÂ²) to O(NâˆšN)')
print('2. Local + strided captures both nearby and distant relationships')
print('3. Enables processing of much longer sequences')
print('4. Used in GPT-3 and other large models for efficiency')