# Algorithm 15: Pair Transition

The Pair Transition layer is a feed-forward network applied to the pair representation. It's similar to MSA Transition (Algorithm 9) but operates on the pair representation.

## Algorithm Pseudocode

![PairTransition](../imgs/algorithms/PairTransition.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Class**: `Transition`
- **Lines**: 1162-1195

In [None]:
import numpy as np

np.random.seed(42)

In [None]:
def pair_transition(z, n=4):
    """
    Pair Transition Layer - Algorithm 15.
    
    Feed-forward network: LayerNorm -> Linear(expand) -> ReLU -> Linear(compress)
    
    Args:
        z: Pair representation [N_res, N_res, c_z]
        n: Expansion factor
    
    Returns:
        Updated z [N_res, N_res, c_z]
    """
    N_res, _, c_z = z.shape
    c_hidden = n * c_z
    
    print(f"Pair Transition Layer")
    print(f"  Input: [{N_res}, {N_res}, {c_z}]")
    print(f"  Hidden: [{N_res}, {N_res}, {c_hidden}]")
    
    # Step 1: Layer normalization
    z_norm = (z - z.mean(axis=-1, keepdims=True)) / (z.std(axis=-1, keepdims=True) + 1e-5)
    
    # Step 2: Expand (Linear1)
    W1 = np.random.randn(c_z, c_hidden) * 0.02
    b1 = np.zeros(c_hidden)
    a = z_norm @ W1 + b1
    
    # Step 3: ReLU activation
    a = np.maximum(0, a)
    
    # Step 4: Compress (Linear2)
    W2 = np.random.randn(c_hidden, c_z) * 0.02
    b2 = np.zeros(c_z)
    output = a @ W2 + b2
    
    print(f"  Output: {output.shape}")
    
    return output

In [None]:
# Test
N_res, c_z = 32, 128
z = np.random.randn(N_res, N_res, c_z)

print("Test Pair Transition")
print("="*50)

output = pair_transition(z, n=4)
print(f"\nShape preserved: {output.shape == z.shape}")

## Source Code Reference

```python
# From AF2-source-code/model/modules.py

class Transition(hk.Module):
  """Transition layer.

  Jumper et al. (2021) Suppl. Alg. 9 "MSATransition"
  Jumper et al. (2021) Suppl. Alg. 15 "PairTransition"
  """

  def __call__(self, act, mask, is_training=False):
    act = common_modules.LayerNorm(...)(act)
    act = common_modules.Linear(num_intermediate)(act)
    act = jax.nn.relu(act)
    act = common_modules.Linear(num_output)(act)
    return act
```