# Algorithm 6: MSA Row Attention (AlphaFold3)

MSA Row Attention updates MSA representation using pair bias. Simplified compared to AF2.

## Source Code Location
- **File**: `AF3-Ref-src/alphafold3-official/src/alphafold3/model/network/modules.py`
- **Class**: `MSAAttention`

In [None]:
import numpy as np
np.random.seed(42)

def layer_norm(x, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

def softmax(x, axis=-1):
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

def sigmoid(x):
    return 1 / (1 + np.exp(-np.clip(x, -500, 500)))

In [None]:
def msa_row_attention(msa, pair, num_heads=8):
    """
    MSA Row Attention with pair bias.
    
    Args:
        msa: MSA representation [N_msa, N_token, c_m]
        pair: Pair representation [N_token, N_token, c_z]
        num_heads: Number of attention heads
    
    Returns:
        Updated MSA [N_msa, N_token, c_m]
    """
    N_msa, N_token, c_m = msa.shape
    c_z = pair.shape[-1]
    c = c_m // num_heads
    
    print(f"MSA Row Attention")
    print(f"="*50)
    print(f"MSA: [{N_msa}, {N_token}, {c_m}]")
    print(f"Heads: {num_heads}, Head dim: {c}")
    
    # Normalize
    msa_norm = layer_norm(msa)
    pair_norm = layer_norm(pair)
    
    # Pair logits (attention bias from pair representation)
    W_logits = np.random.randn(c_z, num_heads) * (c_z ** -0.5)
    logits = np.einsum('ijc,ch->ijh', pair_norm, W_logits)  # [N_token, N_token, H]
    logits = logits.transpose(2, 0, 1)  # [H, N_token, N_token]
    
    # Softmax to get weights
    weights = softmax(logits, axis=-1)  # [H, N_token, N_token]
    
    # Value projection
    W_v = np.random.randn(c_m, num_heads, c) * (c_m ** -0.5)
    v = np.einsum('sic,chd->sihd', msa_norm, W_v)  # [N_msa, N_token, H, c]
    
    # Apply attention (same weights for all MSA sequences)
    v_avg = np.einsum('hqk,skhc->sqhc', weights, v)  # [N_msa, N_token, H, c]
    v_avg = v_avg.reshape(N_msa, N_token, -1)  # [N_msa, N_token, H*c]
    
    # Gating
    W_g = np.random.randn(c_m, num_heads * c) * (c_m ** -0.5)
    b_g = np.ones(num_heads * c)  # Bias init to 1 for gates
    gate = sigmoid(np.einsum('sic,cd->sid', msa_norm, W_g) + b_g)
    v_avg = v_avg * gate
    
    # Output projection
    W_o = np.random.randn(num_heads * c, c_m) * ((num_heads * c) ** -0.5)
    output = np.einsum('sid,dc->sic', v_avg, W_o)
    
    print(f"Output: {output.shape}")
    
    return output

In [None]:
# Test
print("Test: MSA Row Attention")
print("="*60)

N_msa = 64
N_token = 32
c_m = 64
c_z = 128

msa = np.random.randn(N_msa, N_token, c_m)
pair = np.random.randn(N_token, N_token, c_z)

output = msa_row_attention(msa, pair, num_heads=8)

print(f"\nOutput finite: {np.isfinite(output).all()}")

## Key Insights

1. **Pair Bias**: Uses pair representation to compute attention weights
2. **Shared Weights**: Same attention pattern applied to all MSA sequences
3. **Gating**: Sigmoid gating for controlled information flow
4. **No Q/K**: Unlike standard attention, uses pair logits directly as weights