# Algorithm 15: Pair Transition

The Pair Transition layer is a feed-forward network applied to the pair representation. It's similar to MSA Transition (Algorithm 9) but operates on the pair representation, providing non-linear transformations after attention layers.

## Algorithm Pseudocode

![PairTransition](../imgs/algorithms/PairTransition.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Class**: `Transition`
- **Lines**: 1162-1195

## Overview

The Pair Transition implements a standard transformer-style feed-forward network:

```
Input z [N_res, N_res, c_z]
    ↓
Layer Normalization
    ↓
Linear (c_z → n*c_z)  [Expand]
    ↓
ReLU Activation
    ↓
Linear (n*c_z → c_z)  [Compress]
    ↓
Output z [N_res, N_res, c_z]
```

Where `n=4` is the expansion factor (typical transformer FFN design).

In [None]:
import numpy as np

np.random.seed(42)

## NumPy Implementation

In [None]:
def layer_norm(x, axis=-1, eps=1e-5):
    """Layer normalization."""
    mean = np.mean(x, axis=axis, keepdims=True)
    var = np.var(x, axis=axis, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)


def relu(x):
    """ReLU activation."""
    return np.maximum(0, x)


def pair_transition(z, n=4):
    """
    Pair Transition Layer - Algorithm 15.
    
    Feed-forward network: LayerNorm -> Linear(expand) -> ReLU -> Linear(compress)
    
    Args:
        z: Pair representation [N_res, N_res, c_z]
        n: Expansion factor (default: 4)
    
    Returns:
        Updated z [N_res, N_res, c_z]
    """
    N_res, _, c_z = z.shape
    c_hidden = n * c_z
    
    print(f"Pair Transition")
    print(f"="*50)
    print(f"Input: [{N_res}, {N_res}, {c_z}]")
    print(f"Hidden: [{N_res}, {N_res}, {c_hidden}]")
    print(f"Expansion factor: {n}x")
    
    # Step 1: Layer normalization (Line 1)
    z_norm = layer_norm(z, axis=-1)
    print(f"\nStep 1 - Layer norm: {z_norm.shape}")
    
    # Step 2: First linear (expand) (Line 2)
    W1 = np.random.randn(c_z, c_hidden) * (c_z ** -0.5)
    b1 = np.zeros(c_hidden)
    a = np.einsum('ijc,cd->ijd', z_norm, W1) + b1
    print(f"Step 2 - Expand: {a.shape}")
    
    # Step 3: ReLU activation (Line 3)
    a = relu(a)
    print(f"Step 3 - ReLU: {a.shape}")
    
    # Step 4: Second linear (compress) (Line 4)
    W2 = np.random.randn(c_hidden, c_z) * (c_hidden ** -0.5)
    b2 = np.zeros(c_z)
    output = np.einsum('ijd,dc->ijc', a, W2) + b2
    print(f"Step 4 - Compress: {output.shape}")
    
    return output


class PairTransition:
    """
    Object-oriented Pair Transition for reusable weights.
    """
    
    def __init__(self, c_z, n=4):
        """Initialize weights."""
        self.c_z = c_z
        self.c_hidden = n * c_z
        self.n = n
        
        # Layer norm parameters
        self.ln_scale = np.ones(c_z)
        self.ln_bias = np.zeros(c_z)
        
        # Linear 1 (expand)
        self.W1 = np.random.randn(c_z, self.c_hidden) * (c_z ** -0.5)
        self.b1 = np.zeros(self.c_hidden)
        
        # Linear 2 (compress)
        self.W2 = np.random.randn(self.c_hidden, c_z) * (self.c_hidden ** -0.5)
        self.b2 = np.zeros(c_z)
    
    def __call__(self, z):
        """Apply pair transition."""
        # Layer norm
        z_norm = layer_norm(z, axis=-1)
        z_norm = z_norm * self.ln_scale + self.ln_bias
        
        # Expand -> ReLU -> Compress
        a = np.einsum('ijc,cd->ijd', z_norm, self.W1) + self.b1
        a = relu(a)
        output = np.einsum('ijd,dc->ijc', a, self.W2) + self.b2
        
        return output
    
    def count_parameters(self):
        """Count total parameters."""
        params = {
            'layer_norm': 2 * self.c_z,
            'linear_1': self.c_z * self.c_hidden + self.c_hidden,
            'linear_2': self.c_hidden * self.c_z + self.c_z,
        }
        params['total'] = sum(params.values())
        return params

## Test Examples

In [None]:
# Test 1: Basic functionality
print("Test 1: Basic Functionality")
print("="*60)

N_res, c_z = 32, 128
z = np.random.randn(N_res, N_res, c_z).astype(np.float32)

output = pair_transition(z, n=4)

print(f"\nInput shape: {z.shape}")
print(f"Output shape: {output.shape}")
print(f"Shape preserved: {output.shape == z.shape}")

In [None]:
# Test 2: Object-oriented version
print("\nTest 2: Object-Oriented Version")
print("="*60)

np.random.seed(42)
N_res, c_z = 32, 128
z = np.random.randn(N_res, N_res, c_z).astype(np.float32)

transition = PairTransition(c_z=c_z, n=4)

output = transition(z)

print(f"Input shape: {z.shape}")
print(f"Output shape: {output.shape}")

params = transition.count_parameters()
print(f"\nParameter counts:")
for name, count in params.items():
    print(f"  {name}: {count:,}")

In [None]:
# Test 3: Different expansion factors
print("\nTest 3: Different Expansion Factors")
print("="*60)

N_res, c_z = 16, 64
z = np.random.randn(N_res, N_res, c_z).astype(np.float32)

for n in [1, 2, 4, 8]:
    np.random.seed(42)
    output = pair_transition(z.copy(), n=n)
    print()

In [None]:
# Test 4: Verify ReLU effect
print("\nTest 4: Verify ReLU Effect")
print("="*60)

np.random.seed(42)
N_res, c_z = 16, 64
z = np.random.randn(N_res, N_res, c_z).astype(np.float32)

# Trace through manually to check ReLU sparsity
z_norm = layer_norm(z, axis=-1)
c_hidden = 4 * c_z
W1 = np.random.randn(c_z, c_hidden) * (c_z ** -0.5)
b1 = np.zeros(c_hidden)
a = np.einsum('ijc,cd->ijd', z_norm, W1) + b1

# Check pre-ReLU distribution
print(f"Pre-ReLU statistics:")
print(f"  Mean: {a.mean():.4f}")
print(f"  Std: {a.std():.4f}")
print(f"  Negative fraction: {(a < 0).mean():.2%}")

a_relu = relu(a)
print(f"\nPost-ReLU statistics:")
print(f"  Mean: {a_relu.mean():.4f}")
print(f"  Zero fraction: {(a_relu == 0).mean():.2%}")

In [None]:
# Test 5: Residual connection pattern
print("\nTest 5: Residual Connection Pattern")
print("="*60)

np.random.seed(42)
N_res, c_z = 16, 64
z = np.random.randn(N_res, N_res, c_z).astype(np.float32)

# In Evoformer, transition is used with residual connection:
# z = z + Transition(z)
transition = PairTransition(c_z=c_z, n=4)
delta = transition(z)
z_updated = z + delta

print(f"Input mean: {z.mean():.4f}, std: {z.std():.4f}")
print(f"Delta mean: {delta.mean():.4f}, std: {delta.std():.4f}")
print(f"Output mean: {z_updated.mean():.4f}, std: {z_updated.std():.4f}")
print(f"\nResidual ratio (delta/input std): {delta.std() / z.std():.4f}")

## Verification: Key Properties

In [None]:
print("Verification: Key Properties")
print("="*60)

np.random.seed(42)
N_res, c_z = 24, 128
z = np.random.randn(N_res, N_res, c_z).astype(np.float32)

output = pair_transition(z, n=4)

# Property 1: Shape preserved
shape_preserved = output.shape == z.shape
print(f"Property 1 - Shape preserved: {shape_preserved}")

# Property 2: Finite output
output_finite = np.isfinite(output).all()
print(f"Property 2 - Output finite: {output_finite}")

# Property 3: Non-trivial transformation
not_identity = not np.allclose(output, z)
print(f"Property 3 - Non-trivial: {not_identity}")

# Property 4: Element-wise operation (no cross-position mixing)
# Change one position, only that position should change
z_mod = z.copy()
z_mod[0, 0, :] += 1.0
output_mod = pair_transition(z_mod, n=4)

# Due to layer norm, all positions are affected slightly
# But the most affected should be [0, 0]
diff = np.abs(output_mod - output)
max_diff_pos = np.unravel_index(np.argmax(diff.sum(axis=-1)), diff.shape[:2])
print(f"Property 4 - Most affected position: {max_diff_pos} (expected: (0, 0))")

# Property 5: Reasonable output scale
output_scale = output.std()
input_scale = z.std()
scale_ratio = output_scale / input_scale
print(f"Property 5 - Scale ratio: {scale_ratio:.4f} (should be ~1)")

## Source Code Reference

```python
# From AF2-source-code/model/modules.py

class Transition(hk.Module):
  """Transition layer.

  Jumper et al. (2021) Suppl. Alg. 9 "MSATransition"
  Jumper et al. (2021) Suppl. Alg. 15 "PairTransition"
  
  This class is used for both MSA and Pair transitions.
  """

  def __call__(self, act, mask, is_training=False):
    c = self.config
    num_intermediate = int(act.shape[-1] * c.num_intermediate_factor)
    
    # Layer norm
    act = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True,
                       name='input_layer_norm')(act)
    
    # Expand
    act = common_modules.Linear(
        num_intermediate,
        initializer='relu',
        name='transition1')(act)
    
    # ReLU
    act = jax.nn.relu(act)
    
    # Compress
    act = common_modules.Linear(
        c.num_channel,
        initializer=utils.final_init(self.global_config),
        name='transition2')(act)
    
    return act
```

## Key Insights

1. **Shared Architecture**: The same `Transition` class is used for both MSA (Algorithm 9) and Pair (Algorithm 15) transitions.

2. **Expansion Factor**: The default expansion factor of 4x is a common transformer design choice, allowing more expressive non-linear transformations.

3. **Position-Independent**: Unlike attention layers, the transition operates independently on each position (after layer norm).

4. **ReLU Sparsity**: ReLU typically zeros out ~50% of the hidden activations, creating sparse intermediate representations.

5. **Residual Connection**: In practice, the transition output is added to the input (residual connection), not used directly.

6. **Parameter Cost**: The transition layers contribute significantly to parameter count due to the 4x expansion.