# Algorithm 8: MSA Column-wise Gated Self-Attention

MSA Column Attention applies self-attention across all sequences at each residue position. This allows the model to learn evolutionary patterns by comparing how different sequences vary at the same position.

## Algorithm Pseudocode

![MSA Column Attention](../imgs/algorithms/MSAColumnAttention.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Class**: `MSAColumnAttention`
- **Lines**: 779-831

## Key Difference from Row Attention

| Aspect | Row Attention (Alg 7) | Column Attention (Alg 8) |
|--------|----------------------|-------------------------|
| Attention direction | Along residues (within sequence) | Along sequences (at each position) |
| Input shape | [N_seq, N_res, c_m] | [N_seq, N_res, c_m] |
| Attention axis | Axis 1 (N_res) | Axis 0 (N_seq) |
| Pair bias | Yes (from z) | No |
| Purpose | Intra-sequence relationships | Cross-sequence (evolutionary) patterns |

In [None]:
import numpy as np

np.random.seed(42)

## NumPy Implementation

In [None]:
def layer_norm(x, axis=-1, eps=1e-5):
    mean = np.mean(x, axis=axis, keepdims=True)
    var = np.var(x, axis=axis, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)


def softmax(x, axis=-1):
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)


def sigmoid(x):
    return 1 / (1 + np.exp(-np.clip(x, -500, 500)))


def gated_attention(q_data, m_data, bias, num_head, with_gating=True):
    """
    Standard multi-head attention with optional gating.
    
    Args:
        q_data: Queries [batch, N_queries, c]
        m_data: Keys/Values [batch, N_keys, c]
        bias: Attention bias [batch, 1, 1, N_keys]
        num_head: Number of attention heads
        with_gating: Whether to apply output gating
    
    Returns:
        Attention output [batch, N_queries, c]
    """
    batch, n_q, c = q_data.shape
    _, n_k, _ = m_data.shape
    
    head_dim = c // num_head
    
    # Q, K, V projections
    q_w = np.random.randn(c, num_head, head_dim) * 0.01
    k_w = np.random.randn(c, num_head, head_dim) * 0.01
    v_w = np.random.randn(c, num_head, head_dim) * 0.01
    
    q = np.einsum('bqa,ahc->bqhc', q_data, q_w) * (head_dim ** -0.5)
    k = np.einsum('bka,ahc->bkhc', m_data, k_w)
    v = np.einsum('bka,ahc->bkhc', m_data, v_w)
    
    # Attention logits
    logits = np.einsum('bqhc,bkhc->bhqk', q, k) + bias
    
    # Softmax
    weights = softmax(logits, axis=-1)
    
    # Weighted sum
    weighted_avg = np.einsum('bhqk,bkhc->bqhc', weights, v)
    
    # Gating
    if with_gating:
        gate_w = np.random.randn(c, num_head, head_dim) * 0.01
        gate_b = np.ones((num_head, head_dim))
        gate = sigmoid(np.einsum('bqc,chv->bqhv', q_data, gate_w) + gate_b)
        weighted_avg = weighted_avg * gate
    
    # Output projection
    o_w = np.random.randn(num_head, head_dim, c) * 0.01
    output = np.einsum('bqhc,hco->bqo', weighted_avg, o_w)
    
    return output

