# ðŸ“œ Longformer: Long Document Transformer

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/transformer_problems/blob/transformer/transformer_architectures/09_longformer/demo.ipynb)

![Architecture](architecture.png)

### Key Innovation
- **Sliding Window Attention**: Local context efficiently
- **Dilated Sliding Window**: Larger receptive field
- **Global Attention**: Selected tokens attend to all

In [None]:
!pip install torch matplotlib numpy -q
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import numpy as np

torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

## Longformer Attention Patterns

In [None]:
def create_longformer_patterns(seq_len, window_size=4, dilation=2, global_indices=[0]):
    """Create Longformer attention patterns."""
    patterns = {}
    
    # 1. Sliding window attention
    sliding = torch.zeros(seq_len, seq_len)
    for i in range(seq_len):
        start = max(0, i - window_size // 2)
        end = min(seq_len, i + window_size // 2 + 1)
        sliding[i, start:end] = 1
    patterns['Sliding Window'] = sliding
    
    # 2. Dilated sliding window
    dilated = torch.zeros(seq_len, seq_len)
    for i in range(seq_len):
        for offset in range(-window_size // 2, window_size // 2 + 1):
            j = i + offset * dilation
            if 0 <= j < seq_len:
                dilated[i, j] = 1
    patterns['Dilated Window'] = dilated
    
    # 3. Global attention (specific tokens attend to everything)
    global_attn = torch.zeros(seq_len, seq_len)
    for idx in global_indices:
        global_attn[idx, :] = 1  # Global token attends to all
        global_attn[:, idx] = 1  # All attend to global token
    patterns['Global Tokens'] = global_attn
    
    # 4. Combined Longformer pattern
    combined = (sliding + global_attn).clamp(0, 1)
    patterns['Longformer Combined'] = combined
    
    return patterns

# Visualize patterns
seq_len = 32
patterns = create_longformer_patterns(seq_len, window_size=6, global_indices=[0, 15, 31])

fig, axes = plt.subplots(1, 4, figsize=(16, 4))
for ax, (name, pattern) in zip(axes, patterns.items()):
    ax.imshow(pattern, cmap='Blues')
    ax.set_title(f'{name}\n{pattern.sum().item():.0f}/{seq_len**2} = {pattern.mean().item()*100:.1f}%')
    ax.set_xlabel('Key')
    ax.set_ylabel('Query')
plt.tight_layout()
plt.show()

print('Sliding window: Each token attends to nearby tokens')
print('Global tokens: [CLS], special positions attend everywhere')

## Complexity Analysis

In [None]:
def complexity_comparison():
    seq_lengths = [512, 1024, 2048, 4096, 8192, 16384]
    window_size = 256
    n_global = 4
    
    # Full attention: O(NÂ²)
    full = [n**2 for n in seq_lengths]
    
    # Longformer: O(N Ã— window) + O(N Ã— n_global)
    longformer = [n * window_size + n * n_global for n in seq_lengths]
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Absolute
    ax = axes[0]
    ax.plot(seq_lengths, full, 'ro-', label='Full Attention O(NÂ²)', linewidth=2)
    ax.plot(seq_lengths, longformer, 'go-', label='Longformer O(NÂ·w)', linewidth=2)
    ax.set_xlabel('Sequence Length')
    ax.set_ylabel('Operations')
    ax.set_title('Absolute Complexity')
    ax.legend()
    ax.set_yscale('log')
    ax.grid(True, alpha=0.3)
    
    # Speedup
    ax = axes[1]
    speedup = [f/l for f, l in zip(full, longformer)]
    ax.bar(range(len(seq_lengths)), speedup, color='green', alpha=0.7)
    ax.set_xticks(range(len(seq_lengths)))
    ax.set_xticklabels(seq_lengths)
    ax.set_xlabel('Sequence Length')
    ax.set_ylabel('Speedup (Ã—)')
    ax.set_title('Longformer Speedup over Full Attention')
    for i, s in enumerate(speedup):
        ax.text(i, s + 0.5, f'{s:.0f}Ã—', ha='center')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f'\nAt seq_len=16384:')
    print(f'  Full: {full[-1]:,} ops')
    print(f'  Longformer: {longformer[-1]:,} ops')
    print(f'  Speedup: {speedup[-1]:.0f}Ã—')

complexity_comparison()

## Longformer Implementation

In [None]:
class LongformerAttention(nn.Module):
    """Sliding window + global attention."""
    def __init__(self, d_model, n_heads, window_size=64, dropout=0.1):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.scale = self.d_k ** -0.5
        self.window_size = window_size
        
        self.W_qkv = nn.Linear(d_model, 3 * d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def _create_window_mask(self, seq_len):
        """Create sliding window attention mask."""
        mask = torch.zeros(seq_len, seq_len)
        half_window = self.window_size // 2
        
        for i in range(seq_len):
            start = max(0, i - half_window)
            end = min(seq_len, i + half_window + 1)
            mask[i, start:end] = 1
        
        return mask
    
    def forward(self, x, global_mask=None):
        B, T, C = x.shape
        
        # QKV projection
        qkv = self.W_qkv(x).reshape(B, T, 3, self.n_heads, self.d_k).permute(2, 0, 3, 1, 4)
        Q, K, V = qkv[0], qkv[1], qkv[2]
        
        # Compute attention
        attn = (Q @ K.transpose(-2, -1)) * self.scale
        
        # Apply sliding window mask
        window_mask = self._create_window_mask(T).to(x.device)
        
        # Add global attention if specified
        if global_mask is not None:
            # Global tokens can attend to everything
            combined_mask = window_mask.clone()
            combined_mask[global_mask, :] = 1  # Global queries
            combined_mask[:, global_mask] = 1  # All attend to global keys
        else:
            combined_mask = window_mask
        
        # Apply causal constraint
        causal_mask = torch.tril(torch.ones(T, T, device=x.device))
        combined_mask = combined_mask * causal_mask
        
        attn = attn.masked_fill(combined_mask.unsqueeze(0).unsqueeze(0) == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        
        out = (attn @ V).transpose(1, 2).reshape(B, T, C)
        return self.W_o(out), attn

class LongformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, window_size=64, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = LongformerAttention(d_model, n_heads, window_size, dropout)
        self.norm2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
            nn.Dropout(dropout)
        )
    
    def forward(self, x, global_mask=None):
        attn_out, attn_weights = self.attn(self.norm1(x), global_mask)
        x = x + attn_out
        x = x + self.ff(self.norm2(x))
        return x, attn_weights

class Longformer(nn.Module):
    def __init__(self, vocab_size, d_model=128, n_heads=4, n_layers=3, window_size=64, max_len=512, dropout=0.1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)
        self.dropout = nn.Dropout(dropout)
        
        self.layers = nn.ModuleList([LongformerBlock(d_model, n_heads, window_size, dropout) for _ in range(n_layers)])
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)
    
    def forward(self, x, global_mask=None):
        B, T = x.shape
        pos = torch.arange(T, device=x.device).unsqueeze(0)
        
        x = self.dropout(self.embed(x) + self.pos_embed(pos))
        
        attn_weights = []
        for layer in self.layers:
            x, attn = layer(x, global_mask)
            attn_weights.append(attn)
        
        return self.head(self.norm(x)), attn_weights

model = Longformer(vocab_size=1000, d_model=64, n_heads=4, n_layers=2, window_size=16).to(device)
print(f'Longformer Parameters: {sum(p.numel() for p in model.parameters()):,}')

## Training Longformer

In [None]:
# Dataset
text = 'the quick brown fox jumps over the lazy dog and the cat sleeps on the warm mat ' * 200
chars = sorted(list(set(text)))
vocab_size = len(chars)
char_to_idx = {c: i for i, c in enumerate(chars)}
data = torch.tensor([char_to_idx[c] for c in text], dtype=torch.long)

# Training with global tokens
seq_len = 128
model = Longformer(vocab_size=vocab_size, d_model=64, n_heads=4, n_layers=2, window_size=16, max_len=seq_len).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Global mask: first token and every 32nd token
global_indices = torch.tensor([0, 32, 64, 96]).to(device)

losses = []
n_steps = 300

print('Training Longformer with sliding window + global attention...')
for step in range(n_steps):
    idx = torch.randint(0, len(data) - seq_len - 1, (16,))
    x = torch.stack([data[i:i+seq_len] for i in idx]).to(device)
    y = torch.stack([data[i+1:i+seq_len+1] for i in idx]).to(device)
    
    optimizer.zero_grad()
    logits, _ = model(x, global_mask=global_indices)
    loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())
    if (step + 1) % 50 == 0:
        print(f'Step {step+1}: Loss = {loss.item():.4f}')

plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Longformer Training')
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Visualize attention pattern
model.eval()
test_input = torch.randint(0, vocab_size, (1, 64)).to(device)
global_mask = torch.tensor([0, 31, 63]).to(device)

with torch.no_grad():
    _, attn_weights = model(test_input, global_mask=global_mask)

attn = attn_weights[-1][0, 0].cpu()

plt.figure(figsize=(10, 8))
plt.imshow(attn, cmap='Blues')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title('Longformer Attention Pattern\n(Global tokens: 0, 31, 63)')
plt.colorbar()

# Mark global positions
for g in [0, 31, 63]:
    plt.axhline(y=g, color='red', linestyle='--', alpha=0.5, linewidth=0.5)
    plt.axvline(x=g, color='red', linestyle='--', alpha=0.5, linewidth=0.5)

plt.show()

print('\nðŸŽ¯ Key Takeaways:')
print('1. Sliding window: Local attention O(N Ã— window)')
print('2. Global tokens: Selected positions attend everywhere')
print('3. Ideal for long documents: [CLS] gets global view')
print('4. Used in document QA, summarization, classification')