# Algorithm 10: Outer Product Mean

The Outer Product Mean computes pairwise relationships from MSA features by taking the outer product of projected features and averaging across sequences. This is a key mechanism for transferring evolutionary covariance information from MSA to pair representation.

## Algorithm Pseudocode

![Outer Product Mean](../imgs/algorithms/OuterProductMean.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Class**: `OuterProductMean`
- **Lines**: 1414-1498

## Algorithm Steps

1. **Layer Normalization**: Normalize MSA input
2. **Left/Right Projection**: Project MSA to lower dimension (`c` channels)
3. **Outer Product**: Compute outer product between left and right projections
4. **Mean over Sequences**: Average across all sequences in MSA
5. **Output Projection**: Project to pair representation dimension

In [None]:
import numpy as np

# Set random seed for reproducibility
np.random.seed(42)

## NumPy Implementation for Testing

In [None]:
def layer_norm(x, axis=-1, eps=1e-5):
    """Layer normalization."""
    mean = np.mean(x, axis=axis, keepdims=True)
    var = np.var(x, axis=axis, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)


def outer_product_mean(msa_act, msa_mask, num_outer_channel=32, num_output_channel=128):
    """
    Outer Product Mean implementation.
    
    Algorithm 10 from AlphaFold2 supplementary materials.
    
    Args:
        msa_act: MSA activations, shape [N_seq, N_res, c_m]
        msa_mask: MSA mask, shape [N_seq, N_res]
        num_outer_channel: Intermediate projection dimension (c in paper)
        num_output_channel: Output pair representation dimension (c_z)
    
    Returns:
        Pair representation update, shape [N_res, N_res, c_z]
    """
    N_seq, N_res, c_m = msa_act.shape
    
    # Step 1: Layer normalization (Line 1)
    act = layer_norm(msa_act, axis=-1)
    
    # Expand mask for broadcasting
    mask = msa_mask[:, :, None]  # [N_seq, N_res, 1]
    
    # Step 2: Left and Right projections (Lines 2-3)
    # Initialize projection weights
    left_proj_w = np.random.randn(c_m, num_outer_channel) * 0.01
    right_proj_w = np.random.randn(c_m, num_outer_channel) * 0.01
    
    # Project and apply mask
    left_act = mask * np.einsum('sra,ac->src', act, left_proj_w)   # [N_seq, N_res, c]
    right_act = mask * np.einsum('sra,ac->src', act, right_proj_w) # [N_seq, N_res, c]
    
    print(f"Left projection shape: {left_act.shape}")
    print(f"Right projection shape: {right_act.shape}")
    
    # Step 3: Outer product (Line 4)
    # For each sequence, compute outer product between all residue pairs
    # Result shape: [N_seq, N_res, N_res, c, c]
    outer = np.einsum('sia,sjb->sijab', left_act, right_act)
    print(f"Outer product shape: {outer.shape}")
    
    # Step 4: Mean over sequences (Line 5)
    # Compute normalization factor
    norm = np.einsum('si,sj->ij', mask[:,:,0], mask[:,:,0])  # [N_res, N_res]
    norm = np.maximum(norm, 1e-3)  # Avoid division by zero
    
    # Sum and normalize
    outer_mean = np.sum(outer, axis=0) / norm[:, :, None, None]  # [N_res, N_res, c, c]
    print(f"Outer mean shape: {outer_mean.shape}")
    
    # Step 5: Output projection (Line 6)
    # Flatten the last two dimensions and project
    outer_flat = outer_mean.reshape(N_res, N_res, num_outer_channel * num_outer_channel)
    output_w = np.random.randn(num_outer_channel * num_outer_channel, num_output_channel) * 0.01
    output_b = np.zeros(num_output_channel)
    
    output = np.einsum('ijc,cd->ijd', outer_flat, output_w) + output_b
    print(f"Output shape: {output.shape}")
    
    return output

## Test Example

In [None]:
# Test parameters
N_seq = 128    # Number of sequences in MSA
N_res = 64     # Number of residues (sequence length)
c_m = 256      # MSA channel dimension
c_z = 128      # Pair channel dimension
c = 32         # Outer product channel dimension

# Create test inputs
msa_act = np.random.randn(N_seq, N_res, c_m).astype(np.float32)
msa_mask = np.ones((N_seq, N_res), dtype=np.float32)

# Add some padding (mask out last 4 residues)
msa_mask[:, -4:] = 0

print(f"MSA activations shape: {msa_act.shape}")
print(f"MSA mask shape: {msa_mask.shape}")
print(f"Valid residues: {int(msa_mask[0].sum())} / {N_res}")
print()

In [None]:
# Run the algorithm
pair_update = outer_product_mean(
    msa_act, 
    msa_mask, 
    num_outer_channel=c, 
    num_output_channel=c_z
)

print(f"\nFinal output shape: {pair_update.shape}")
print(f"Expected shape: [{N_res}, {N_res}, {c_z}]")
print(f"Output statistics: mean={pair_update.mean():.6f}, std={pair_update.std():.6f}")

## Verification: Check Symmetry Properties

In [None]:
# The outer product mean should capture co-evolution signals
# Verify that masked regions have near-zero contribution

valid_region = pair_update[:60, :60, :]  # Valid residues
masked_region = pair_update[60:, :, :]   # Involves masked residues

print(f"Valid region norm: {np.linalg.norm(valid_region):.4f}")
print(f"Masked region norm: {np.linalg.norm(masked_region):.4f}")
print(f"Masked region should be smaller due to masking")

## Source Code Reference

```python
# From AF2-source-code/model/modules.py

class OuterProductMean(hk.Module):
  """Computes mean outer product.

  Jumper et al. (2021) Suppl. Alg. 10 "OuterProductMean"
  """

  def __call__(self, act, mask, is_training=True):
    gc = self.global_config
    c = self.config

    mask = mask[..., None]
    act = hk.LayerNorm([-1], True, True, name='layer_norm_input')(act)

    left_act = mask * common_modules.Linear(
        c.num_outer_channel, initializer='linear', name='left_projection')(act)

    right_act = mask * common_modules.Linear(
        c.num_outer_channel, initializer='linear', name='right_projection')(act)

    # Compute outer product and average
    def compute_chunk(left_act):
      left_act = jnp.transpose(left_act, [0, 2, 1])
      act = jnp.einsum('acb,ade->dceb', left_act, right_act)
      act = jnp.einsum('dceb,cef->dbf', act, output_w) + output_b
      return jnp.transpose(act, [1, 0, 2])

    # Normalize by number of sequences
    epsilon = 1e-3
    norm = jnp.einsum('abc,adc->bdc', mask, mask)
    act /= epsilon + norm

    return act
```

## Key Insights

1. **Evolutionary Covariance**: The outer product captures correlations between residue positions across sequences, which is a signal for structural contacts.

2. **Dimension Reduction**: Projecting to lower dimension (`c=32`) before outer product reduces memory from O(N² * c_m²) to O(N² * c²).

3. **Mean Normalization**: Dividing by the number of valid sequence pairs ensures consistent scale regardless of MSA depth.

4. **Masking**: Proper masking ensures padded sequences don't contribute to the pair representation.