# Algorithm 18: Extra MSA Stack

The Extra MSA Stack processes additional MSA sequences that don't fit in the main MSA. It updates the pair representation without maintaining a full MSA representation, using global attention for efficiency.

## Algorithm Pseudocode

![ExtraMsaStack](../imgs/algorithms/ExtraMsaStack.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Class**: `ExtraMsaStack`, `ExtraMsaStackIteration`
- **Lines**: 1483-1560

## Overview

### Comparison: Main MSA vs Extra MSA

| Aspect | Main Evoformer | Extra MSA Stack |
|--------|----------------|------------------|
| MSA size | ~512 sequences | ~5000 sequences |
| Attention | Row + Column | Row + Global Column |
| Output | MSA + Pair | Pair only |
| Blocks | 48 | 4 |
| Purpose | Main processing | Extract pair info |

### Algorithm Steps

1. MSA Row Attention with Pair Bias
2. MSA Column Global Attention (Algorithm 19)
3. MSA Transition
4. Outer Product Mean → Pair update
5. Triangle Multiplication + Attention
6. Pair Transition

In [None]:
import numpy as np

np.random.seed(42)

## NumPy Implementation

In [None]:
def layer_norm(x, axis=-1, eps=1e-5):
    """Layer normalization."""
    mean = np.mean(x, axis=axis, keepdims=True)
    var = np.var(x, axis=axis, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)


def softmax(x, axis=-1):
    """Softmax."""
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)


def relu(x):
    """ReLU."""
    return np.maximum(0, x)


def sigmoid(x):
    """Sigmoid."""
    return 1.0 / (1.0 + np.exp(-np.clip(x, -20, 20)))

In [None]:
def msa_row_attention(m, z, c=8, num_heads=8):
    """MSA Row Attention with Pair Bias (simplified)."""
    N_seq, N_res, c_m = m.shape
    c_z = z.shape[-1]
    
    m_norm = layer_norm(m)
    z_norm = layer_norm(z)
    
    W_q = np.random.randn(c_m, num_heads, c) * (c_m ** -0.5)
    W_k = np.random.randn(c_m, num_heads, c) * (c_m ** -0.5)
    W_v = np.random.randn(c_m, num_heads, c) * (c_m ** -0.5)
    W_b = np.random.randn(c_z, num_heads) * (c_z ** -0.5)
    W_o = np.random.randn(num_heads, c, c_m) * ((num_heads * c) ** -0.5)
    
    q = np.einsum('src,chd->srhd', m_norm, W_q)
    k = np.einsum('src,chd->srhd', m_norm, W_k)
    v = np.einsum('src,chd->srhd', m_norm, W_v)
    b = np.einsum('ijc,ch->ijh', z_norm, W_b)
    
    attn_logits = np.einsum('sihd,sjhd->sijh', q, k) / np.sqrt(c)
    attn_logits = attn_logits + b[None, :, :, :]  # Add pair bias
    attn_weights = softmax(attn_logits, axis=2)
    
    attended = np.einsum('sijh,sjhd->sihd', attn_weights, v)
    output = np.einsum('sihd,hdc->sic', attended, W_o)
    
    return output


def msa_column_global_attention(m, c=8, num_heads=8):
    """MSA Column Global Attention (Algorithm 19)."""
    N_seq, N_res, c_m = m.shape
    
    m_norm = layer_norm(m)
    
    W_q = np.random.randn(c_m, num_heads, c) * (c_m ** -0.5)
    W_k = np.random.randn(c_m, num_heads, c) * (c_m ** -0.5)
    W_v = np.random.randn(c_m, num_heads, c) * (c_m ** -0.5)
    W_g = np.random.randn(c_m, num_heads, c) * (c_m ** -0.5)
    W_o = np.random.randn(num_heads, c, c_m) * ((num_heads * c) ** -0.5)
    
    # Global query from mean
    q_input = m_norm.mean(axis=0)  # [N_res, c_m]
    q = np.einsum('rc,chd->rhd', q_input, W_q)  # [N_res, H, c]
    
    k = np.einsum('src,chd->srhd', m_norm, W_k)
    v = np.einsum('src,chd->srhd', m_norm, W_v)
    g = sigmoid(np.einsum('src,chd->srhd', m_norm, W_g))
    
    # Attention over sequences
    attn_logits = np.einsum('rhd,srhd->rhs', q, k) / np.sqrt(c)
    attn_weights = softmax(attn_logits, axis=-1)
    
    attended = np.einsum('rhs,srhd->rhd', attn_weights, v)
    attended_broadcast = np.broadcast_to(attended[None, :, :, :], (N_seq, N_res, num_heads, c))
    
    gated = g * attended_broadcast
    output = np.einsum('srhd,hdc->src', gated, W_o)
    
    return output


