# ðŸŽ­ Performer: Fast Attention via FAVOR+

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/transformer_problems/blob/main/transformer_architectures/07_performer/demo.ipynb)

![Architecture](architecture.png)

### Key Innovation
- **FAVOR+**: Fast Attention Via positive Orthogonal Random features
- **O(N) Complexity**: Linear instead of quadratic!
- **Unbiased Estimator**: Approximates softmax attention

In [None]:
!pip install torch matplotlib numpy -q
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import numpy as np

torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

## The Math Behind FAVOR+

Standard attention: `Attn = softmax(QK^T / âˆšd) V`

FAVOR+ approximation:
1. `K(x,y) â‰ˆ Ï†(x)^T Ï†(y)` where Ï† is random feature map
2. `Attn â‰ˆ Ï†(Q) (Ï†(K)^T V)` - can compute in O(N)!

In [None]:
def visualize_favor_math():
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Standard attention computation
    ax = axes[0]
    N, d = 8, 4
    Q = np.random.randn(N, d)
    K = np.random.randn(N, d)
    V = np.random.randn(N, d)
    
    # QK^T is NÃ—N
    ax.text(0.5, 0.9, 'Standard Attention', fontsize=14, ha='center', transform=ax.transAxes)
    ax.text(0.5, 0.7, r'$\mathbf{QK}^T$ â†’ NÃ—N matrix', fontsize=11, ha='center', transform=ax.transAxes)
    ax.text(0.5, 0.5, f'For N={N}, d={d}:', fontsize=10, ha='center', transform=ax.transAxes)
    ax.text(0.5, 0.3, f'Memory: O(NÂ²) = O({N**2})', fontsize=10, ha='center', transform=ax.transAxes, color='red')
    ax.axis('off')
    
    # FAVOR+ decomposition
    ax = axes[1]
    ax.text(0.5, 0.9, 'FAVOR+ Trick', fontsize=14, ha='center', transform=ax.transAxes)
    ax.text(0.5, 0.7, r'$\phi(Q) [\phi(K)^T V]$', fontsize=11, ha='center', transform=ax.transAxes)
    ax.text(0.5, 0.5, 'Compute (K^T V) first!', fontsize=10, ha='center', transform=ax.transAxes)
    ax.text(0.5, 0.3, f'Memory: O(dÂ²) = O({d**2})', fontsize=10, ha='center', transform=ax.transAxes, color='green')
    ax.axis('off')
    
    # Comparison
    ax = axes[2]
    seq_lens = [64, 256, 1024, 4096]
    standard = [n**2 for n in seq_lens]
    favor = [n * d**2 for n in seq_lens]
    
    x = np.arange(len(seq_lens))
    width = 0.35
    ax.bar(x - width/2, standard, width, label='Standard O(NÂ²)', color='coral')
    ax.bar(x + width/2, favor, width, label='FAVOR+ O(NdÂ²)', color='lightgreen')
    ax.set_xticks(x)
    ax.set_xticklabels(seq_lens)
    ax.set_xlabel('Sequence Length')
    ax.set_ylabel('Operations')
    ax.legend()
    ax.set_yscale('log')
    ax.set_title('Complexity Comparison')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

visualize_favor_math()

## Random Feature Maps

In [None]:
def random_feature_map(x, n_features, is_query=True):
    """FAVOR+ random feature map using positive random features."""
    d = x.shape[-1]
    
    # Orthogonal random features (simplified)
    torch.manual_seed(0)  # Fixed for reproducibility
    omega = torch.randn(d, n_features, device=x.device) / math.sqrt(d)
    
    # Project: x @ omega â†’ (B, N, n_features)
    projection = x @ omega
    
    # Apply nonlinearity: exp(x - ||x||Â²/2) for positive features
    norm_sq = (x ** 2).sum(dim=-1, keepdim=True) / 2
    
    # Positive random features
    features = torch.exp(projection - norm_sq) / math.sqrt(n_features)
    
    return features

# Demonstrate approximation quality
def compare_attention_methods(seq_len=64, d_model=32, n_features=64):
    Q = torch.randn(1, seq_len, d_model)
    K = torch.randn(1, seq_len, d_model)
    V = torch.randn(1, seq_len, d_model)
    
    # Standard attention
    scores = Q @ K.transpose(-2, -1) / math.sqrt(d_model)
    attn_standard = F.softmax(scores, dim=-1)
    out_standard = attn_standard @ V
    
    # FAVOR+ approximation
    Q_prime = random_feature_map(Q / math.sqrt(math.sqrt(d_model)), n_features)
    K_prime = random_feature_map(K / math.sqrt(math.sqrt(d_model)), n_features)
    
    # Linear attention: Q'(K'V) instead of (Q'K')V
    KV = K_prime.transpose(-2, -1) @ V  # (n_features, d_model)
    out_favor = Q_prime @ KV  # (seq_len, d_model)
    
    # Normalize
    normalizer = Q_prime @ K_prime.sum(dim=1, keepdim=True).transpose(-2, -1)
    out_favor = out_favor / (normalizer + 1e-6)
    
    # Compare
    mse = F.mse_loss(out_standard, out_favor).item()
    cosine = F.cosine_similarity(out_standard.flatten(), out_favor.flatten(), dim=0).item()
    
    return mse, cosine

# Test with different number of random features
n_features_list = [16, 32, 64, 128, 256]
mses, cosines = [], []

for n_feat in n_features_list:
    mse, cos = compare_attention_methods(n_features=n_feat)
    mses.append(mse)
    cosines.append(cos)

fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(n_features_list, mses, 'ro-')
axes[0].set_xlabel('Number of Random Features')
axes[0].set_ylabel('MSE')
axes[0].set_title('Approximation Error')
axes[0].grid(True, alpha=0.3)

