# Algorithm 15: Diffusion Module (AlphaFold3)

The Diffusion Module is the core structural prediction component of AlphaFold3, replacing AlphaFold2's IPA-based Structure Module. It uses a denoising diffusion process to generate 3D atomic coordinates.

## Key Difference from AlphaFold2

| Aspect | AlphaFold2 | AlphaFold3 |
|--------|------------|------------|
| Prediction | Direct backbone frames + torsions | Denoising diffusion on all atoms |
| Output | Single structure | Ensemble of structures |
| Iterations | 8 refinement steps | 200 diffusion steps |

## Source Code Location
- **File**: `AF3-Ref-src/alphafold3-official/src/alphafold3/model/network/diffusion_head.py`
- **Class**: `DiffusionHead`

## Overview

### Diffusion Process

```
Forward Process (Training):
x₀ (clean) → x₁ → x₂ → ... → x_T (noise)
           +ε   +ε        +ε

Reverse Process (Inference):
x_T (noise) → x_{T-1} → ... → x₁ → x₀ (predicted)
            denoise    denoise   denoise
```

### AlphaFold3 Diffusion

1. **Noise Schedule**: Defines how noise is added at each step
2. **Conditioning**: Single and pair representations condition the denoising
3. **Atom Cross Attention**: Aggregates information across atoms
4. **Diffusion Transformer**: Processes noisy coordinates

In [None]:
import numpy as np

np.random.seed(42)

## NumPy Implementation

In [None]:
def get_noise_schedule(num_steps=200, s=0.008):
    """
    Cosine noise schedule for diffusion.
    
    Args:
        num_steps: Number of diffusion steps
        s: Small offset to prevent singularity
    
    Returns:
        alphas_cumprod: Cumulative product of alphas
    """
    t = np.linspace(0, num_steps, num_steps + 1)
    f_t = np.cos((t / num_steps + s) / (1 + s) * np.pi / 2) ** 2
    alphas_cumprod = f_t / f_t[0]
    return alphas_cumprod


def add_noise(x0, t, alphas_cumprod):
    """
    Add noise to clean coordinates at timestep t.
    
    x_t = sqrt(alpha_t) * x_0 + sqrt(1 - alpha_t) * epsilon
    """
    alpha_t = alphas_cumprod[t]
    noise = np.random.randn(*x0.shape)
    x_t = np.sqrt(alpha_t) * x0 + np.sqrt(1 - alpha_t) * noise
    return x_t, noise

In [None]:
def layer_norm(x, axis=-1, eps=1e-5):
    """Layer normalization."""
    mean = np.mean(x, axis=axis, keepdims=True)
    var = np.var(x, axis=axis, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)


def adaptive_layer_norm(x, single_cond):
    """
    Adaptive LayerNorm (AdaLN) from DiT paper.
    
    Modulates normalized features based on conditioning.
    """
    x_norm = layer_norm(x)
    
    if single_cond is None:
        return x_norm
    
    single_cond_norm = layer_norm(single_cond)
    
    c = x.shape[-1]
    W_scale = np.random.randn(single_cond.shape[-1], c) * 0.02
    W_bias = np.random.randn(single_cond.shape[-1], c) * 0.02
    
    scale = 1.0 / (1.0 + np.exp(-single_cond_norm @ W_scale))  # sigmoid
    bias = single_cond_norm @ W_bias
    
    return scale * x_norm + bias

In [None]:
def noise_level_embedding(t, c=128, max_t=200):
    """
    Sinusoidal embedding for noise level (timestep).
    
    Similar to positional encoding in Transformers.
    """
    half_c = c // 2
    freqs = np.exp(-np.log(10000.0) * np.arange(half_c) / half_c)
    
    # Normalize timestep to [0, 1]
    t_normalized = t / max_t
    
    angles = t_normalized * freqs * 1000
    embedding = np.concatenate([np.sin(angles), np.cos(angles)])
    
    return embedding

