# Algorithm 16: Adaptive LayerNorm (AlphaFold3)

Adaptive LayerNorm (AdaLN) is a key component of the Diffusion Transformer in AlphaFold3. It modulates normalized features based on conditioning signals, following the approach from the DiT (Diffusion Transformer) paper.

## Source Code Location
- **File**: `AF3-Ref-src/alphafold3-official/src/alphafold3/model/network/diffusion_transformer.py`
- **Function**: `adaptive_layernorm`

## Overview

### Standard LayerNorm vs Adaptive LayerNorm

**Standard LayerNorm:**
```
y = γ * (x - μ) / σ + β
```
where γ, β are learnable parameters.

**Adaptive LayerNorm:**
```
y = scale(c) * LayerNorm(x) + bias(c)
```
where scale and bias are functions of conditioning c.

### Why Adaptive?

In diffusion models, the network needs to behave differently at different noise levels. AdaLN allows the conditioning (timestep, single representation) to modulate the entire layer's behavior.

In [None]:
import numpy as np

np.random.seed(42)

## NumPy Implementation

In [None]:
def layer_norm(x, axis=-1, eps=1e-5):
    """
    Standard Layer Normalization.
    
    Normalizes along the last axis.
    """
    mean = np.mean(x, axis=axis, keepdims=True)
    var = np.var(x, axis=axis, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

In [None]:
def adaptive_layer_norm(x, single_cond, c_out=None):
    """
    Adaptive LayerNorm - Algorithm 16.
    
    Modulates normalized features based on conditioning.
    
    Args:
        x: Input features [*, c]
        single_cond: Conditioning signal [*, c_cond] or None
        c_out: Output channels (defaults to x.shape[-1])
    
    Returns:
        Modulated features
    """
    c = x.shape[-1]
    if c_out is None:
        c_out = c
    
    print(f"Adaptive LayerNorm")
    print(f"  Input: {x.shape}")
    
    # Step 1: Standard LayerNorm (without learnable params)
    x_norm = layer_norm(x)
    print(f"  Step 1: LayerNorm applied")
    
    if single_cond is None:
        print(f"  No conditioning - returning normalized")
        return x_norm
    
    c_cond = single_cond.shape[-1]
    print(f"  Conditioning: {single_cond.shape}")
    
    # Step 2: Normalize conditioning
    single_cond_norm = layer_norm(single_cond)
    print(f"  Step 2: Conditioning normalized")
    
    # Step 3: Compute scale (initialized near 1 via sigmoid)
    W_scale = np.random.randn(c_cond, c_out) * 0.02
    b_scale = np.zeros(c_out)  # Bias to start near sigmoid(0) = 0.5, doubled = 1
    
    scale_logits = single_cond_norm @ W_scale + b_scale
    scale = 1.0 / (1.0 + np.exp(-scale_logits))  # sigmoid
    print(f"  Step 3: Scale computed (mean={scale.mean():.3f})")
    
    # Step 4: Compute bias
    W_bias = np.random.randn(c_cond, c_out) * 0.02
    bias = single_cond_norm @ W_bias
    print(f"  Step 4: Bias computed (mean={bias.mean():.3f})")
    
    # Step 5: Apply modulation
    output = scale * x_norm + bias
    print(f"  Step 5: Modulation applied")
    print(f"  Output: {output.shape}")
    
    return output

In [None]:
def adaptive_zero_init(x, c_out, single_cond, name=''):
    """
    Adaptive Zero Init (AdaLN-Zero).
    
    Like AdaLN but with zero-initialized conditioning path.
    Used for residual connections to start as identity.
    
    Args:
        x: Input after processing [*, c]
        c_out: Output dimension
        single_cond: Conditioning signal
    
    Returns:
        Scaled output (starts near zero)
    """
    c = x.shape[-1]
    
    print(f"Adaptive Zero Init")
    print(f"  Input: {x.shape}")
    
    # Linear projection
    W = np.random.randn(c, c_out) * (c ** -0.5)
    output = x @ W
    
    if single_cond is None:
        # Zero-initialized final projection (small output)
        output = output * 0.01
        print(f"  No conditioning - small init")
    else:
        c_cond = single_cond.shape[-1]
        
        # Scale initialized with bias=-2 so sigmoid(-2) ≈ 0.12
        W_scale = np.random.randn(c_cond, c_out) * 0.02
        b_scale = np.full(c_out, -2.0)  # Initialize to produce ~0.1 scaling
        
        scale_logits = single_cond @ W_scale + b_scale
        scale = 1.0 / (1.0 + np.exp(-scale_logits))
        
        output = output * scale
        print(f"  Conditioning applied (scale mean={scale.mean():.3f})")
    
    print(f"  Output: {output.shape}")
    return output

## Test Examples

In [None]:
# Test 1: Basic Adaptive LayerNorm
print("Test 1: Basic Adaptive LayerNorm")
print("="*60)

N = 32
c = 128
c_cond = 64

x = np.random.randn(N, c).astype(np.float32) * 2  # Large variance
single_cond = np.random.randn(N, c_cond).astype(np.float32)

output = adaptive_layer_norm(x, single_cond)

print(f"\nInput std: {x.std():.3f}")
print(f"Output std: {output.std():.3f}")

In [None]:
# Test 2: Without conditioning (standard LayerNorm)
print("\nTest 2: Without Conditioning")
print("="*60)

x = np.random.randn(32, 128)
output = adaptive_layer_norm(x, None)

# Should be normalized
print(f"\nOutput mean: {output.mean(axis=-1).mean():.6f} (expected ~0)")
print(f"Output std: {output.std(axis=-1).mean():.6f} (expected ~1)")

In [None]:
# Test 3: Adaptive Zero Init
print("\nTest 3: Adaptive Zero Init")
print("="*60)

x = np.random.randn(32, 128)
single_cond = np.random.randn(32, 64)

output = adaptive_zero_init(x, c_out=128, single_cond=single_cond)

print(f"\nInput norm: {np.linalg.norm(x):.3f}")
print(f"Output norm: {np.linalg.norm(output):.3f} (should be much smaller)")

In [None]:
# Test 4: Effect of different conditioning
print("\nTest 4: Effect of Different Conditioning")
print("="*60)

np.random.seed(42)
N = 16
c = 64
c_cond = 32

x = np.random.randn(N, c)

# Different conditioning signals
cond_1 = np.ones((N, c_cond))  # Uniform positive
cond_2 = -np.ones((N, c_cond))  # Uniform negative
cond_3 = np.random.randn(N, c_cond)  # Random

out_1 = adaptive_layer_norm(x, cond_1)
out_2 = adaptive_layer_norm(x, cond_2)
out_3 = adaptive_layer_norm(x, cond_3)

print(f"\nConditioning 1 (positive): output mean={out_1.mean():.3f}, std={out_1.std():.3f}")
print(f"Conditioning 2 (negative): output mean={out_2.mean():.3f}, std={out_2.std():.3f}")
print(f"Conditioning 3 (random):   output mean={out_3.mean():.3f}, std={out_3.std():.3f}")

## Verification: Key Properties

In [None]:
print("Verification: Key Properties")
print("="*60)

np.random.seed(42)
N = 32
c = 64
c_cond = 32

x = np.random.randn(N, c)
cond = np.random.randn(N, c_cond)

output = adaptive_layer_norm(x, cond)

# Property 1: Shape preserved
shape_preserved = output.shape == x.shape
print(f"Property 1 - Shape preserved: {shape_preserved}")

# Property 2: Finite output
all_finite = np.isfinite(output).all()
print(f"Property 2 - All finite: {all_finite}")

# Property 3: Without conditioning, output is normalized
output_no_cond = adaptive_layer_norm(x, None)
is_normalized = np.allclose(output_no_cond.mean(axis=-1), 0, atol=1e-5) and \
                np.allclose(output_no_cond.std(axis=-1), 1, atol=1e-5)
print(f"Property 3 - No-cond is normalized: {is_normalized}")

# Property 4: AdaLN-Zero starts small
x_proc = np.random.randn(N, c)
output_zero = adaptive_zero_init(x_proc, c, cond)
is_small = np.linalg.norm(output_zero) < np.linalg.norm(x_proc) * 0.5
print(f"Property 4 - AdaLN-Zero is small: {is_small}")

## Source Code Reference

```python
# From AF3-Ref-src/alphafold3-official/src/alphafold3/model/network/diffusion_transformer.py

def adaptive_layernorm(x, single_cond, name):
  """Adaptive LayerNorm."""
  # Adopted from Scalable Diffusion Models with Transformers
  # https://arxiv.org/abs/2212.09748
  if single_cond is None:
    x = hm.LayerNorm(name=f'{name}layer_norm', use_fast_variance=False)(x)
  else:
    x = hm.LayerNorm(
        name=f'{name}layer_norm',
        use_fast_variance=False,
        create_scale=False,
        create_offset=False,
    )(x)
    single_cond = hm.LayerNorm(
        name=f'{name}single_cond_layer_norm',
        use_fast_variance=False,
        create_offset=False,
    )(single_cond)
    single_scale = hm.Linear(
        x.shape[-1],
        initializer='zeros',
        use_bias=True,
        name=f'{name}single_cond_scale',
    )(single_cond)
    single_bias = hm.Linear(
        x.shape[-1], initializer='zeros', name=f'{name}single_cond_bias'
    )(single_cond)
    x = jax.nn.sigmoid(single_scale) * x + single_bias
  return x
```

## Key Insights

1. **Conditioning Modulation**: AdaLN allows the model to dynamically adjust layer behavior based on timestep and sequence information.

2. **Scale via Sigmoid**: Using sigmoid for scale ensures it's always positive and bounded, preventing instability.

3. **Zero Initialization**: The AdaLN-Zero variant initializes residual paths to near-zero, making early training more stable.

4. **From DiT Paper**: This technique is adopted from "Scalable Diffusion Models with Transformers" (Peebles & Xie, 2022).

5. **Per-Token Modulation**: Each token can receive different scale/bias based on its conditioning, enabling fine-grained control.

6. **Replaces Class Conditioning**: In image diffusion, AdaLN often conditions on class labels; in AF3, it conditions on sequence features.