axes[1].plot(n_features_list, cosines, 'go-')
axes[1].set_xlabel('Number of Random Features')
axes[1].set_ylabel('Cosine Similarity')
axes[1].set_title('Output Similarity')
axes[1].axhline(y=1.0, color='gray', linestyle='--', alpha=0.5)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print('More random features â†’ better approximation!')

## Performer Implementation

In [None]:
class PerformerAttention(nn.Module):
    """FAVOR+ attention with linear complexity."""
    def __init__(self, d_model, n_heads, n_features=None, dropout=0.1):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.n_features = n_features or self.d_k
        
        self.W_qkv = nn.Linear(d_model, 3 * d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        
        # Random projection matrix (fixed)
        self.register_buffer('omega', torch.randn(self.d_k, self.n_features) / math.sqrt(self.d_k))
    
    def _feature_map(self, x):
        """Positive random feature map."""
        projection = x @ self.omega
        norm_sq = (x ** 2).sum(dim=-1, keepdim=True) / 2
        return torch.exp(projection - norm_sq) / math.sqrt(self.n_features)
    
    def forward(self, x, causal=True):
        B, T, C = x.shape
        
        # QKV projection
        qkv = self.W_qkv(x).reshape(B, T, 3, self.n_heads, self.d_k).permute(2, 0, 3, 1, 4)
        Q, K, V = qkv[0], qkv[1], qkv[2]  # (B, H, T, d_k)
        
        # Apply feature map
        Q_prime = self._feature_map(Q / math.sqrt(math.sqrt(self.d_k)))
        K_prime = self._feature_map(K / math.sqrt(math.sqrt(self.d_k)))
        
        if causal:
            # Causal linear attention using prefix sums
            # Compute cumulative sum: KV[i] = sum(K'[j] * V[j]) for j <= i
            KV = torch.zeros(B, self.n_heads, self.n_features, self.d_k, device=x.device)
            K_sum = torch.zeros(B, self.n_heads, self.n_features, 1, device=x.device)
            
            outputs = []
            for t in range(T):
                k_t = K_prime[:, :, t:t+1, :]  # (B, H, 1, n_feat)
                v_t = V[:, :, t:t+1, :]  # (B, H, 1, d_k)
                q_t = Q_prime[:, :, t:t+1, :]  # (B, H, 1, n_feat)
                
                KV = KV + k_t.transpose(-2, -1) @ v_t  # (B, H, n_feat, d_k)
                K_sum = K_sum + k_t.transpose(-2, -1)  # (B, H, n_feat, 1)
                
                out_t = q_t @ KV / (q_t @ K_sum + 1e-6)
                outputs.append(out_t)
            
            out = torch.cat(outputs, dim=2)  # (B, H, T, d_k)
        else:
            # Non-causal: Q'(K'V) - much faster
            KV = K_prime.transpose(-2, -1) @ V  # (B, H, n_feat, d_k)
            K_sum = K_prime.sum(dim=2, keepdim=True).transpose(-2, -1)  # (B, H, n_feat, 1)
            out = Q_prime @ KV / (Q_prime @ K_sum + 1e-6)
        
        out = out.transpose(1, 2).reshape(B, T, C)
        return self.W_o(out)

class PerformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, n_features=None, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = PerformerAttention(d_model, n_heads, n_features, dropout)
        self.norm2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ff(self.norm2(x))
        return x

class Performer(nn.Module):
    def __init__(self, vocab_size, d_model=128, n_heads=4, n_layers=3, n_features=64, max_len=512, dropout=0.1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)
        self.dropout = nn.Dropout(dropout)
        
        self.layers = nn.ModuleList([PerformerBlock(d_model, n_heads, n_features, dropout) for _ in range(n_layers)])
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        B, T = x.shape
        pos = torch.arange(T, device=x.device).unsqueeze(0)
        
        x = self.dropout(self.embed(x) + self.pos_embed(pos))
        for layer in self.layers:
            x = layer(x)
        
        return self.head(self.norm(x))

model = Performer(vocab_size=1000, d_model=64, n_heads=4, n_layers=2, n_features=32).to(device)
print(f'Performer Parameters: {sum(p.numel() for p in model.parameters()):,}')

## Training Performer

In [None]:
# Dataset
text = 'the quick brown fox jumps over the lazy dog ' * 300
chars = sorted(list(set(text)))
vocab_size = len(chars)
char_to_idx = {c: i for i, c in enumerate(chars)}
data = torch.tensor([char_to_idx[c] for c in text], dtype=torch.long)

# Training
seq_len = 64
model = Performer(vocab_size=vocab_size, d_model=64, n_heads=4, n_layers=2, n_features=32, max_len=seq_len).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

losses = []
n_steps = 300

print('Training Performer with linear attention...')
for step in range(n_steps):
    idx = torch.randint(0, len(data) - seq_len - 1, (16,))
    x = torch.stack([data[i:i+seq_len] for i in idx]).to(device)
    y = torch.stack([data[i+1:i+seq_len+1] for i in idx]).to(device)
    
    optimizer.zero_grad()
    logits = model(x)
    loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())
    if (step + 1) % 50 == 0:
        print(f'Step {step+1}: Loss = {loss.item():.4f}')

plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Performer Training (Linear Attention)')
plt.grid(True, alpha=0.3)
plt.show()

print('\nðŸŽ¯ Key Takeaways:')
print('1. FAVOR+ approximates softmax via random features')
print('2. O(N) complexity instead of O(NÂ²)')
print('3. More random features = better approximation')
print('4. Great for very long sequences (10K+ tokens)')