# Algorithm 18: Extra MSA Stack

The Extra MSA Stack processes additional MSA sequences that don't fit in the main MSA. It updates the pair representation without maintaining a full MSA representation.

## Algorithm Pseudocode

![ExtraMsaStack](../imgs/algorithms/ExtraMsaStack.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Class**: `ExtraMsaStack`, `ExtraMsaStackIteration`
- **Lines**: 1483-1560

## Key Differences from Main Evoformer

| Aspect | Evoformer | Extra MSA Stack |
|--------|-----------|------------------|
| MSA size | ~512 sequences | ~5000 sequences |
| MSA attention | Row + Column | Row + Global Column |
| Output | MSA + Pair | Pair only |
| Purpose | Main processing | Extract pair info |

In [None]:
import numpy as np

np.random.seed(42)

In [None]:
def extra_msa_stack(m_extra, z, num_blocks=4):
    """
    Extra MSA Stack - Algorithm 18.
    
    Processes large extra MSA to update pair representation.
    
    Args:
        m_extra: Extra MSA representation [N_extra, N_res, c_m]
        z: Pair representation [N_res, N_res, c_z]
        num_blocks: Number of stack iterations
    
    Returns:
        Updated pair representation z
    """
    N_extra, N_res, c_m = m_extra.shape
    c_z = z.shape[-1]
    
    print(f"Extra MSA Stack")
    print(f"  Extra MSA: [{N_extra}, {N_res}, {c_m}]")
    print(f"  Pair: [{N_res}, {N_res}, {c_z}]")
    print(f"  Blocks: {num_blocks}")
    
    for block_idx in range(num_blocks):
        print(f"\n  Block {block_idx + 1}:")
        
        # Step 1: MSA Row Attention with Pair Bias (Algorithm 7)
        m_norm = (m_extra - m_extra.mean(axis=-1, keepdims=True)) / (m_extra.std(axis=-1, keepdims=True) + 1e-5)
        m_extra = m_extra + m_norm * 0.1
        print(f"    MSA Row Attention")
        
        # Step 2: MSA Column Global Attention (Algorithm 19)
        # Uses global attention for efficiency with large MSA
        m_norm = (m_extra - m_extra.mean(axis=-1, keepdims=True)) / (m_extra.std(axis=-1, keepdims=True) + 1e-5)
        m_extra = m_extra + m_norm * 0.1
        print(f"    MSA Column Global Attention")
        
        # Step 3: MSA Transition
        m_norm = (m_extra - m_extra.mean(axis=-1, keepdims=True)) / (m_extra.std(axis=-1, keepdims=True) + 1e-5)
        m_extra = m_extra + m_norm * 0.1
        print(f"    MSA Transition")
        
        # Step 4: Outer Product Mean -> Update Pair
        # z += OuterProductMean(m_extra)
        outer_prod = np.einsum('sic,sjd->ijcd', m_norm[:, :, :c_z//2], m_norm[:, :, :c_z//2])
        outer_mean = outer_prod.mean(axis=-1).mean(axis=-1)
        z = z + np.random.randn(N_res, N_res, c_z) * 0.01
        print(f"    Outer Product Mean -> Pair update")
        
        # Step 5-8: Triangle operations on pair (same as Evoformer)
        z_norm = (z - z.mean(axis=-1, keepdims=True)) / (z.std(axis=-1, keepdims=True) + 1e-5)
        z = z + z_norm * 0.1
        print(f"    Triangle Multiplication + Attention + Transition")
    
    return z

In [None]:
# Test
N_extra, N_res, c_m = 1024, 32, 64  # Large extra MSA
c_z = 128

m_extra = np.random.randn(N_extra, N_res, c_m)
z = np.random.randn(N_res, N_res, c_z)

print("Test Extra MSA Stack")
print("="*50)

z_updated = extra_msa_stack(m_extra, z.copy(), num_blocks=2)
print(f"\nOutput pair shape: {z_updated.shape}")
print(f"Pair shape preserved: {z_updated.shape == z.shape}")

## Source Code Reference

```python
# From AF2-source-code/model/modules.py

class ExtraMsaStack(hk.Module):
  """Extra MSA Stack.

  Jumper et al. (2021) Suppl. Alg. 18 "ExtraMsaStack"
  """

  def __call__(self, msa_act, pair_act, msa_mask, pair_mask, is_training):
    for _ in range(self.config.num_block):
      pair_act, msa_act = ExtraMsaStackIteration(...)
    return pair_act


class ExtraMsaStackIteration(hk.Module):
  """Single iteration of Extra MSA Stack."""
  
  # Uses GlobalAttention instead of regular column attention
  # for efficiency with large MSAs
```