# Algorithm 8: Pairformer Stack (AlphaFold3)

The Pairformer Stack replaces AlphaFold2's Evoformer for pair representation processing. Unlike Evoformer, Pairformer also processes single (per-token) representations but without MSA.

## Key Difference from AlphaFold2 Evoformer

| Aspect | Evoformer (AF2) | Pairformer (AF3) |
|--------|-----------------|------------------|
| MSA Track | Yes (full processing) | No (MSA handled separately) |
| Single Track | Derived from MSA | Explicit single representation |
| Blocks | 48 | 48 |
| Single ↔ Pair | MSA outer product | Single attention with pair bias |

## Source Code Location
- **File**: `AF3-Ref-src/alphafold3-official/src/alphafold3/model/network/modules.py`
- **Class**: `PairFormerIteration`

## Overview

### Pairformer Block Structure

```
Input: (single, pair)
       ↓
┌──────────────────────────────────────┐
│  Triangle Multiplication (Outgoing)  │
│  Triangle Multiplication (Incoming)  │
│  Triangle Attention (Starting)       │
│  Triangle Attention (Ending)         │
│  Pair Transition                     │
└──────────────────────────────────────┘
       ↓
       pair_updated
       ↓
┌──────────────────────────────────────┐
│  Single Attention with Pair Bias     │
│  Single Transition                   │
└──────────────────────────────────────┘
       ↓
Output: (single_updated, pair_updated)
```

In [None]:
import numpy as np

np.random.seed(42)

## NumPy Implementation

In [None]:
def layer_norm(x, axis=-1, eps=1e-5):
    """Layer normalization."""
    mean = np.mean(x, axis=axis, keepdims=True)
    var = np.var(x, axis=axis, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)


def softmax(x, axis=-1):
    """Softmax function."""
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)


def sigmoid(x):
    """Sigmoid activation."""
    return 1 / (1 + np.exp(-np.clip(x, -500, 500)))

In [None]:
def triangle_multiplication_outgoing(z, c=128):
    """
    Triangle Multiplication with outgoing edges.
    
    Updates z_ij using z_ik and z_jk for all k.
    """
    N = z.shape[0]
    c_z = z.shape[-1]
    
    z_norm = layer_norm(z)
    
    # Project to a and b
    W_a = np.random.randn(c_z, c) * (c_z ** -0.5)
    W_b = np.random.randn(c_z, c) * (c_z ** -0.5)
    W_g = np.random.randn(c_z, c) * (c_z ** -0.5)
    
    a = sigmoid(z_norm @ W_a) * (z_norm @ W_a)  # Gated
    b = sigmoid(z_norm @ W_b) * (z_norm @ W_b)  # Gated
    g = sigmoid(z_norm @ W_g)
    
    # Triangle multiplication: sum_k a_ik * b_jk
    # a[i, k] * b[j, k] summed over k
    output = np.einsum('ikc,jkc->ijc', a, b)
    output = layer_norm(output)
    
    # Output projection with gating
    W_o = np.random.randn(c, c_z) * (c ** -0.5)
    output = (output @ W_o) * g
    
    return output


def triangle_multiplication_incoming(z, c=128):
    """
    Triangle Multiplication with incoming edges.
    
    Updates z_ij using z_ki and z_kj for all k.
    """
    N = z.shape[0]
    c_z = z.shape[-1]
    
    z_norm = layer_norm(z)
    
    W_a = np.random.randn(c_z, c) * (c_z ** -0.5)
    W_b = np.random.randn(c_z, c) * (c_z ** -0.5)
    W_g = np.random.randn(c_z, c) * (c_z ** -0.5)
    
    a = sigmoid(z_norm @ W_a) * (z_norm @ W_a)
    b = sigmoid(z_norm @ W_b) * (z_norm @ W_b)
    g = sigmoid(z_norm @ W_g)
    
    # Triangle multiplication: sum_k a_ki * b_kj
    output = np.einsum('kic,kjc->ijc', a, b)
    output = layer_norm(output)
    
    W_o = np.random.randn(c, c_z) * (c ** -0.5)
    output = (output @ W_o) * g
    
    return output