In [None]:
def diffusion_module_forward(x_t, t, single, pair, c=128):
    """
    Simplified Diffusion Module forward pass.
    
    Args:
        x_t: Noisy atom coordinates [N_atoms, 3]
        t: Current timestep
        single: Single representation [N_tokens, c_s]
        pair: Pair representation [N_tokens, N_tokens, c_z]
        c: Hidden dimension
    
    Returns:
        Predicted noise [N_atoms, 3]
    """
    N_atoms = x_t.shape[0]
    N_tokens = single.shape[0]
    c_s = single.shape[-1]
    
    print(f"Diffusion Module Forward")
    print(f"="*50)
    print(f"Atoms: {N_atoms}, Tokens: {N_tokens}")
    print(f"Timestep: {t}")
    
    # Step 1: Noise level embedding
    t_emb = noise_level_embedding(t, c=c_s)
    single_cond = single + t_emb[None, :]  # Broadcast to all tokens
    print(f"\nStep 1: Noise embedding added to single")
    
    # Step 2: Project coordinates to hidden dimension
    W_in = np.random.randn(3, c) * 0.1
    h = x_t @ W_in  # [N_atoms, c]
    print(f"Step 2: Coordinates projected: {h.shape}")
    
    # Step 3: Adaptive LayerNorm with conditioning
    # Simplified: use mean of single as conditioning
    cond = single_cond.mean(axis=0)  # [c_s]
    W_cond = np.random.randn(c_s, c) * 0.1
    cond_proj = cond @ W_cond  # [c]
    
    h = adaptive_layer_norm(h, np.tile(cond_proj, (N_atoms, 1)))
    print(f"Step 3: Adaptive LayerNorm applied")
    
    # Step 4: Simplified self-attention (would be Diffusion Transformer)
    W_q = np.random.randn(c, c) * (c ** -0.5)
    W_k = np.random.randn(c, c) * (c ** -0.5)
    W_v = np.random.randn(c, c) * (c ** -0.5)
    
    q = h @ W_q
    k = h @ W_k
    v = h @ W_v
    
    attn = np.softmax(q @ k.T / np.sqrt(c), axis=-1) if hasattr(np, 'softmax') else \
           np.exp(q @ k.T / np.sqrt(c)) / np.exp(q @ k.T / np.sqrt(c)).sum(axis=-1, keepdims=True)
    h = attn @ v
    print(f"Step 4: Self-attention computed")
    
    # Step 5: MLP / Transition
    W_up = np.random.randn(c, c * 4) * 0.1
    W_down = np.random.randn(c * 4, c) * 0.1
    h = h @ W_up
    h = np.maximum(0, h)  # ReLU
    h = h @ W_down
    print(f"Step 5: MLP transition applied")
    
    # Step 6: Project back to 3D
    W_out = np.random.randn(c, 3) * 0.02  # Small init for residual
    noise_pred = h @ W_out
    print(f"Step 6: Output projected: {noise_pred.shape}")
    
    return noise_pred

In [None]:
def sample_diffusion(single, pair, N_atoms, num_steps=20):
    """
    Sample structure using reverse diffusion process.
    
    Simplified version with fewer steps for demonstration.
    """
    print(f"\nSampling with {num_steps} diffusion steps")
    print(f"="*50)
    
    alphas_cumprod = get_noise_schedule(num_steps)
    
    # Start from pure noise
    x_t = np.random.randn(N_atoms, 3) * 10  # Scale for Angstroms
    print(f"Initial noise std: {x_t.std():.2f}")
    
    trajectory = [x_t.copy()]
    
    # Reverse diffusion
    for t in range(num_steps, 0, -1):
        # Predict noise
        noise_pred = diffusion_module_forward(x_t, t, single, pair)
        
        # Compute alpha values
        alpha_t = alphas_cumprod[t]
        alpha_t_prev = alphas_cumprod[t - 1]
        
        # Predict x_0
        x0_pred = (x_t - np.sqrt(1 - alpha_t) * noise_pred) / np.sqrt(alpha_t)
        
        # Compute x_{t-1}
        if t > 1:
            noise = np.random.randn(*x_t.shape)
            sigma = np.sqrt((1 - alpha_t_prev) / (1 - alpha_t) * (1 - alpha_t / alpha_t_prev))
            x_t = np.sqrt(alpha_t_prev) * x0_pred + \
                  np.sqrt(1 - alpha_t_prev - sigma**2) * noise_pred + \
                  sigma * noise
        else:
            x_t = x0_pred
        
        trajectory.append(x_t.copy())
        
        if t % 5 == 0:
            print(f"Step {t}: std = {x_t.std():.2f}")
    
    print(f"\nFinal structure std: {x_t.std():.2f}")
    return x_t, trajectory

