# Algorithm 8: Pairformer Stack (AlphaFold3)

The Pairformer Stack replaces AlphaFold2's Evoformer for pair representation processing.

## Key Difference from AlphaFold2 Evoformer

| Aspect | Evoformer (AF2) | Pairformer (AF3) |
|--------|-----------------|------------------|
| MSA Track | Yes (full processing) | No (MSA handled separately) |
| Single Track | Derived from MSA | Explicit single representation |
| Blocks | 48 | 48 |

## Source Code Location
- **File**: `AF3-Ref-src/alphafold3-official/src/alphafold3/model/network/modules.py`
- **Class**: `PairFormerIteration`

In [None]:
import numpy as np
np.random.seed(42)

In [None]:
def layer_norm(x, axis=-1, eps=1e-5):
    mean = np.mean(x, axis=axis, keepdims=True)
    var = np.var(x, axis=axis, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

def softmax(x, axis=-1):
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

def sigmoid(x):
    return 1 / (1 + np.exp(-np.clip(x, -500, 500)))

In [None]:
def triangle_multiplication_outgoing(z, c=64):
    """Triangle Multiplication with outgoing edges."""
    N = z.shape[0]
    c_z = z.shape[-1]
    z_norm = layer_norm(z)
    
    W_a = np.random.randn(c_z, c) * (c_z ** -0.5)
    W_b = np.random.randn(c_z, c) * (c_z ** -0.5)
    W_g = np.random.randn(c_z, c_z) * (c_z ** -0.5)  # Gate projects to c_z
    
    a = sigmoid(z_norm @ W_a) * (z_norm @ W_a)
    b = sigmoid(z_norm @ W_b) * (z_norm @ W_b)
    g = sigmoid(z_norm @ W_g)
    
    output = np.einsum('ikc,jkc->ijc', a, b)
    output = layer_norm(output)
    
    W_o = np.random.randn(c, c_z) * (c ** -0.5)
    output = (output @ W_o) * g
    
    return output

def triangle_multiplication_incoming(z, c=64):
    """Triangle Multiplication with incoming edges."""
    N = z.shape[0]
    c_z = z.shape[-1]
    z_norm = layer_norm(z)
    
    W_a = np.random.randn(c_z, c) * (c_z ** -0.5)
    W_b = np.random.randn(c_z, c) * (c_z ** -0.5)
    W_g = np.random.randn(c_z, c_z) * (c_z ** -0.5)
    
    a = sigmoid(z_norm @ W_a) * (z_norm @ W_a)
    b = sigmoid(z_norm @ W_b) * (z_norm @ W_b)
    g = sigmoid(z_norm @ W_g)
    
    output = np.einsum('kic,kjc->ijc', a, b)
    output = layer_norm(output)
    
    W_o = np.random.randn(c, c_z) * (c ** -0.5)
    output = (output @ W_o) * g
    
    return output

In [None]:
def triangle_attention_starting(z, num_heads=4, c=16):
    """Triangle Attention around starting node."""
    N = z.shape[0]
    c_z = z.shape[-1]
    z_norm = layer_norm(z)
    
    W_q = np.random.randn(c_z, num_heads, c) * (c_z ** -0.5)
    W_k = np.random.randn(c_z, num_heads, c) * (c_z ** -0.5)
    W_v = np.random.randn(c_z, num_heads, c) * (c_z ** -0.5)
    W_g = np.random.randn(c_z, num_heads, c) * (c_z ** -0.5)
    
    q = np.einsum('ijc,chd->ijhd', z_norm, W_q)
    k = np.einsum('ijc,chd->ijhd', z_norm, W_k)
    v = np.einsum('ijc,chd->ijhd', z_norm, W_v)
    g = sigmoid(np.einsum('ijc,chd->ijhd', z_norm, W_g))
    
    attn_logits = np.einsum('ijhd,ikhd->ijkh', q, k) / np.sqrt(c)
    attn_weights = softmax(attn_logits, axis=2)
    
    attended = np.einsum('ijkh,ikhd->ijhd', attn_weights, v)
    attended = attended * g
    
    W_o = np.random.randn(num_heads, c, c_z) * ((num_heads * c) ** -0.5)
    output = np.einsum('ijhd,hdc->ijc', attended, W_o)
    
    return output

In [None]:
def single_attention_with_pair_bias(s, z, num_heads=8, c=16):
    """Single Attention with Pair Bias."""
    N = s.shape[0]
    c_s = s.shape[-1]
    c_z = z.shape[-1]
    
    s_norm = layer_norm(s)
    z_norm = layer_norm(z)
    
    W_q = np.random.randn(c_s, num_heads, c) * (c_s ** -0.5)
    W_k = np.random.randn(c_s, num_heads, c) * (c_s ** -0.5)
    W_v = np.random.randn(c_s, num_heads, c) * (c_s ** -0.5)
    W_g = np.random.randn(c_s, num_heads, c) * (c_s ** -0.5)
    
    q = np.einsum('ic,chd->ihd', s_norm, W_q)
    k = np.einsum('jc,chd->jhd', s_norm, W_k)
    v = np.einsum('jc,chd->jhd', s_norm, W_v)
    g = sigmoid(np.einsum('ic,chd->ihd', s_norm, W_g))
    
    W_b = np.random.randn(c_z, num_heads) * (c_z ** -0.5)
    b = np.einsum('ijc,ch->ijh', z_norm, W_b)
    
    attn_logits = np.einsum('ihd,jhd->ijh', q, k) / np.sqrt(c)
    attn_logits = attn_logits + b
    attn_weights = softmax(attn_logits, axis=1)
    
    attended = np.einsum('ijh,jhd->ihd', attn_weights, v)
    attended = attended * g
    
    W_o = np.random.randn(num_heads, c, c_s) * ((num_heads * c) ** -0.5)
    output = np.einsum('ihd,hdc->ic', attended, W_o)
    
    return output

In [None]:
def transition_block(x, n=4):
    """Feed-forward transition with SwiGLU."""
    c = x.shape[-1]
    x_norm = layer_norm(x)
    
    W_up = np.random.randn(c, c * n * 2) * (c ** -0.5)
    W_down = np.random.randn(c * n, c) * ((c * n) ** -0.5)
    
    hidden = x_norm @ W_up
    a, b = np.split(hidden, 2, axis=-1)
    hidden = a * sigmoid(a) * b
    
    return hidden @ W_down

In [None]:
def pairformer_block(s, z):
    """Single Pairformer block."""
    c_z = z.shape[-1]
    c_s = s.shape[-1]
    
    # Pair track updates
    z = z + triangle_multiplication_outgoing(z, c=32)
    z = z + triangle_multiplication_incoming(z, c=32)
    z = z + triangle_attention_starting(z, num_heads=4, c=16)
    z = z + transition_block(z, n=2)
    
    # Single track updates
    s = s + single_attention_with_pair_bias(s, z, num_heads=4, c=16)
    s = s + transition_block(s, n=4)
    
    return s, z

def pairformer_stack(s, z, num_blocks=4):
    """Full Pairformer Stack."""
    print(f"Pairformer Stack ({num_blocks} blocks)")
    print(f"="*50)
    
    for i in range(num_blocks):
        s, z = pairformer_block(s, z)
        print(f"Block {i+1}: s_norm={np.linalg.norm(s):.2f}, z_norm={np.linalg.norm(z):.2f}")
    
    return s, z

In [None]:
# Test
print("Test: Pairformer Stack")
print("="*60)

N = 24
c_s = 64
c_z = 32

s = np.random.randn(N, c_s) * 0.1
z = np.random.randn(N, N, c_z) * 0.1

s_out, z_out = pairformer_stack(s, z, num_blocks=2)

print(f"\nOutput shapes: s={s_out.shape}, z={z_out.shape}")
print(f"Outputs finite: {np.isfinite(s_out).all() and np.isfinite(z_out).all()}")

## Key Insights

1. **Decoupled MSA**: Pairformer doesn't process MSA directly
2. **Explicit Single Track**: Maintains separate single representation
3. **SwiGLU Activation**: Uses SwiGLU instead of ReLU
4. **48 Blocks**: Same depth as Evoformer