def outer_product_mean(m, c=32, c_z=128):
    """Outer Product Mean (Algorithm 10)."""
    N_seq, N_res, c_m = m.shape
    
    m_norm = layer_norm(m)
    
    W_l = np.random.randn(c_m, c) * (c_m ** -0.5)
    W_r = np.random.randn(c_m, c) * (c_m ** -0.5)
    W_o = np.random.randn(c * c, c_z) * ((c * c) ** -0.5)
    
    left = np.einsum('src,cd->srd', m_norm, W_l)
    right = np.einsum('src,cd->srd', m_norm, W_r)
    
    outer = np.einsum('sic,sjd->sijcd', left, right)
    outer_mean = outer.mean(axis=0)  # [N, N, c, c]
    outer_flat = outer_mean.reshape(N_res, N_res, c * c)
    
    output = np.einsum('ijc,cd->ijd', outer_flat, W_o)
    return output

In [None]:
def extra_msa_stack(m_extra, z, num_blocks=4):
    """
    Extra MSA Stack - Algorithm 18.
    
    Processes large extra MSA to update pair representation.
    
    Args:
        m_extra: Extra MSA representation [N_extra, N_res, c_m]
        z: Pair representation [N_res, N_res, c_z]
        num_blocks: Number of stack iterations
    
    Returns:
        Updated pair representation z
    """
    N_extra, N_res, c_m = m_extra.shape
    c_z = z.shape[-1]
    
    print(f"Extra MSA Stack")
    print(f"="*50)
    print(f"Extra MSA: [{N_extra}, {N_res}, {c_m}]")
    print(f"Pair: [{N_res}, {N_res}, {c_z}]")
    print(f"Blocks: {num_blocks}")
    
    for block_idx in range(num_blocks):
        print(f"\nBlock {block_idx + 1}:")
        
        # Step 1: MSA Row Attention with Pair Bias
        m_extra = m_extra + msa_row_attention(m_extra, z, c=8, num_heads=8)
        print(f"  MSA Row Attention")
        
        # Step 2: MSA Column Global Attention (Algorithm 19)
        m_extra = m_extra + msa_column_global_attention(m_extra, c=8, num_heads=8)
        print(f"  MSA Column Global Attention")
        
        # Step 3: MSA Transition
        m_norm = layer_norm(m_extra)
        W1 = np.random.randn(c_m, 4 * c_m) * (c_m ** -0.5)
        W2 = np.random.randn(4 * c_m, c_m) * ((4 * c_m) ** -0.5)
        m_extra = m_extra + relu(m_norm @ W1) @ W2
        print(f"  MSA Transition")
        
        # Step 4: Outer Product Mean -> Update Pair
        z = z + outer_product_mean(m_extra, c=32, c_z=c_z)
        print(f"  Outer Product Mean -> Pair")
        
        # Step 5-8: Triangle operations on pair (simplified)
        z_norm = layer_norm(z)
        z = z + z_norm * 0.1  # Simplified triangle ops
        print(f"  Triangle Operations")
        
        print(f"  MSA norm: {np.linalg.norm(m_extra):.2f}, Pair norm: {np.linalg.norm(z):.2f}")
    
    return z

## Test Examples

In [None]:
# Test 1: Basic functionality
print("Test 1: Basic Functionality")
print("="*60)

np.random.seed(42)
N_extra, N_res, c_m = 1024, 32, 64
c_z = 128

m_extra = np.random.randn(N_extra, N_res, c_m).astype(np.float32)
z = np.random.randn(N_res, N_res, c_z).astype(np.float32)

z_updated = extra_msa_stack(m_extra, z.copy(), num_blocks=2)

print(f"\nOutput pair shape: {z_updated.shape}")
print(f"Pair shape preserved: {z_updated.shape == z.shape}")

In [None]:
# Test 2: Global attention efficiency
print("\nTest 2: Global Attention Efficiency")
print("="*60)

import time

N_res, c_m = 32, 64

for N_seq in [256, 512, 1024, 2048]:
    m = np.random.randn(N_seq, N_res, c_m).astype(np.float32)
    
    start = time.time()
    _ = msa_column_global_attention(m, c=8, num_heads=8)
    elapsed = (time.time() - start) * 1000
    
    print(f"N_seq={N_seq:4d}: {elapsed:.2f}ms (O(N) complexity)")

