# Algorithm 7: MSA Row-wise Gated Self-Attention with Pair Bias

This algorithm implements row-wise self-attention on the MSA representation, with an additional bias term derived from the pair representation. This is a key mechanism for incorporating structural information into the MSA processing.

## Algorithm Pseudocode

![MSA Row Attention with Pair Bias](../imgs/algorithms/MSARowAttentionWithPairBias.png)

## Key Concepts

### Input/Output Shapes
- **Input MSA**: `m` with shape `[N_seq, N_res, c_m]`
- **Input Pair**: `z` with shape `[N_res, N_res, c_z]`
- **Output**: Updated MSA with same shape `[N_seq, N_res, c_m]`

### Core Ideas
1. **Row-wise Attention**: Each sequence in the MSA attends to all residue positions within that sequence
2. **Pair Bias**: The pair representation provides a learned bias to the attention logits, encoding structural relationships
3. **Gating**: Output is gated by a sigmoid-activated linear projection for controlled information flow

## Source Code Implementation

From `AF2-source-code/model/modules.py`:

```python
class MSARowAttentionWithPairBias(hk.Module):
  """MSA per-row attention biased by the pair representation.

  Jumper et al. (2021) Suppl. Alg. 7 "MSARowAttentionWithPairBias"
  """

  def __init__(self, config, global_config,
               name='msa_row_attention_with_pair_bias'):
    super().__init__(name=name)
    self.config = config
    self.global_config = global_config
```

### Forward Pass Implementation

```python
def __call__(self,
             msa_act,
             msa_mask,
             pair_act,
             is_training=False):
    """Builds MSARowAttentionWithPairBias module.

    Arguments:
      msa_act: [N_seq, N_res, c_m] MSA representation.
      msa_mask: [N_seq, N_res] mask of non-padded regions.
      pair_act: [N_res, N_res, c_z] pair representation.
      is_training: Whether the module is in training mode.

    Returns:
      Update to msa_act, shape [N_seq, N_res, c_m].
    """
    c = self.config

    assert len(msa_act.shape) == 3
    assert len(msa_mask.shape) == 2
    assert c.orientation == 'per_row'
```

### Step 1: Create Attention Mask (Line 2)

```python
    # Create bias from mask: large negative value for masked positions
    bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
    assert len(bias.shape) == 4
```

The mask is converted to a bias term:
- Valid positions (mask=1) → bias = 0
- Padded positions (mask=0) → bias = -1e9 (effectively -∞ after softmax)

### Step 2: Layer Normalization (Line 1)

```python
    # Normalize MSA input
    msa_act = hk.LayerNorm(
        axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
            msa_act)

    # Normalize pair representation for bias computation
    pair_act = hk.LayerNorm(
        axis=[-1],
        create_scale=True,
        create_offset=True,
        name='feat_2d_norm')(
            pair_act)
```

### Step 3: Compute Pair Bias (Line 3)

```python
    # Project pair representation to attention heads
    init_factor = 1. / jnp.sqrt(int(pair_act.shape[-1]))
    weights = hk.get_parameter(
        'feat_2d_weights',
        shape=(pair_act.shape[-1], c.num_head),
        init=hk.initializers.RandomNormal(stddev=init_factor))
    
    # Compute non-batched bias: [N_res, N_res] -> [num_head, N_res, N_res]
    nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights)
```

This projects the pair representation `[N_res, N_res, c_z]` to `[num_head, N_res, N_res]`, creating a bias for each attention head.

### Step 4: Apply Attention (Lines 4-7)

```python
    # Create attention module
    attn_mod = Attention(
        c, self.global_config, msa_act.shape[-1])
    
    # Apply attention with subbatching for memory efficiency
    msa_act = mapping.inference_subbatch(
        attn_mod,
        self.global_config.subbatch_size,
        batched_args=[msa_act, msa_act, bias],
        nonbatched_args=[nonbatched_bias],
        low_memory=not is_training)

    return msa_act
```

## The Attention Module

The underlying `Attention` class implements standard multi-head attention with gating:

```python
class Attention(hk.Module):
  """Multihead attention."""

  def __call__(self, q_data, m_data, bias, nonbatched_bias=None):
    # Get dimensions
    key_dim = self.config.get('key_dim', int(q_data.shape[-1]))
    value_dim = self.config.get('value_dim', int(m_data.shape[-1]))
    num_head = self.config.num_head
    key_dim = key_dim // num_head
    value_dim = value_dim // num_head

    # Compute Q, K, V projections
    q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5)
    k = jnp.einsum('bka,ahc->bkhc', m_data, k_weights)
    v = jnp.einsum('bka,ahc->bkhc', m_data, v_weights)
    
    # Compute attention logits with bias
    logits = jnp.einsum('bqhc,bkhc->bhqk', q, k) + bias
    if nonbatched_bias is not None:
      logits += jnp.expand_dims(nonbatched_bias, axis=0)
    
    # Softmax and weighted sum
    weights = jax.nn.softmax(logits)
    weighted_avg = jnp.einsum('bhqk,bkhc->bqhc', weights, v)
```

### Gating Mechanism (Lines 8-10)

```python
    if self.config.gating:
      # Compute gate values from query data
      gating_weights = hk.get_parameter(
          'gating_w',
          shape=(q_data.shape[-1], num_head, value_dim),
          init=hk.initializers.Constant(0.0))
      gating_bias = hk.get_parameter(
          'gating_b',
          shape=(num_head, value_dim),
          init=hk.initializers.Constant(1.0))

      gate_values = jnp.einsum('bqc, chv->bqhv', q_data,
                               gating_weights) + gating_bias
      gate_values = jax.nn.sigmoid(gate_values)

      # Apply gating
      weighted_avg *= gate_values
```

The gating mechanism:
- Initialized with bias=1, so initial gate ≈ sigmoid(1) ≈ 0.73
- Allows the network to learn to suppress or amplify attention outputs

### Output Projection (Line 11)

```python
    # Project back to output dimension
    o_weights = hk.get_parameter(
        'output_w', shape=(num_head, value_dim, self.output_dim),
        init=init)
    o_bias = hk.get_parameter('output_b', shape=(self.output_dim,),
                              init=hk.initializers.Constant(0.0))

    output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias
    return output
```

## Mathematical Formulation

The attention computation can be written as:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + B_{\text{pair}} + B_{\text{mask}}\right) V$$

Where:
- $Q, K, V$: Query, Key, Value projections from MSA
- $d_k$: Key dimension per head
- $B_{\text{pair}}$: Bias from pair representation
- $B_{\text{mask}}$: Mask bias (-∞ for padded positions)

The pair bias encodes:
- Residue-residue distance information
- Evolutionary coupling signals
- Structural constraints

## Summary

MSA Row Attention with Pair Bias serves as a bridge between:
- **Sequence information** (MSA representation)
- **Structural information** (pair representation)

Key features:
1. Row-wise attention allows each sequence to refine its residue representations
2. Pair bias injects structural knowledge into the attention pattern
3. Gating provides learnable control over information flow
4. Subbatching enables memory-efficient processing of long sequences