# Algorithm 11: Transition Block (Boltz)

Feed-forward transition with SwiGLU activation.

## Source Code Location
- **File**: `Boltz-Ref-src/boltz-official/src/boltz/model/layers/transition.py`

In [None]:
import numpy as np
np.random.seed(42)

def layer_norm(x, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

def sigmoid(x):
    return 1 / (1 + np.exp(-np.clip(x, -500, 500)))

def swish(x):
    return x * sigmoid(x)

In [None]:
def transition_block(x, n=4, use_glu=True):
    """
    Transition Block with SwiGLU.
    
    Args:
        x: Input tensor [..., c]
        n: Expansion factor
        use_glu: Whether to use GLU
    
    Returns:
        Output tensor [..., c]
    """
    c = x.shape[-1]
    c_hidden = c * n
    
    print(f"Transition Block")
    print(f"="*50)
    print(f"Input: {x.shape}")
    print(f"Expansion: {c} -> {c_hidden} -> {c}")
    
    # Layer norm
    x_norm = layer_norm(x)
    
    if use_glu:
        # SwiGLU: swish(a) * b
        W1 = np.random.randn(c, c_hidden * 2) * (c ** -0.5)
        hidden = x_norm @ W1
        a, b = np.split(hidden, 2, axis=-1)
        hidden = swish(a) * b
        print(f"Using SwiGLU")
    else:
        W1 = np.random.randn(c, c_hidden) * (c ** -0.5)
        hidden = swish(x_norm @ W1)
        print(f"Using Swish")
    
    # Down projection
    W2 = np.random.randn(c_hidden, c) * (c_hidden ** -0.5)
    output = hidden @ W2
    
    print(f"Output: {output.shape}")
    return output

In [None]:
# Test
print("Test: Transition Block")
print("="*60)

# Single representation
s = np.random.randn(32, 128)
output_s = transition_block(s, n=4)
print(f"Single output: {output_s.shape}")

print()

# Pair representation
z = np.random.randn(24, 24, 64)
output_z = transition_block(z, n=2)
print(f"Pair output: {output_z.shape}")

## Key Insights

1. **SwiGLU**: swish(a) * b gating
2. **Expansion**: 4x for single, 2x for pair typically
3. **Layer Norm**: Applied before transition
4. **Residual**: Output added to input in blocks