In [None]:
# Test 3: Information flow from MSA to pair
print("\nTest 3: Information Flow MSA -> Pair")
print("="*60)

np.random.seed(42)
N_extra, N_res, c_m = 512, 24, 32
c_z = 64

# Create structured MSA
m_extra = np.random.randn(N_extra, N_res, c_m).astype(np.float32)

# Start with zero pair
z_zero = np.zeros((N_res, N_res, c_z), dtype=np.float32)

# Run stack
z_result = extra_msa_stack(m_extra, z_zero, num_blocks=2)

print(f"\nStarting pair norm: {np.linalg.norm(z_zero):.4f}")
print(f"Final pair norm: {np.linalg.norm(z_result):.4f}")
print(f"Information transferred from MSA to pair")

## Verification: Key Properties

In [None]:
print("Verification: Key Properties")
print("="*60)

np.random.seed(42)
N_extra, N_res, c_m = 256, 20, 32
c_z = 64

m_extra = np.random.randn(N_extra, N_res, c_m).astype(np.float32)
z = np.random.randn(N_res, N_res, c_z).astype(np.float32)

z_output = extra_msa_stack(m_extra.copy(), z.copy(), num_blocks=2)

# Property 1: Pair shape preserved
shape_preserved = z_output.shape == z.shape
print(f"Property 1 - Pair shape preserved: {shape_preserved}")

# Property 2: Finite output
output_finite = np.isfinite(z_output).all()
print(f"Property 2 - Output finite: {output_finite}")

# Property 3: Non-trivial update
not_identity = not np.allclose(z_output, z)
print(f"Property 3 - Non-trivial update: {not_identity}")

# Property 4: Depends on extra MSA
m_modified = m_extra * 2.0
z_modified = extra_msa_stack(m_modified, z.copy(), num_blocks=2)
depends_on_msa = not np.allclose(z_output, z_modified)
print(f"Property 4 - Depends on extra MSA: {depends_on_msa}")

# Property 5: Reasonable scale
scale_ratio = np.linalg.norm(z_output) / np.linalg.norm(z)
print(f"Property 5 - Scale ratio: {scale_ratio:.4f}")

## Source Code Reference

```python
# From AF2-source-code/model/modules.py

class ExtraMsaStack(hk.Module):
  """Extra MSA Stack.

  Jumper et al. (2021) Suppl. Alg. 18 "ExtraMsaStack"
  """

  def __call__(self, msa_act, pair_act, msa_mask, pair_mask, is_training):
    for _ in range(self.config.num_block):
      pair_act, msa_act = ExtraMsaStackIteration(
          self.config, self.global_config)(msa_act, pair_act, ...)
    return pair_act


class ExtraMsaStackIteration(hk.Module):
  """Single iteration of Extra MSA Stack.
  
  Key difference: Uses GlobalAttention instead of regular column attention.
  """
  
  def __call__(self, msa_act, pair_act, ...):
    # MSA row attention with pair bias
    msa_act += MSARowAttentionWithPairBias(...)(msa_act, msa_mask, pair_act)
    
    # MSA column global attention (Algorithm 19)
    msa_act += MSAColumnGlobalAttention(...)(msa_act, msa_mask)
    
    # MSA transition
    msa_act += Transition(...)(msa_act, msa_mask)
    
    # Outer product mean -> pair update
    pair_act += OuterProductMean(...)(msa_act, msa_mask)
    
    # Triangle operations on pair
    pair_act += TriangleMultiplication(...)(pair_act, pair_mask)
    pair_act += TriangleAttention(...)(pair_act, pair_mask)
    pair_act += Transition(...)(pair_act, pair_mask)
    
    return pair_act, msa_act
```

## Key Insights

1. **Handling Large MSAs**: Deep MSAs can have 10,000+ sequences. The Extra MSA Stack processes them efficiently.

2. **Global Attention**: O(N) complexity per position instead of O(N²), enabling processing of large MSAs.

3. **Pair-Only Output**: Unlike the main Evoformer, only the pair representation is updated and passed on.

4. **Fewer Blocks**: 4 blocks vs 48 for Evoformer - the extra MSA provides supplementary information.

5. **Co-evolution Signal**: The outer product mean extracts pairwise co-evolution signals from the extra MSA.

6. **Two-Track Architecture**: The split into main MSA (detailed processing) and extra MSA (efficient processing) is a key architectural innovation.