# Algorithm 4: Affinity Heads Transformer (Boltz-2)

Transformer-based heads for affinity prediction.

## Source Code Location
- **File**: `Boltz-Ref-src/boltz-official/src/boltz/model/modules/affinity.py`
- **Class**: `AffinityHeadsTransformer`

In [None]:
import numpy as np
np.random.seed(42)

def layer_norm(x, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

def softmax(x, axis=-1):
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

def sigmoid(x):
    return 1 / (1 + np.exp(-np.clip(x, -500, 500)))

def swish(x):
    return x * sigmoid(x)

In [None]:
def affinity_attention(z, num_heads=4, c=32):
    """Self-attention over pair representation for affinity."""
    N = z.shape[0]
    c_z = z.shape[-1]
    
    z_norm = layer_norm(z)
    
    # Flatten pair to sequence for attention
    z_flat = z_norm.reshape(N * N, c_z)
    
    # QKV
    W_q = np.random.randn(c_z, num_heads, c) * (c_z ** -0.5)
    W_k = np.random.randn(c_z, num_heads, c) * (c_z ** -0.5)
    W_v = np.random.randn(c_z, num_heads, c) * (c_z ** -0.5)
    
    q = np.einsum('nc,chd->nhd', z_flat, W_q)
    k = np.einsum('nc,chd->nhd', z_flat, W_k)
    v = np.einsum('nc,chd->nhd', z_flat, W_v)
    
    # Attention
    attn = np.einsum('nhd,mhd->nmh', q, k) / np.sqrt(c)
    attn = softmax(attn, axis=1)
    
    output = np.einsum('nmh,mhd->nhd', attn, v)
    output = output.reshape(N * N, -1)
    
    W_o = np.random.randn(num_heads * c, c_z) * ((num_heads * c) ** -0.5)
    output = output @ W_o
    
    return output.reshape(N, N, c_z)

In [None]:
def affinity_heads_transformer(z, num_blocks=4, num_heads=4, c=32):
    """
    Affinity Heads Transformer.
    
    Processes pair representation for affinity prediction.
    
    Args:
        z: Pair representation [N, N, c_z]
        num_blocks: Number of transformer blocks
        num_heads: Number of attention heads
        c: Head dimension
    
    Returns:
        affinity_value: Predicted affinity
        affinity_binary: Binder probability
    """
    N = z.shape[0]
    c_z = z.shape[-1]
    
    print(f"Affinity Heads Transformer")
    print(f"="*50)
    print(f"Blocks: {num_blocks}, Heads: {num_heads}")
    
    for i in range(num_blocks):
        # Attention
        z = z + affinity_attention(z, num_heads, c)
        
        # FFN
        z_norm = layer_norm(z)
        W_up = np.random.randn(c_z, c_z * 4) * (c_z ** -0.5)
        hidden = swish(np.einsum('ijc,cd->ijd', z_norm, W_up))
        W_down = np.random.randn(c_z * 4, c_z) * ((c_z * 4) ** -0.5)
        z = z + np.einsum('ijd,dc->ijc', hidden, W_down)
    
    # Pool and predict
    z_pooled = z.mean(axis=(0, 1))
    z_pooled = layer_norm(z_pooled)
    
    W_val = np.random.randn(c_z, 1) * (c_z ** -0.5)
    affinity_value = (z_pooled @ W_val).item()
    
    W_bin = np.random.randn(c_z, 1) * (c_z ** -0.5)
    affinity_binary = sigmoid((z_pooled @ W_bin).item())
    
    print(f"Affinity value: {affinity_value:.3f}")
    print(f"Binder prob: {affinity_binary:.3f}")
    
    return affinity_value, affinity_binary

In [None]:
# Test
print("Test: Affinity Heads Transformer")
print("="*60)

N = 20
c_z = 64

z = np.random.randn(N, N, c_z) * 0.1

aff_val, aff_bin = affinity_heads_transformer(z, num_blocks=2)

## Key Insights

1. **Pair-based**: Works on pair representation
2. **Self-Attention**: Global context for affinity
3. **Dual Head**: Regression + Classification
4. **Deep Processing**: Multiple transformer blocks