In [None]:
def triangle_attention_starting(z, num_heads=4, c=32):
    """
    Triangle Attention around starting node.
    
    For each i, attends over j positions using k as keys.
    """
    N = z.shape[0]
    c_z = z.shape[-1]
    
    z_norm = layer_norm(z)
    
    # QKV projections
    W_q = np.random.randn(c_z, num_heads, c) * (c_z ** -0.5)
    W_k = np.random.randn(c_z, num_heads, c) * (c_z ** -0.5)
    W_v = np.random.randn(c_z, num_heads, c) * (c_z ** -0.5)
    W_b = np.random.randn(c_z, num_heads) * (c_z ** -0.5)
    W_g = np.random.randn(c_z, num_heads, c) * (c_z ** -0.5)
    
    q = np.einsum('ijc,chd->ijhd', z_norm, W_q)
    k = np.einsum('ijc,chd->ijhd', z_norm, W_k)
    v = np.einsum('ijc,chd->ijhd', z_norm, W_v)
    b = np.einsum('ijc,ch->ijh', z_norm, W_b)
    g = sigmoid(np.einsum('ijc,chd->ijhd', z_norm, W_g))
    
    # Attention: for each (i,j), attend over k
    # q[i,j] attends to k[i,k] with bias b[j,k]
    attn_logits = np.einsum('ijhd,ikhd->ijkh', q, k) / np.sqrt(c)
    attn_logits = attn_logits + b[None, :, :, :].transpose(0, 2, 1, 3)  # Add bias
    attn_weights = softmax(attn_logits, axis=2)
    
    # Apply attention
    attended = np.einsum('ijkh,ikhd->ijhd', attn_weights, v)
    attended = attended * g
    
    # Output projection
    W_o = np.random.randn(num_heads, c, c_z) * ((num_heads * c) ** -0.5)
    output = np.einsum('ijhd,hdc->ijc', attended, W_o)
    
    return output

In [None]:
def single_attention_with_pair_bias(s, z, num_heads=8, c=32):
    """
    Single Attention with Pair Bias.
    
    Key difference from AF2: explicit single representation attention.
    
    Args:
        s: Single representation [N_tokens, c_s]
        z: Pair representation [N_tokens, N_tokens, c_z]
    """
    N = s.shape[0]
    c_s = s.shape[-1]
    c_z = z.shape[-1]
    
    s_norm = layer_norm(s)
    z_norm = layer_norm(z)
    
    # QKV from single
    W_q = np.random.randn(c_s, num_heads, c) * (c_s ** -0.5)
    W_k = np.random.randn(c_s, num_heads, c) * (c_s ** -0.5)
    W_v = np.random.randn(c_s, num_heads, c) * (c_s ** -0.5)
    W_g = np.random.randn(c_s, num_heads, c) * (c_s ** -0.5)
    
    q = np.einsum('ic,chd->ihd', s_norm, W_q)
    k = np.einsum('jc,chd->jhd', s_norm, W_k)
    v = np.einsum('jc,chd->jhd', s_norm, W_v)
    g = sigmoid(np.einsum('ic,chd->ihd', s_norm, W_g))
    
    # Bias from pair
    W_b = np.random.randn(c_z, num_heads) * (c_z ** -0.5)
    b = np.einsum('ijc,ch->ijh', z_norm, W_b)
    
    # Attention
    attn_logits = np.einsum('ihd,jhd->ijh', q, k) / np.sqrt(c)
    attn_logits = attn_logits + b
    attn_weights = softmax(attn_logits, axis=1)
    
    # Apply attention
    attended = np.einsum('ijh,jhd->ihd', attn_weights, v)
    attended = attended * g
    
    # Output projection
    W_o = np.random.randn(num_heads, c, c_s) * ((num_heads * c) ** -0.5)
    output = np.einsum('ihd,hdc->ic', attended, W_o)
    
    return output

In [None]:
def transition_block(x, n=4):
    """Feed-forward transition with SwiGLU activation."""
    c = x.shape[-1]
    x_norm = layer_norm(x)
    
    W_up = np.random.randn(c, c * n * 2) * (c ** -0.5)
    W_down = np.random.randn(c * n, c) * ((c * n) ** -0.5)
    
    hidden = x_norm @ W_up
    a, b = np.split(hidden, 2, axis=-1)
    hidden = a * sigmoid(a) * b  # SwiGLU
    
    return hidden @ W_down

