# Algorithm 22: Diffusion Loss (AlphaFold3)

The diffusion loss trains the model to denoise structures at various noise levels.

## Source Code Location
- **File**: `AF3-Ref-src/alphafold3-official/src/alphafold3/model/network/diffusion.py`

## Overview

### Diffusion Training Objective

The model learns to predict clean coordinates x_0 from noisy coordinates x_t:

```
L_diffusion = E_t,noise[ ||x_0 - f(x_t, t)||^2 ]
```

### Key Components
1. **Noise Schedule**: Defines how much noise at each t
2. **Forward Process**: Adds noise to ground truth
3. **Loss Weighting**: Different weights for different t

In [None]:
import numpy as np
np.random.seed(42)

In [None]:
def get_noise_schedule(num_steps=1000, s=0.008):
    """
    Cosine noise schedule.
    
    Returns alpha_bar values (cumulative product of alphas).
    """
    t = np.linspace(0, 1, num_steps + 1)
    f_t = np.cos((t + s) / (1 + s) * np.pi / 2) ** 2
    alpha_bar = f_t / f_t[0]
    return alpha_bar

In [None]:
def forward_diffusion(x_0, t, alpha_bar):
    """
    Add noise to clean coordinates.
    
    x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * noise
    
    Args:
        x_0: Clean coordinates [N, 3]
        t: Noise level index (0 to num_steps)
        alpha_bar: Cumulative alpha schedule
    
    Returns:
        x_t: Noisy coordinates [N, 3]
        noise: The noise that was added [N, 3]
    """
    alpha_bar_t = alpha_bar[t]
    
    noise = np.random.randn(*x_0.shape)
    x_t = np.sqrt(alpha_bar_t) * x_0 + np.sqrt(1 - alpha_bar_t) * noise
    
    return x_t, noise

In [None]:
def diffusion_loss(x_pred, x_0, t, alpha_bar, loss_type='l2', weighting='uniform'):
    """
    Compute diffusion loss.
    
    Args:
        x_pred: Predicted clean coordinates [N, 3]
        x_0: Ground truth coordinates [N, 3]
        t: Noise level index
        alpha_bar: Noise schedule
        loss_type: 'l2' or 'l1'
        weighting: 'uniform' or 'snr' (signal-to-noise ratio)
    
    Returns:
        Weighted loss scalar
    """
    print(f"Diffusion Loss")
    print(f"="*50)
    print(f"t={t}, alpha_bar_t={alpha_bar[t]:.4f}")
    
    # Compute raw loss
    if loss_type == 'l2':
        loss = np.mean((x_pred - x_0) ** 2)
    else:  # l1
        loss = np.mean(np.abs(x_pred - x_0))
    
    # Apply weighting
    if weighting == 'snr':
        # Higher weight for noisier samples (lower alpha_bar)
        snr = alpha_bar[t] / (1 - alpha_bar[t])
        weight = 1 / (snr + 1)
    else:  # uniform
        weight = 1.0
    
    weighted_loss = weight * loss
    
    print(f"Raw loss: {loss:.4f}")
    print(f"Weight: {weight:.4f}")
    print(f"Weighted loss: {weighted_loss:.4f}")
    
    return weighted_loss

In [None]:
def sample_training_step(x_0, alpha_bar, mock_predict):
    """
    Simulate a single training step.
    
    Args:
        x_0: Ground truth coordinates
        alpha_bar: Noise schedule
        mock_predict: Mock prediction function
    """
    num_steps = len(alpha_bar) - 1
    
    # Sample random time step
    t = np.random.randint(1, num_steps)
    
    # Add noise
    x_t, noise = forward_diffusion(x_0, t, alpha_bar)
    
    # Predict clean coordinates
    x_pred = mock_predict(x_t, t)
    
    # Compute loss
    loss = diffusion_loss(x_pred, x_0, t, alpha_bar)
    
    return loss, t

In [None]:
# Test
print("Test: Diffusion Loss")
print("="*60)

N = 32
x_0 = np.random.randn(N, 3) * 10  # Ground truth

alpha_bar = get_noise_schedule(num_steps=1000)

# Mock predictor (returns slightly noisy ground truth)
def mock_predict(x_t, t):
    return x_0 + np.random.randn(*x_0.shape) * 0.5

# Training step
loss, t = sample_training_step(x_0, alpha_bar, mock_predict)

print(f"\nSampled t: {t}")
print(f"Final loss: {loss:.4f}")

In [None]:
# Test: Loss at different noise levels
print("\nTest: Loss vs Noise Level")
print("="*60)

for t_test in [100, 500, 900]:
    x_t, noise = forward_diffusion(x_0, t_test, alpha_bar)
    x_pred = x_0 + np.random.randn(*x_0.shape) * 0.5
    
    print(f"\nt={t_test}:")
    print(f"  x_t norm: {np.linalg.norm(x_t):.2f}")
    print(f"  noise level: {np.sqrt(1 - alpha_bar[t_test]):.4f}")
    _ = diffusion_loss(x_pred, x_0, t_test, alpha_bar, weighting='snr')

## Key Insights

1. **x_0 Prediction**: AF3 predicts clean coordinates directly (not noise)
2. **Random t**: Sample t uniformly during training
3. **Loss Weighting**: Can weight by SNR to balance easy/hard timesteps
4. **Smooth Schedule**: Cosine schedule provides smooth noise progression