# Algorithm 9: MSA Transition

The MSA Transition is a simple feed-forward network (two-layer MLP) applied to each position independently. It provides additional non-linear transformation capacity between attention layers.

## Algorithm Pseudocode

![MSA Transition](../imgs/algorithms/MSATransition.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Class**: `Transition`
- **Lines**: 476-529

**Note**: The same `Transition` class is used for both MSA Transition (Alg 9) and Pair Transition (Alg 15).

## Structure

The transition block follows a standard pattern:
1. **Layer Normalization**: Normalize input
2. **Linear Expansion**: Project to 4x wider hidden dimension
3. **ReLU Activation**: Non-linearity
4. **Linear Projection**: Project back to original dimension
5. **Residual Connection**: Add to input (applied in wrapper)

In [None]:
import numpy as np

np.random.seed(42)

## NumPy Implementation

In [None]:
def layer_norm(x, axis=-1, eps=1e-5, gamma=None, beta=None):
    """Layer normalization with optional scale and offset."""
    mean = np.mean(x, axis=axis, keepdims=True)
    var = np.var(x, axis=axis, keepdims=True)
    normalized = (x - mean) / np.sqrt(var + eps)
    
    if gamma is not None:
        normalized = normalized * gamma
    if beta is not None:
        normalized = normalized + beta
    
    return normalized


def relu(x):
    return np.maximum(x, 0)


def transition(act, mask, num_intermediate_factor=4):
    """
    MSA Transition / Pair Transition.
    
    Algorithm 9 (MSA Transition) and Algorithm 15 (Pair Transition)
    from AlphaFold2 supplementary materials.
    
    A simple 2-layer MLP with expansion and contraction.
    
    Args:
        act: Input activations [..., N, c]
        mask: Mask [..., N]
        num_intermediate_factor: Expansion factor (default 4)
    
    Returns:
        Output activations [..., N, c] (same shape as input)
    """
    *batch_dims, n, c = act.shape
    num_intermediate = int(c * num_intermediate_factor)
    
    print(f"Input shape: {act.shape}")
    print(f"Channel dimension: {c}")
    print(f"Intermediate dimension: {num_intermediate}")
    
    # Expand mask for broadcasting
    mask_expanded = mask[..., None]
    
    # Step 1: Layer normalization (Line 1)
    gamma = np.ones(c)  # Learnable scale
    beta = np.zeros(c)  # Learnable offset
    act_norm = layer_norm(act, axis=-1, gamma=gamma, beta=beta)
    
    print(f"After LayerNorm: {act_norm.shape}")
    
    # Step 2: First linear layer - expansion (Line 2)
    # Expand from c to 4*c
    w1 = np.random.randn(c, num_intermediate) * np.sqrt(2.0 / c)  # He init for ReLU
    b1 = np.zeros(num_intermediate)
    
    hidden = np.einsum('...c,cd->...d', act_norm, w1) + b1
    
    print(f"After expansion: {hidden.shape}")
    
    # Step 3: ReLU activation (Line 3)
    hidden = relu(hidden)
    
    # Step 4: Second linear layer - contraction (Line 4)
    # Contract from 4*c back to c
    w2 = np.random.randn(num_intermediate, c) * 0.01  # Small init for residual
    b2 = np.zeros(c)
    
    output = np.einsum('...d,dc->...c', hidden, w2) + b2
    
    print(f"After contraction: {output.shape}")
    
    # Note: Residual connection is applied in the wrapper function (dropout_wrapper)
    # output = input + dropout(transition(input))
    
    return output

## Test: MSA Transition

In [None]:
# Test parameters for MSA
N_seq = 128    # Number of sequences
N_res = 64     # Number of residues
c_m = 256      # MSA channel dimension

# Create test inputs
msa_act = np.random.randn(N_seq, N_res, c_m).astype(np.float32)
msa_mask = np.ones((N_seq, N_res), dtype=np.float32)

print("MSA Transition Test")
print("="*50)
print(f"Input: {msa_act.shape}")
print()

# Run transition
msa_update = transition(msa_act, msa_mask, num_intermediate_factor=4)

print(f"\nOutput statistics:")
print(f"  Mean: {msa_update.mean():.6f}")
print(f"  Std: {msa_update.std():.6f}")

## Test: Pair Transition

In [None]:
# Test parameters for Pair representation
N_res = 64     # Number of residues
c_z = 128      # Pair channel dimension

# Create test inputs
pair_act = np.random.randn(N_res, N_res, c_z).astype(np.float32)
pair_mask = np.ones((N_res, N_res), dtype=np.float32)

print("\nPair Transition Test")
print("="*50)
print(f"Input: {pair_act.shape}")
print()

# Run transition
pair_update = transition(pair_act, pair_mask, num_intermediate_factor=4)

print(f"\nOutput statistics:")
print(f"  Mean: {pair_update.mean():.6f}")
print(f"  Std: {pair_update.std():.6f}")

## Full Transition with Residual

In [None]:
def transition_with_residual(act, mask, dropout_rate=0.0, num_intermediate_factor=4):
    """
    Complete transition block including residual connection.
    
    This matches how it's called in the Evoformer via dropout_wrapper.
    """
    # Compute transition
    update = transition(act, mask, num_intermediate_factor)
    
    # Apply dropout (skip in inference)
    if dropout_rate > 0:
        mask_drop = np.random.binomial(1, 1 - dropout_rate, update.shape)
        update = update * mask_drop / (1 - dropout_rate)
    
    # Residual connection
    output = act + update
    
    return output


# Test with residual
print("\nTransition with Residual Connection")
print("="*50)

# Before
input_norm = np.linalg.norm(msa_act)
print(f"Input norm: {input_norm:.4f}")

# Apply transition
output = transition_with_residual(msa_act, msa_mask)

# After
output_norm = np.linalg.norm(output)
update_norm = np.linalg.norm(output - msa_act)

print(f"Output norm: {output_norm:.4f}")
print(f"Update norm: {update_norm:.4f}")
print(f"Update relative to input: {update_norm / input_norm:.4f}")

## Source Code Reference

```python
# From AF2-source-code/model/modules.py

class Transition(hk.Module):
  """Transition layer.

  Jumper et al. (2021) Suppl. Alg. 9 "MSATransition"
  Jumper et al. (2021) Suppl. Alg. 15 "PairTransition"
  """

  def __call__(self, act, mask, is_training=True):
    _, _, nc = act.shape

    num_intermediate = int(nc * self.config.num_intermediate_factor)
    mask = jnp.expand_dims(mask, axis=-1)

    act = hk.LayerNorm(
        axis=[-1],
        create_scale=True,
        create_offset=True,
        name='input_layer_norm')(act)

    transition_module = hk.Sequential([
        common_modules.Linear(
            num_intermediate,
            initializer='relu',
            name='transition1'), 
        jax.nn.relu,
        common_modules.Linear(
            nc,
            initializer=utils.final_init(self.global_config),
            name='transition2')
    ])

    act = mapping.inference_subbatch(
        transition_module,
        self.global_config.subbatch_size,
        batched_args=[act],
        nonbatched_args=[],
        low_memory=not is_training)

    return act
```

## Key Insights

1. **Expansion Factor**: Default is 4x, meaning hidden dimension is 4 times the input dimension (256 â†’ 1024 for MSA).

2. **Position-wise**: Applied independently to each position (same as Transformer FFN).

3. **Weight Initialization**: 
   - First layer uses 'relu' init (He initialization)
   - Second layer uses small initialization for stable residual learning

4. **Shared Implementation**: Same class handles both MSA (Algorithm 9) and Pair (Algorithm 15) transitions.

5. **Subbatching**: For memory efficiency, the transition can be applied in chunks.