# Algorithm 19: MSA Column Global Attention

MSA Column Global Attention is an efficient attention mechanism for processing large MSAs. Instead of full O(N²) attention, it uses a global query that summarizes all sequences, reducing complexity to O(N).

## Algorithm Pseudocode

![MSAColumnGlobalAttention](../imgs/algorithms/MSAColumnGlobalAttention.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Class**: `GlobalAttention`
- **Lines**: 717-795

## Overview

### Comparison: Regular vs Global Attention

| Aspect | Regular Column (Alg 8) | Global Column (Alg 19) |
|--------|------------------------|------------------------|
| Complexity | O(N_seq²) | O(N_seq) |
| Query | All sequences | Mean of sequences |
| Memory | O(N² × N_res) | O(N × N_res) |
| Use case | Main MSA (~512) | Extra MSA (~5000) |

### Algorithm Steps

1. Compute global query as mean over sequences
2. Compute keys and values for all sequences
3. Attend from global query to all sequences
4. Broadcast attended result to all sequences
5. Apply gating

In [None]:
import numpy as np

np.random.seed(42)

## NumPy Implementation

In [None]:
def layer_norm(x, axis=-1, eps=1e-5):
    """Layer normalization."""
    mean = np.mean(x, axis=axis, keepdims=True)
    var = np.var(x, axis=axis, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)


def softmax(x, axis=-1):
    """Softmax."""
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)


def sigmoid(x):
    """Sigmoid."""
    return 1.0 / (1.0 + np.exp(-np.clip(x, -20, 20)))


def msa_column_global_attention(m, msa_mask=None, num_heads=8, c=8):
    """
    MSA Column Global Attention - Algorithm 19.
    
    Efficient attention using mean query.
    
    Args:
        m: MSA representation [N_seq, N_res, c_m]
        msa_mask: Valid sequence mask [N_seq, N_res]
        num_heads: Number of attention heads
        c: Head dimension
    
    Returns:
        Updated m [N_seq, N_res, c_m]
    """
    N_seq, N_res, c_m = m.shape
    
    print(f"MSA Column Global Attention")
    print(f"="*50)
    print(f"Input: [{N_seq}, {N_res}, {c_m}]")
    print(f"Heads: {num_heads}, Head dim: {c}")
    
    if msa_mask is None:
        msa_mask = np.ones((N_seq, N_res), dtype=np.float32)
    
    # Step 1: Layer normalization
    m_norm = layer_norm(m, axis=-1)
    print(f"\nStep 1 - Layer norm: {m_norm.shape}")
    
    # Step 2: Initialize weights
    W_q = np.random.randn(c_m, num_heads, c) * (c_m ** -0.5)
    W_k = np.random.randn(c_m, num_heads, c) * (c_m ** -0.5)
    W_v = np.random.randn(c_m, num_heads, c) * (c_m ** -0.5)
    W_g = np.random.randn(c_m, num_heads, c) * (c_m ** -0.5)
    W_o = np.random.randn(num_heads, c, c_m) * ((num_heads * c) ** -0.5)
    
    # Step 3: GLOBAL QUERY - mean over sequences (Line 2)
    # Key difference from regular attention!
    mask_expanded = msa_mask[:, :, None]  # [N_seq, N_res, 1]
    m_masked = m_norm * mask_expanded
    
    # Mean over valid sequences for each position
    n_valid = mask_expanded.sum(axis=0)  # [N_res, 1]
    q_input = m_masked.sum(axis=0) / np.maximum(n_valid, 1.0)  # [N_res, c_m]
    
    # Project to query
    q = np.einsum('rc,chd->rhd', q_input, W_q)  # [N_res, H, c]
    print(f"Step 3 - Global query: {q.shape}")
    
    # Step 4: Keys and values from all sequences (Lines 3-4)
    k = np.einsum('src,chd->srhd', m_norm, W_k)  # [N_seq, N_res, H, c]
    v = np.einsum('src,chd->srhd', m_norm, W_v)
    print(f"Step 4 - K: {k.shape}, V: {v.shape}")
    
    # Step 5: Gating (per sequence) (Line 5)
    g = sigmoid(np.einsum('src,chd->srhd', m_norm, W_g))
    print(f"Step 5 - Gates: {g.shape}")
    
    # Step 6: Attention (Lines 6-7)
    # Global query attends to all sequences
    # q: [N_res, H, c], k: [N_seq, N_res, H, c]
    attn_logits = np.einsum('rhd,srhd->rhs', q, k) / np.sqrt(c)  # [N_res, H, N_seq]
    
    # Apply mask
    mask_attn = msa_mask.T[:, None, :]  # [N_res, 1, N_seq]
    attn_logits = np.where(mask_attn > 0, attn_logits, -1e9)
    
    # Softmax over sequences
    attn_weights = softmax(attn_logits, axis=-1)
    print(f"Step 6 - Attention weights: {attn_weights.shape}")
    
    # Step 7: Apply attention to values (Line 8)
    # attended[r,h,d] = sum_s attn[r,h,s] * v[s,r,h,d]
    attended = np.einsum('rhs,srhd->rhd', attn_weights, v)
    print(f"Step 7 - Attended (global): {attended.shape}")
    
    # Step 8: Broadcast to all sequences (Line 9)
    attended_broadcast = np.broadcast_to(
        attended[None, :, :, :], 
        (N_seq, N_res, num_heads, c)
    )
    print(f"Step 8 - Broadcast: {attended_broadcast.shape}")
    
    # Step 9: Apply gating (Line 10)
    gated = g * attended_broadcast
    print(f"Step 9 - Gated: {gated.shape}")
    
    # Step 10: Output projection (Line 11)
    output = np.einsum('srhd,hdc->src', gated, W_o)
    print(f"Step 10 - Output: {output.shape}")
    
    return output

