# Algorithm 7: MSA Transition (AlphaFold3)

Feed-forward transition layer for MSA representation with SwiGLU activation.

## Source Code Location
- **File**: `AF3-Ref-src/alphafold3-official/src/alphafold3/model/network/modules.py`
- **Class**: `TransitionBlock`

In [None]:
import numpy as np
np.random.seed(42)

def layer_norm(x, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

def swish(x):
    return x / (1 + np.exp(-np.clip(x, -500, 500)))

In [None]:
def msa_transition(msa, n=4, use_glu=True):
    """
    MSA Transition block with SwiGLU.
    
    Args:
        msa: MSA representation [N_msa, N_token, c_m]
        n: Expansion factor
        use_glu: Use gated linear unit
    
    Returns:
        Updated MSA [N_msa, N_token, c_m]
    """
    *batch_dims, c_m = msa.shape
    c_hidden = c_m * n
    
    print(f"MSA Transition")
    print(f"="*50)
    print(f"Input: {msa.shape}")
    print(f"Expansion: {c_m} → {c_hidden} → {c_m}")
    
    # Layer norm
    msa_norm = layer_norm(msa)
    
    if use_glu:
        # SwiGLU: swish(a) * b
        W1 = np.random.randn(c_m, c_hidden * 2) * (c_m ** -0.5)
        hidden = msa_norm @ W1  # [..., c_hidden * 2]
        a, b = np.split(hidden, 2, axis=-1)
        hidden = swish(a) * b
        print(f"Using SwiGLU activation")
    else:
        W1 = np.random.randn(c_m, c_hidden) * (c_m ** -0.5)
        hidden = swish(msa_norm @ W1)
    
    # Down projection
    W2 = np.random.randn(c_hidden, c_m) * (c_hidden ** -0.5)
    output = hidden @ W2
    
    print(f"Output: {output.shape}")
    
    return output

In [None]:
# Test
print("Test: MSA Transition")
print("="*60)

N_msa = 64
N_token = 32
c_m = 64

msa = np.random.randn(N_msa, N_token, c_m)

output = msa_transition(msa, n=4)

print(f"\nInput norm: {np.linalg.norm(msa):.2f}")
print(f"Output norm: {np.linalg.norm(output):.2f}")
print(f"Output finite: {np.isfinite(output).all()}")

## Key Insights

1. **SwiGLU**: Uses Swish-Gated Linear Unit instead of ReLU (AF2)
2. **Expansion Factor**: 4x expansion by default
3. **GLU Structure**: Two parallel paths: swish(a) * b
4. **Layer Norm**: Applied before the transition