# Problem 1: Quadratic Complexity O(N²)

This notebook demonstrates the O(N²) complexity problem in transformer attention.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/transformer_problems/blob/main/01_quadratic_complexity/demo.ipynb)


In [None]:
# Install dependencies
!pip install torch matplotlib numpy -q

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import time


## The Problem: O(N²) Attention

Standard self-attention computes scores between every pair of tokens, creating an N×N matrix.


In [None]:
def standard_attention(Q, K, V):
    """Standard O(N²) attention"""
    d_k = Q.shape[-1]
    scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_k)
    attention_weights = F.softmax(scores, dim=-1)
    return torch.matmul(attention_weights, V)

# Measure quadratic scaling
d_model = 64
seq_lengths = [128, 256, 512, 1024, 2048]
times = []

print("Seq Length | Time (ms) | Memory (MB)")
print("-" * 40)

for seq_len in seq_lengths:
    Q = torch.randn(1, seq_len, d_model)
    K = torch.randn(1, seq_len, d_model)
    V = torch.randn(1, seq_len, d_model)
    
    start = time.time()
    for _ in range(10):
        _ = standard_attention(Q, K, V)
    elapsed = (time.time() - start) / 10 * 1000
    times.append(elapsed)
    
    mem = seq_len * seq_len * 4 / (1024 * 1024)
    print(f"{seq_len:>10} | {elapsed:>9.2f} | {mem:>10.2f}")

print("\n⚠️ Double sequence = 4x time and memory!")


In [None]:
# Visualize quadratic growth
fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(seq_lengths, times, 'b-o', linewidth=2, markersize=8)
ax.set_xlabel('Sequence Length', fontsize=12)
ax.set_ylabel('Time (ms)', fontsize=12)
ax.set_title('Attention Time Grows Quadratically O(N²)', fontsize=14)
ax.grid(True, alpha=0.3)

# Add quadratic fit line
x = np.array(seq_lengths)
fit = np.polyfit(x, times, 2)
x_smooth = np.linspace(min(x), max(x), 100)
ax.plot(x_smooth, np.polyval(fit, x_smooth), 'r--', alpha=0.5, label='Quadratic fit')
ax.legend()
plt.tight_layout()
plt.show()


## Solution: Linear Attention

By computing K^T @ V first, we avoid creating the N×N matrix!


In [None]:
def linear_attention(Q, K, V):
    """Linear O(N) attention - compute K^T @ V first!"""
    Q_prime = F.relu(Q)  # Feature map
    K_prime = F.relu(K)
    
    # Key insight: compute K^T @ V first (d × d matrix, not N × N)
    KV = torch.matmul(K_prime.transpose(-2, -1), V)  # (d, d)
    output = torch.matmul(Q_prime, KV)  # (N, d)
    
    # Normalize
    normalizer = torch.matmul(Q_prime, K_prime.sum(dim=-2, keepdim=True).transpose(-2, -1))
    return output / (normalizer + 1e-6)

# Compare scaling
print("Standard vs Linear Attention:")
print("Seq Length | Standard (ms) | Linear (ms) | Speedup")
print("-" * 55)

for seq_len in [512, 1024, 2048, 4096]:
    Q = torch.randn(1, seq_len, 64)
    K = torch.randn(1, seq_len, 64)
    V = torch.randn(1, seq_len, 64)
    
    start = time.time()
    _ = standard_attention(Q, K, V)
    std_time = (time.time() - start) * 1000
    
    start = time.time()
    _ = linear_attention(Q, K, V)
    lin_time = (time.time() - start) * 1000
    
    print(f"{seq_len:>10} | {std_time:>13.2f} | {lin_time:>11.2f} | {std_time/lin_time:>6.1f}x")

print("\n✓ Linear attention scales much better for long sequences!")