In [None]:
def pairformer_block(s, z):
    """
    Single Pairformer block.
    
    Args:
        s: Single representation [N, c_s]
        z: Pair representation [N, N, c_z]
    
    Returns:
        Updated (s, z)
    """
    N = s.shape[0]
    c_s = s.shape[-1]
    c_z = z.shape[-1]
    
    print(f"Pairformer Block")
    print(f"  Single: [{N}, {c_s}], Pair: [{N}, {N}, {c_z}]")
    
    # Pair track updates
    z = z + triangle_multiplication_outgoing(z, c=c_z // 2)
    print(f"  + Triangle Mult Outgoing")
    
    z = z + triangle_multiplication_incoming(z, c=c_z // 2)
    print(f"  + Triangle Mult Incoming")
    
    z = z + triangle_attention_starting(z, num_heads=4, c=c_z // 4)
    print(f"  + Triangle Attention Starting")
    
    z = z + transition_block(z, n=2)
    print(f"  + Pair Transition")
    
    # Single track updates (using updated pair)
    s = s + single_attention_with_pair_bias(s, z, num_heads=8, c=c_s // 8)
    print(f"  + Single Attention with Pair Bias")
    
    s = s + transition_block(s, n=4)
    print(f"  + Single Transition")
    
    return s, z


def pairformer_stack(s, z, num_blocks=4):
    """
    Full Pairformer Stack.
    
    Args:
        s: Initial single representation
        z: Initial pair representation
        num_blocks: Number of Pairformer blocks
    
    Returns:
        Final (s, z)
    """
    print(f"\nPairformer Stack ({num_blocks} blocks)")
    print(f"="*50)
    
    for i in range(num_blocks):
        print(f"\nBlock {i + 1}:")
        s, z = pairformer_block(s, z)
    
    return s, z

## Test Examples

In [None]:
# Test 1: Single Pairformer block
print("Test 1: Single Pairformer Block")
print("="*60)

N = 32
c_s = 128
c_z = 64

s = np.random.randn(N, c_s).astype(np.float32)
z = np.random.randn(N, N, c_z).astype(np.float32)

s_out, z_out = pairformer_block(s, z)

print(f"\nInput:  s={s.shape}, z={z.shape}")
print(f"Output: s={s_out.shape}, z={z_out.shape}")

In [None]:
# Test 2: Pairformer Stack
print("\nTest 2: Pairformer Stack")
print("="*60)

np.random.seed(42)
N = 24
c_s = 64
c_z = 32

s = np.random.randn(N, c_s).astype(np.float32) * 0.1
z = np.random.randn(N, N, c_z).astype(np.float32) * 0.1

s_out, z_out = pairformer_stack(s, z, num_blocks=2)

In [None]:
# Test 3: Single Attention with Pair Bias analysis
print("\nTest 3: Single Attention Analysis")
print("="*60)

N = 16
c_s = 64
c_z = 32

# Create structured pair (diagonal dominance)
s = np.random.randn(N, c_s)
z = np.eye(N)[:, :, None] * np.ones(c_z) * 2 + np.random.randn(N, N, c_z) * 0.1

s_out = single_attention_with_pair_bias(s, z, num_heads=4, c=16)

print(f"Single input norm: {np.linalg.norm(s):.4f}")
print(f"Single output norm: {np.linalg.norm(s_out):.4f}")
print(f"Pair diagonal mean: {np.diag(z[:, :, 0]).mean():.4f}")

## Verification: Key Properties

In [None]:
print("Verification: Key Properties")
print("="*60)

np.random.seed(42)
N = 20
c_s = 64
c_z = 32

s = np.random.randn(N, c_s)
z = np.random.randn(N, N, c_z)

s_out, z_out = pairformer_block(s, z)

# Property 1: Shapes preserved
shapes_preserved = s_out.shape == s.shape and z_out.shape == z.shape
print(f"Property 1 - Shapes preserved: {shapes_preserved}")

# Property 2: Finite outputs
all_finite = np.isfinite(s_out).all() and np.isfinite(z_out).all()
print(f"Property 2 - All finite: {all_finite}")

# Property 3: Representations change (non-trivial updates)
s_changed = not np.allclose(s, s_out)
z_changed = not np.allclose(z, z_out)
print(f"Property 3 - Single changed: {s_changed}, Pair changed: {z_changed}")

# Property 4: Triangle ops maintain symmetry properties
z_sym = (z + z.transpose(1, 0, 2)) / 2
tri_out = triangle_multiplication_outgoing(z_sym)
print(f"Property 4 - Triangle output finite: {np.isfinite(tri_out).all()}")

## Key Insights

1. **Decoupled MSA**: Unlike Evoformer, Pairformer doesn't process MSA. MSA is handled in a separate, simpler module.

2. **Explicit Single Track**: AlphaFold3 maintains an explicit single (per-token) representation that's updated alongside pair.

3. **Single ← Pair**: Single representation is updated using pair as bias, not through outer product from MSA.

4. **SwiGLU Activation**: Uses SwiGLU (Swish-Gated Linear Unit) instead of ReLU for transitions.

5. **Same Triangle Operations**: Triangle multiplication and attention remain similar to AlphaFold2.

6. **48 Blocks**: Same depth as Evoformer, maintaining model capacity.