# Algorithm 10: Attention with Pair Bias (Boltz)

Single representation attention using pair as bias.

## Source Code Location
- **File**: `Boltz-Ref-src/boltz-official/src/boltz/model/layers/attention.py`
- **Class**: `AttentionPairBias`

In [None]:
import numpy as np
np.random.seed(42)

def layer_norm(x, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

def softmax(x, axis=-1):
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

def sigmoid(x):
    return 1 / (1 + np.exp(-np.clip(x, -500, 500)))

In [None]:
def attention_pair_bias(s, z, num_heads=8, c=32):
    """
    Attention with Pair Bias.
    
    Standard self-attention on single representation with
    pair representation as attention bias.
    
    Args:
        s: Single representation [N, c_s]
        z: Pair representation [N, N, c_z]
        num_heads: Number of attention heads
        c: Head dimension
    
    Returns:
        Update to single [N, c_s]
    """
    N, c_s = s.shape
    c_z = z.shape[-1]
    
    print(f"Attention with Pair Bias")
    print(f"="*50)
    print(f"Single: [{N}, {c_s}], Pair: [{N}, {N}, {c_z}]")
    
    s_norm = layer_norm(s)
    z_norm = layer_norm(z)
    
    # QKV from single
    W_q = np.random.randn(c_s, num_heads, c) * (c_s ** -0.5)
    W_k = np.random.randn(c_s, num_heads, c) * (c_s ** -0.5)
    W_v = np.random.randn(c_s, num_heads, c) * (c_s ** -0.5)
    W_g = np.random.randn(c_s, num_heads * c) * (c_s ** -0.5)
    
    q = np.einsum('ic,chd->ihd', s_norm, W_q)
    k = np.einsum('jc,chd->jhd', s_norm, W_k)
    v = np.einsum('jc,chd->jhd', s_norm, W_v)
    g = sigmoid(s_norm @ W_g)  # [N, H*c]
    
    # Bias from pair
    W_b = np.random.randn(c_z, num_heads) * (c_z ** -0.5)
    b = np.einsum('ijc,ch->ijh', z_norm, W_b)
    
    # Attention
    attn = np.einsum('ihd,jhd->ijh', q, k) / np.sqrt(c)
    attn = attn + b
    attn = softmax(attn, axis=1)
    
    # Apply attention
    output = np.einsum('ijh,jhd->ihd', attn, v)
    output = output.reshape(N, -1) * g
    
    # Output projection
    W_o = np.random.randn(num_heads * c, c_s) * ((num_heads * c) ** -0.5)
    output = output @ W_o
    
    print(f"Output: {output.shape}")
    return output

In [None]:
# Test
print("Test: Attention with Pair Bias")
print("="*60)

N = 32
c_s = 128
c_z = 64

s = np.random.randn(N, c_s)
z = np.random.randn(N, N, c_z)

output = attention_pair_bias(s, z, num_heads=8, c=16)

print(f"\nOutput shape: {output.shape}")
print(f"Output finite: {np.isfinite(output).all()}")

## Key Insights

1. **Pair as Bias**: Pair representation modulates attention
2. **Gating**: Output gating for controlled updates
3. **Pairâ†’Single**: Key mechanism for information flow
4. **Standard Attention**: QKV from single representation