# Algorithm 4: MSA Module (Boltz)

Processes MSA features and generates pair representations.

## Source Code Location
- **File**: `Boltz-Ref-src/boltz-official/src/boltz/model/modules/trunk.py`
- **Class**: `MSAModule`

In [None]:
import numpy as np
np.random.seed(42)

def layer_norm(x, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

def softmax(x, axis=-1):
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

In [None]:
def outer_product_mean(msa, c=32, c_z=128):
    """Outer Product Mean - MSA to pair."""
    N_msa, N_token, c_m = msa.shape
    msa_norm = layer_norm(msa)
    
    W_a = np.random.randn(c_m, c) * (c_m ** -0.5)
    W_b = np.random.randn(c_m, c) * (c_m ** -0.5)
    
    a = np.einsum('sic,cd->sid', msa_norm, W_a)
    b = np.einsum('sjc,ce->sje', msa_norm, W_b)
    
    outer = np.einsum('sid,sje->ijde', a, b) / N_msa
    outer_flat = outer.reshape(N_token, N_token, c * c)
    
    W_o = np.random.randn(c * c, c_z) * ((c * c) ** -0.5)
    return outer_flat @ W_o

In [None]:
def msa_attention_with_pair_bias(msa, z, num_heads=8):
    """MSA attention with pair bias."""
    N_msa, N_token, c_m = msa.shape
    c_z = z.shape[-1]
    c = c_m // num_heads
    
    msa_norm = layer_norm(msa)
    z_norm = layer_norm(z)
    
    # Pair bias
    W_b = np.random.randn(c_z, num_heads) * (c_z ** -0.5)
    bias = np.einsum('ijc,ch->ijh', z_norm, W_b)
    
    # Value projection
    W_v = np.random.randn(c_m, num_heads, c) * (c_m ** -0.5)
    v = np.einsum('sic,chd->sihd', msa_norm, W_v)
    
    # Attention weights from pair bias
    weights = softmax(bias, axis=1)
    
    # Apply attention
    output = np.einsum('ijh,sjhd->sihd', weights, v)
    output = output.reshape(N_msa, N_token, -1)
    
    W_o = np.random.randn(num_heads * c, c_m) * ((num_heads * c) ** -0.5)
    return output @ W_o

In [None]:
def msa_module(msa, z, num_blocks=4):
    """
    MSA Module - processes MSA with pair bias.
    
    Args:
        msa: MSA representation [N_msa, N_token, c_m]
        z: Pair representation [N_token, N_token, c_z]
        num_blocks: Number of MSA blocks
    
    Returns:
        Updated (msa, z)
    """
    print(f"MSA Module ({num_blocks} blocks)")
    print(f"="*50)
    
    c_m = msa.shape[-1]
    c_z = z.shape[-1]
    
    for i in range(num_blocks):
        # MSA attention with pair bias
        msa = msa + msa_attention_with_pair_bias(msa, z)
        
        # Outer product mean
        z = z + outer_product_mean(msa, c=32, c_z=c_z)
        
        print(f"  Block {i+1}: msa_norm={np.linalg.norm(msa):.2f}, z_norm={np.linalg.norm(z):.2f}")
    
    return msa, z

In [None]:
# Test
print("Test: MSA Module")
print("="*60)

N_msa = 64
N_token = 32
c_m = 64
c_z = 128

msa = np.random.randn(N_msa, N_token, c_m) * 0.1
z = np.random.randn(N_token, N_token, c_z) * 0.1

msa_out, z_out = msa_module(msa, z, num_blocks=4)

print(f"\nOutput shapes: msa={msa_out.shape}, z={z_out.shape}")
print(f"Outputs finite: {np.isfinite(msa_out).all() and np.isfinite(z_out).all()}")

## Key Insights

1. **Bidirectional Flow**: MSA informs pair, pair biases MSA attention
2. **Outer Product Mean**: Key mechanism for MSAâ†’pair information
3. **Pair Bias**: Pair representation guides MSA attention
4. **Multiple Blocks**: Iterative refinement