# Algorithm 5: Outer Product Mean (Boltz)

Transfers information from MSA to pair representation.

## Source Code Location
- **File**: `Boltz-Ref-src/boltz-official/src/boltz/model/layers/outer_product_mean.py`

In [None]:
import numpy as np
np.random.seed(42)

def layer_norm(x, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

In [None]:
def outer_product_mean(msa, c=32, c_z=128):
    """
    Outer Product Mean.
    
    Computes pairwise features from MSA by taking outer products
    and averaging over sequences.
    
    Args:
        msa: MSA representation [N_msa, N_token, c_m]
        c: Hidden dimension for projections
        c_z: Output pair dimension
    
    Returns:
        Pair representation [N_token, N_token, c_z]
    """
    N_msa, N_token, c_m = msa.shape
    
    print(f"Outer Product Mean")
    print(f"="*50)
    print(f"MSA: [{N_msa}, {N_token}, {c_m}]")
    
    # Layer norm
    msa_norm = layer_norm(msa)
    
    # Project to lower dimension
    W_a = np.random.randn(c_m, c) * (c_m ** -0.5)
    W_b = np.random.randn(c_m, c) * (c_m ** -0.5)
    
    a = np.einsum('sic,cd->sid', msa_norm, W_a)  # [N_msa, N_token, c]
    b = np.einsum('sjc,ce->sje', msa_norm, W_b)  # [N_msa, N_token, c]
    
    # Outer product and mean
    # For each (i, j), compute outer product of a[s,i] and b[s,j], mean over s
    outer = np.einsum('sid,sje->ijde', a, b) / N_msa  # [N_token, N_token, c, c]
    
    # Flatten
    outer_flat = outer.reshape(N_token, N_token, c * c)
    
    print(f"Outer product: {outer_flat.shape}")
    
    # Project to pair dimension
    W_o = np.random.randn(c * c, c_z) * ((c * c) ** -0.5)
    output = outer_flat @ W_o
    
    print(f"Output: {output.shape}")
    
    return output

In [None]:
# Test
print("Test: Outer Product Mean")
print("="*60)

N_msa = 128
N_token = 32
c_m = 64

msa = np.random.randn(N_msa, N_token, c_m)

output = outer_product_mean(msa, c=32, c_z=128)

print(f"\nOutput shape: {output.shape}")
print(f"Output finite: {np.isfinite(output).all()}")
print(f"Output norm: {np.linalg.norm(output):.2f}")

## Key Insights

1. **Coevolution Signal**: Captures correlations between positions
2. **Outer Product**: Full pairwise feature computation
3. **Mean Aggregation**: Robust averaging over sequences
4. **Dimensionality Reduction**: c*c -> c_z projection