## Test Examples

In [None]:
# Test 1: Basic functionality
print("Test 1: Basic Functionality")
print("="*60)

N_seq, N_res, c_m = 1024, 32, 64
m = np.random.randn(N_seq, N_res, c_m).astype(np.float32)

output = msa_column_global_attention(m, num_heads=8, c=8)

print(f"\nInput shape: {m.shape}")
print(f"Output shape: {output.shape}")
print(f"Shape preserved: {output.shape == m.shape}")

In [None]:
# Test 2: Complexity comparison
print("\nTest 2: Complexity Comparison")
print("="*60)

import time

N_res, c_m = 32, 64

print("Global Attention (O(N)):")
for N_seq in [256, 512, 1024, 2048, 4096]:
    m = np.random.randn(N_seq, N_res, c_m).astype(np.float32)
    
    start = time.time()
    _ = msa_column_global_attention(m, num_heads=8, c=8)
    elapsed = (time.time() - start) * 1000
    
    print(f"  N_seq={N_seq:4d}: {elapsed:6.2f}ms")

In [None]:
# Test 3: With masking
print("\nTest 3: With Masking")
print("="*60)

N_seq, N_res, c_m = 512, 24, 32
m = np.random.randn(N_seq, N_res, c_m).astype(np.float32)

# Mask half of the sequences
msa_mask = np.ones((N_seq, N_res), dtype=np.float32)
msa_mask[N_seq//2:, :] = 0

output = msa_column_global_attention(m, msa_mask=msa_mask, num_heads=8, c=8)

print(f"\nMasked {(msa_mask == 0).sum() // N_res} sequences")
print(f"Output shape: {output.shape}")

In [None]:
# Test 4: Global query behavior
print("\nTest 4: Global Query Behavior")
print("="*60)

np.random.seed(42)
N_seq, N_res, c_m = 256, 16, 32

# Create structured MSA with patterns
m = np.random.randn(N_seq, N_res, c_m).astype(np.float32)

# Make first sequence very different
m[0] = m[0] * 10.0

output = msa_column_global_attention(m, num_heads=4, c=8)

print(f"First sequence input norm: {np.linalg.norm(m[0]):.2f}")
print(f"First sequence output norm: {np.linalg.norm(output[0]):.2f}")
print(f"Other sequences mean output norm: {np.mean([np.linalg.norm(output[i]) for i in range(1, 10)]):.2f}")

## Verification: Key Properties

In [None]:
print("Verification: Key Properties")
print("="*60)

np.random.seed(42)
N_seq, N_res, c_m = 512, 20, 48
m = np.random.randn(N_seq, N_res, c_m).astype(np.float32)

output = msa_column_global_attention(m, num_heads=8, c=8)

# Property 1: Shape preserved
shape_preserved = output.shape == m.shape
print(f"Property 1 - Shape preserved: {shape_preserved}")

# Property 2: Finite output
output_finite = np.isfinite(output).all()
print(f"Property 2 - Output finite: {output_finite}")

# Property 3: Non-trivial transformation
not_identity = not np.allclose(output, m)
print(f"Property 3 - Non-trivial: {not_identity}")

# Property 4: Gating allows sequence-specific output
# Different sequences can have different outputs
output_seq0 = output[0]
output_seq1 = output[1]
sequences_differ = not np.allclose(output_seq0, output_seq1)
print(f"Property 4 - Sequence-specific output: {sequences_differ}")

# Property 5: Reasonable scale
scale_ratio = output.std() / m.std()
print(f"Property 5 - Scale ratio: {scale_ratio:.4f}")

## Source Code Reference

```python
# From AF2-source-code/model/modules.py

class GlobalAttention(hk.Module):
  """Global attention.

  Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention"
  """

  def __call__(self, msa_act, msa_mask, is_training=False):
    c = self.config
    
    # Layer norm
    msa_act = common_modules.LayerNorm(...)(msa_act)
    
    # Compute mean query
    q_avg = utils.mask_mean(msa_mask[..., None], msa_act, axis=-3)
    q = common_modules.Linear(self.num_head * self.key_dim)(q_avg)
    
    # K, V from all sequences
    k = common_modules.Linear(self.num_head * self.key_dim)(msa_act)
    v = common_modules.Linear(self.num_head * self.value_dim)(msa_act)
    
    # Gating per sequence
    g = jax.nn.sigmoid(
        common_modules.Linear(self.num_head * self.value_dim)(msa_act))
    
    # Global attention computation
    logits = jnp.einsum('qhc,khc->hqk', q, k)
    weights = jax.nn.softmax(logits)
    weighted_avg = jnp.einsum('hqk,khc->qhc', weights, v)
    
    # Gate and project
    output = g * weighted_avg[None]  # Broadcast to all sequences
    output = common_modules.Linear(self.output_dim)(output)
    
    return output
```

## Key Insights

1. **O(N) Complexity**: The global query reduces attention complexity from O(N²) to O(N), enabling processing of very large MSAs.

2. **Mean Query**: Using the mean as the query represents a "consensus" that attends to all sequences.

3. **Gating for Diversity**: The per-sequence gating allows different sequences to receive different weighted combinations of the attended values.

4. **Broadcasting**: The same attended result is broadcast to all sequences, then modulated by sequence-specific gates.

5. **Memory Efficiency**: No need to store O(N²) attention matrices, only O(N) keys/values.

6. **Trade-off**: Less expressive than full attention, but sufficient for extracting coarse co-evolution signals from large MSAs.