## Test Examples

In [None]:
# Test 1: Noise schedule visualization
print("Test 1: Noise Schedule")
print("="*60)

alphas = get_noise_schedule(200)
print(f"Alpha at t=0: {alphas[0]:.4f} (clean)")
print(f"Alpha at t=100: {alphas[100]:.4f}")
print(f"Alpha at t=200: {alphas[200]:.4f} (noisy)")

In [None]:
# Test 2: Forward diffusion (adding noise)
print("\nTest 2: Forward Diffusion")
print("="*60)

# Clean coordinates (simple helix)
N_atoms = 50
t_param = np.linspace(0, 4 * np.pi, N_atoms)
x0 = np.stack([
    t_param * 0.5,
    np.cos(t_param) * 5,
    np.sin(t_param) * 5
], axis=1)

print(f"Clean structure: mean={x0.mean():.2f}, std={x0.std():.2f}")

alphas = get_noise_schedule(200)
for t in [0, 50, 100, 150, 200]:
    x_t, _ = add_noise(x0, t, alphas)
    print(f"t={t:3d}: mean={x_t.mean():.2f}, std={x_t.std():.2f}")

In [None]:
# Test 3: Single forward pass
print("\nTest 3: Diffusion Module Forward Pass")
print("="*60)

N_atoms = 100
N_tokens = 32
c_s = 384
c_z = 128

# Random noisy coordinates
x_t = np.random.randn(N_atoms, 3) * 5

# Conditioning
single = np.random.randn(N_tokens, c_s)
pair = np.random.randn(N_tokens, N_tokens, c_z)

# Forward pass
noise_pred = diffusion_module_forward(x_t, t=100, single=single, pair=pair)

In [None]:
# Test 4: Full sampling (simplified)
print("\nTest 4: Sampling (Simplified)")
print("="*60)

np.random.seed(42)
N_atoms = 50
N_tokens = 16

single = np.random.randn(N_tokens, 128)
pair = np.random.randn(N_tokens, N_tokens, 64)

# Use fewer steps for speed
final_coords, trajectory = sample_diffusion(single, pair, N_atoms, num_steps=10)

## Verification: Key Properties

In [None]:
print("Verification: Key Properties")
print("="*60)

# Property 1: Noise schedule starts at 1, ends near 0
alphas = get_noise_schedule(200)
schedule_valid = alphas[0] > 0.99 and alphas[-1] < 0.01
print(f"Property 1 - Valid noise schedule: {schedule_valid}")

# Property 2: Adding noise increases variance
x0 = np.random.randn(50, 3)
x_50, _ = add_noise(x0, 50, alphas)
x_100, _ = add_noise(x0, 100, alphas)
variance_increases = x_50.var() <= x_100.var()
print(f"Property 2 - Noise increases with t: {variance_increases}")

# Property 3: Noise embedding has correct dimension
emb = noise_level_embedding(100, c=128)
emb_dim_correct = emb.shape == (128,)
print(f"Property 3 - Embedding dimension correct: {emb_dim_correct}")

# Property 4: Forward pass preserves atom count
N_atoms = 64
x_t = np.random.randn(N_atoms, 3)
single = np.random.randn(16, 128)
pair = np.random.randn(16, 16, 64)
noise_pred = diffusion_module_forward(x_t, 50, single, pair)
shape_preserved = noise_pred.shape == (N_atoms, 3)
print(f"Property 4 - Output shape preserved: {shape_preserved}")

## Key Insights

1. **Diffusion vs Direct Prediction**: AlphaFold3 generates structures through iterative denoising rather than direct frame prediction, enabling better uncertainty quantification.

2. **Noise Schedule**: The cosine schedule provides smooth noise levels, important for stable training and sampling.

3. **Conditioning**: Single and pair representations condition the denoising process, incorporating sequence and evolutionary information.

4. **Adaptive LayerNorm**: From DiT (Diffusion Transformer) paper, modulates features based on timestep and conditioning.

5. **200 Steps**: Full inference uses 200 denoising steps, much more than AF2's 8 refinement iterations.

6. **Ensemble Generation**: Multiple samples can be generated for uncertainty estimation.