# Algorithm 14: Pair Transition (AlphaFold3)

Feed-forward transition layer for pair representation with SwiGLU.

## Source Code Location
- **File**: `AF3-Ref-src/alphafold3-official/src/alphafold3/model/network/modules.py`
- **Class**: `TransitionBlock`

In [None]:
import numpy as np
np.random.seed(42)

def layer_norm(x, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

def swish(x):
    return x / (1 + np.exp(-np.clip(x, -500, 500)))

In [None]:
def pair_transition(z, n=4):
    """
    Pair Transition block with SwiGLU.
    
    Args:
        z: Pair representation [N, N, c_z]
        n: Expansion factor
    
    Returns:
        Update to pair representation [N, N, c_z]
    """
    N = z.shape[0]
    c_z = z.shape[-1]
    c_hidden = c_z * n
    
    print(f"Pair Transition")
    print(f"="*50)
    print(f"Pair: [{N}, {N}, {c_z}]")
    print(f"Expansion: {c_z} → {c_hidden} → {c_z}")
    
    z_norm = layer_norm(z)
    
    # SwiGLU: swish(a) * b
    W1 = np.random.randn(c_z, c_hidden * 2) * (c_z ** -0.5)
    hidden = z_norm @ W1
    a, b = np.split(hidden, 2, axis=-1)
    hidden = swish(a) * b
    
    # Down projection
    W2 = np.random.randn(c_hidden, c_z) * (c_hidden ** -0.5)
    output = hidden @ W2
    
    print(f"Output: {output.shape}")
    
    return output

In [None]:
# Test
print("Test: Pair Transition")
print("="*60)

N = 32
c_z = 128

z = np.random.randn(N, N, c_z)

output = pair_transition(z, n=4)

print(f"\nInput norm: {np.linalg.norm(z):.2f}")
print(f"Output norm: {np.linalg.norm(output):.2f}")
print(f"Output finite: {np.isfinite(output).all()}")

## Key Insights

1. **SwiGLU**: swish(a) * b gating mechanism
2. **4x Expansion**: Default expansion factor of 4
3. **Residual**: Output is added to input in Pairformer block
4. **Same as MSA Transition**: Same architecture, different representation