In [None]:
def msa_column_attention(msa_act, msa_mask, num_head=8):
    """
    MSA Column-wise Gated Self-Attention.
    
    Algorithm 8 from AlphaFold2 supplementary materials.
    
    At each residue position, attention is applied across all sequences.
    
    Args:
        msa_act: MSA activations [N_seq, N_res, c_m]
        msa_mask: MSA mask [N_seq, N_res]
        num_head: Number of attention heads
    
    Returns:
        Updated MSA activations [N_seq, N_res, c_m]
    """
    N_seq, N_res, c_m = msa_act.shape
    
    print(f"Input shape: [{N_seq}, {N_res}, {c_m}]")
    
    # Step 1: Transpose to make sequences the "query" dimension
    # Original: [N_seq, N_res, c_m]
    # After swap: [N_res, N_seq, c_m]
    msa_act_t = np.swapaxes(msa_act, 0, 1)  # [N_res, N_seq, c_m]
    msa_mask_t = np.swapaxes(msa_mask, 0, 1)  # [N_res, N_seq]
    
    print(f"After transpose: [{msa_act_t.shape[0]}, {msa_act_t.shape[1]}, {msa_act_t.shape[2]}]")
    
    # Step 2: Create attention bias from mask
    # Masked positions get large negative bias
    bias = (1e9 * (msa_mask_t - 1.))[:, None, None, :]  # [N_res, 1, 1, N_seq]
    
    # Step 3: Layer normalization
    msa_act_t = layer_norm(msa_act_t, axis=-1)
    
    # Step 4: Apply attention
    # Each residue position is a "batch", attending across sequences
    output = gated_attention(
        q_data=msa_act_t,
        m_data=msa_act_t,
        bias=bias,
        num_head=num_head,
        with_gating=True
    )
    
    print(f"Attention output: {output.shape}")
    
    # Step 5: Transpose back to original shape
    output = np.swapaxes(output, 0, 1)  # [N_seq, N_res, c_m]
    
    print(f"Final output: {output.shape}")
    
    return output

## Test Example

In [None]:
# Test parameters
N_seq = 128    # Number of sequences in MSA
N_res = 64     # Number of residues
c_m = 256      # MSA channel dimension

# Create test inputs
msa_act = np.random.randn(N_seq, N_res, c_m).astype(np.float32)
msa_mask = np.ones((N_seq, N_res), dtype=np.float32)

# Mask out some sequences (e.g., padding)
msa_mask[-10:, :] = 0  # Last 10 sequences are padding

print(f"MSA activations: {msa_act.shape}")
print(f"MSA mask: {msa_mask.shape}")
print(f"Valid sequences: {int(msa_mask[:, 0].sum())} / {N_seq}")
print()

In [None]:
# Run column attention
output = msa_column_attention(msa_act, msa_mask, num_head=8)

print(f"\nOutput statistics: mean={output.mean():.6f}, std={output.std():.6f}")

## Verify Masking Works

In [None]:
# Check that masked sequences don't affect output
# Run with different masked content, output for valid sequences should be same

msa_act_modified = msa_act.copy()
msa_act_modified[-10:, :, :] = np.random.randn(10, N_res, c_m) * 100  # Large random values

output_modified = msa_column_attention(msa_act_modified, msa_mask, num_head=8)

# Compare outputs for valid sequences
valid_output_diff = np.abs(output[:-10] - output_modified[:-10]).max()
print(f"Max difference in valid sequence outputs: {valid_output_diff:.6f}")
print("(Should be very small if masking works correctly)")

## Source Code Reference

```python
# From AF2-source-code/model/modules.py

class MSAColumnAttention(hk.Module):
  """MSA per-column attention.

  Jumper et al. (2021) Suppl. Alg. 8 "MSAColumnAttention"
  """

  def __call__(self, msa_act, msa_mask, is_training=False):
    c = self.config

    assert len(msa_act.shape) == 3
    assert len(msa_mask.shape) == 2
    assert c.orientation == 'per_column'

    # Transpose: [N_seq, N_res, c_m] -> [N_res, N_seq, c_m]
    msa_act = jnp.swapaxes(msa_act, -2, -3)
    msa_mask = jnp.swapaxes(msa_mask, -1, -2)

    bias = (1e9 * (msa_mask - 1.))[:, None, None, :]

    msa_act = hk.LayerNorm(axis=[-1], ...)(msa_act)

    attn_mod = Attention(c, self.global_config, msa_act.shape[-1])
    msa_act = mapping.inference_subbatch(
        attn_mod, ...,
        batched_args=[msa_act, msa_act, bias],
        nonbatched_args=[])

    # Transpose back: [N_res, N_seq, c_m] -> [N_seq, N_res, c_m]
    msa_act = jnp.swapaxes(msa_act, -2, -3)

    return msa_act
```

## Key Insights

1. **Evolutionary Information**: Column attention captures how sequences co-vary at each position, which is a strong signal for structural contacts.

2. **No Pair Bias**: Unlike row attention, column attention doesn't use pair representation bias - it's purely sequence-based.

3. **Transpose Trick**: The same attention module is reused by transposing the data.

4. **Efficiency**: For long sequences with few MSA sequences, column attention is more efficient than row attention.