# Algorithm 5: Outer Product Mean (AlphaFold3)

Outer Product Mean transfers information from MSA representation to pair representation.

## Source Code Location
- **File**: `AF3-Ref-src/alphafold3-official/src/alphafold3/model/network/modules.py`

In [None]:
import numpy as np
np.random.seed(42)

def layer_norm(x, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

In [None]:
def outer_product_mean(msa, c=32, c_z=128):
    """
    Outer Product Mean - transfers MSA info to pair representation.
    
    Args:
        msa: MSA representation [N_msa, N_token, c_m]
        c: Hidden dimension for outer product
        c_z: Pair representation dimension
    
    Returns:
        Pair update [N_token, N_token, c_z]
    """
    N_msa, N_token, c_m = msa.shape
    
    print(f"Outer Product Mean")
    print(f"="*50)
    print(f"MSA: [{N_msa}, {N_token}, {c_m}]")
    
    # Layer norm
    msa_norm = layer_norm(msa)
    
    # Project to a and b
    W_a = np.random.randn(c_m, c) * (c_m ** -0.5)
    W_b = np.random.randn(c_m, c) * (c_m ** -0.5)
    
    a = np.einsum('sic,cd->sid', msa_norm, W_a)  # [N_msa, N_token, c]
    b = np.einsum('sjc,ce->sje', msa_norm, W_b)  # [N_msa, N_token, c]
    
    print(f"a: {a.shape}, b: {b.shape}")
    
    # Outer product and mean over MSA
    # For each (i,j), compute outer product of a[s,i] and b[s,j], then average over s
    outer = np.einsum('sid,sje->ijde', a, b) / N_msa  # [N_token, N_token, c, c]
    
    # Flatten outer product
    outer_flat = outer.reshape(N_token, N_token, c * c)  # [N_token, N_token, c*c]
    
    print(f"Outer product: {outer_flat.shape}")
    
    # Project to pair dimension
    W_o = np.random.randn(c * c, c_z) * ((c * c) ** -0.5)
    output = np.einsum('ijc,cd->ijd', outer_flat, W_o)
    
    print(f"Output: {output.shape}")
    
    return output

In [None]:
# Test
print("Test: Outer Product Mean")
print("="*60)

N_msa = 64
N_token = 32
c_m = 64

msa = np.random.randn(N_msa, N_token, c_m)

output = outer_product_mean(msa, c=32, c_z=128)

print(f"\nOutput finite: {np.isfinite(output).all()}")
print(f"Output norm: {np.linalg.norm(output):.2f}")

## Key Insights

1. **Information Transfer**: Moves evolutionary information from MSA to pairwise representation
2. **Outer Product**: Captures correlations between positions i and j across MSA
3. **Mean Aggregation**: Averages over MSA sequences for robust signal
4. **Dimensionality**: c*c -> c_z projection for final pair update