# Lecture 11: Efficient Transformers

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/transformer_problems/blob/efficientml-course/efficientml_course/11_efficient_transformers/demo.ipynb)

FlashAttention, sparse attention, and linear attention implementations.


In [None]:
!pip install torch -q
import torch
import torch.nn.functional as F
import time

def standard_attention(Q, K, V):
    """O(N^2) memory - stores full attention matrix"""
    d_k = Q.shape[-1]
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
    attn = F.softmax(scores, dim=-1)
    return torch.matmul(attn, V)

def linear_attention(Q, K, V):
    """O(N) memory - no N×N matrix!"""
    Q = F.elu(Q) + 1
    K = F.elu(K) + 1
    KV = torch.matmul(K.transpose(-2, -1), V)  # d×d instead of N×N
    QKV = torch.matmul(Q, KV)
    normalizer = torch.matmul(Q, K.sum(dim=-2, keepdim=True).transpose(-2, -1))
    return QKV / (normalizer + 1e-6)

# Compare memory and speed
print("Attention Comparison")
print("=" * 50)

for seq_len in [512, 1024, 2048, 4096]:
    Q = K = V = torch.randn(1, seq_len, 64)
    
    # Standard attention memory
    std_mem = seq_len * seq_len * 4 / 1e6  # MB
    lin_mem = 64 * 64 * 4 / 1e6  # MB
    
    # Time comparison
    t0 = time.time()
    _ = standard_attention(Q, K, V)
    std_time = (time.time() - t0) * 1000
    
    t0 = time.time()
    _ = linear_attention(Q, K, V)
    lin_time = (time.time() - t0) * 1000
    
    print(f"N={seq_len}: Standard {std_mem:.1f}MB, {std_time:.1f}ms | Linear {lin_mem:.3f}MB, {lin_time:.1f}ms")
