# Algorithm 19: MSA Column Global Attention

MSA Column Global Attention is an efficient attention mechanism for processing large MSAs. Instead of full O(N²) attention, it uses a global query that summarizes all sequences.

## Algorithm Pseudocode

![MSAColumnGlobalAttention](../imgs/algorithms/MSAColumnGlobalAttention.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Class**: `GlobalAttention`
- **Lines**: 717-795

## Comparison with Regular Column Attention

| Aspect | Regular Column (Alg 8) | Global Column (Alg 19) |
|--------|------------------------|------------------------|
| Complexity | O(N_seq²) | O(N_seq) |
| Query | All sequences | Mean of sequences |
| Use case | Main MSA (~512) | Extra MSA (~5000) |

In [None]:
import numpy as np

np.random.seed(42)

In [None]:
def msa_column_global_attention(m, num_heads=8, c=8):
    """
    MSA Column Global Attention - Algorithm 19.
    
    Efficient attention using mean query.
    
    Args:
        m: MSA representation [N_seq, N_res, c_m]
        num_heads: Number of attention heads
        c: Head dimension
    
    Returns:
        Updated m [N_seq, N_res, c_m]
    """
    N_seq, N_res, c_m = m.shape
    
    print(f"MSA Column Global Attention")
    print(f"  Input: [{N_seq}, {N_res}, {c_m}]")
    print(f"  Heads: {num_heads}, Head dim: {c}")
    
    # Layer norm
    m_norm = (m - m.mean(axis=-1, keepdims=True)) / (m.std(axis=-1, keepdims=True) + 1e-5)
    
    # Weights
    W_q = np.random.randn(c_m, num_heads, c) * 0.02
    W_k = np.random.randn(c_m, num_heads, c) * 0.02
    W_v = np.random.randn(c_m, num_heads, c) * 0.02
    W_g = np.random.randn(c_m, num_heads, c) * 0.02
    W_o = np.random.randn(num_heads, c, c_m) * 0.02
    
    # KEY DIFFERENCE: Global query from mean
    # Instead of q for each sequence, compute mean query
    q_input = m_norm.mean(axis=0)  # [N_res, c_m] - mean over sequences
    q = np.einsum('rc,chd->rhd', q_input, W_q)  # [N_res, H, c]
    print(f"  Global query: [{N_res}, {num_heads}, {c}]")
    
    # K, V from all sequences
    k = np.einsum('src,chd->srhd', m_norm, W_k)  # [N_seq, N_res, H, c]
    v = np.einsum('src,chd->srhd', m_norm, W_v)
    
    # Gating (per sequence)
    g = 1 / (1 + np.exp(-np.einsum('src,chd->srhd', m_norm, W_g)))
    
    # Attention: global query attends to all keys
    # For each residue position, attention over sequences
    # q: [N_res, H, c], k: [N_seq, N_res, H, c]
    attn_logits = np.einsum('rhd,srhd->rhs', q, k) / np.sqrt(c)  # [N_res, H, N_seq]
    
    # Softmax over sequences
    attn_logits_max = attn_logits.max(axis=-1, keepdims=True)
    attn_weights = np.exp(attn_logits - attn_logits_max)
    attn_weights /= attn_weights.sum(axis=-1, keepdims=True)
    
    # Apply attention
    # weighted sum of values: [N_res, H, c]
    attended = np.einsum('rhs,srhd->rhd', attn_weights, v)
    
    # Broadcast to all sequences and apply gating
    # attended: [N_res, H, c] -> broadcast to [N_seq, N_res, H, c]
    attended_broadcast = np.broadcast_to(attended[None, :, :, :], (N_seq, N_res, num_heads, c))
    gated = g * attended_broadcast
    
    # Output projection
    output = np.einsum('srhd,hdc->src', gated, W_o)
    
    print(f"  Output: {output.shape}")
    
    return output

In [None]:
# Test
N_seq, N_res, c_m = 1024, 32, 64  # Large MSA

m = np.random.randn(N_seq, N_res, c_m)

print("Test MSA Column Global Attention")
print("="*50)

output = msa_column_global_attention(m, num_heads=8, c=8)
print(f"\nShape preserved: {output.shape == m.shape}")

## Source Code Reference

```python
# From AF2-source-code/model/modules.py

class GlobalAttention(hk.Module):
  """Global attention.

  Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention"
  """

  def __call__(self, msa_act, msa_mask, is_training=False):
    # Compute mean query
    q = hk.Linear(self.num_head * self.key_dim)(q_avg)
    
    # K, V from all sequences
    k = hk.Linear(self.num_head * self.key_dim)(msa_act)
    v = hk.Linear(self.num_head * self.value_dim)(msa_act)
    
    # Global attention computation
    # ...
```