# Algorithm 7: Diffusion v2 (Boltz-2)

Enhanced diffusion module for Boltz-2.

## Source Code Location
- **File**: `Boltz-Ref-src/boltz-official/src/boltz/model/modules/diffusionv2.py`

In [None]:
import numpy as np
np.random.seed(42)

def layer_norm(x, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

In [None]:
def edm_preconditioning(sigma, sigma_data=16.0):
    """
    EDM-style preconditioning.
    
    Returns:
        c_skip, c_out, c_in, c_noise
    """
    c_skip = sigma_data ** 2 / (sigma ** 2 + sigma_data ** 2)
    c_out = sigma * sigma_data / np.sqrt(sigma ** 2 + sigma_data ** 2)
    c_in = 1 / np.sqrt(sigma ** 2 + sigma_data ** 2)
    c_noise = 0.25 * np.log(sigma + 1e-8)
    
    return c_skip, c_out, c_in, c_noise

In [None]:
def diffusion_v2_forward(x_t, sigma, s, z, sigma_data=16.0):
    """
    Diffusion v2 forward pass.
    
    Enhanced with better conditioning and architecture.
    
    Args:
        x_t: Noisy coordinates [N_atoms, 3]
        sigma: Noise level
        s: Single representation [N, c_s]
        z: Pair representation [N, N, c_z]
        sigma_data: Data standard deviation
    
    Returns:
        Predicted clean coordinates [N_atoms, 3]
    """
    N_atoms = x_t.shape[0]
    c_s = s.shape[-1]
    
    print(f"Diffusion v2 Forward")
    print(f"="*50)
    print(f"Atoms: {N_atoms}, sigma: {sigma:.4f}")
    
    # EDM preconditioning
    c_skip, c_out, c_in, c_noise = edm_preconditioning(sigma, sigma_data)
    
    print(f"  c_skip: {c_skip:.4f}, c_out: {c_out:.4f}")
    
    # Scale input
    x_scaled = c_in * x_t
    
    # Enhanced noise embedding (v2 uses better embedding)
    dim = 128
    half = dim // 2
    freqs = np.exp(-np.log(10000) * np.arange(half) / half)
    args = c_noise * freqs
    t_emb = np.concatenate([np.cos(args), np.sin(args)])
    
    # Condition single representation
    W_t = np.random.randn(dim, c_s) * (dim ** -0.5)
    s_cond = s + t_emb @ W_t
    
    # Simplified network
    N = min(N_atoms, s.shape[0])
    W_x = np.random.randn(3, c_s) * (3 ** -0.5)
    x_emb = x_scaled[:N] @ W_x
    
    h = x_emb + s_cond[:N]
    h = layer_norm(h)
    
    W_out = np.random.randn(c_s, 3) * (c_s ** -0.5)
    F_theta = h @ W_out
    
    if N < N_atoms:
        F_theta = np.concatenate([F_theta, np.zeros((N_atoms - N, 3))])
    
    # EDM output
    x_pred = c_skip * x_t + c_out * F_theta
    
    print(f"  Output: {x_pred.shape}")
    
    return x_pred

In [None]:
# Test
print("Test: Diffusion v2")
print("="*60)

N_atoms = 64
N = 32
c_s = 128
c_z = 64

x_0 = np.random.randn(N_atoms, 3) * 10
sigma = 5.0
x_t = x_0 + np.random.randn(N_atoms, 3) * sigma

s = np.random.randn(N, c_s)
z = np.random.randn(N, N, c_z)

x_pred = diffusion_v2_forward(x_t, sigma, s, z)

print(f"\nPrediction finite: {np.isfinite(x_pred).all()}")

## Key Insights

1. **EDM Framework**: Elucidating Diffusion Models preconditioning
2. **Enhanced Embedding**: Better noise level embedding
3. **Skip Connection**: c_skip * x_t for easier learning
4. **Conditioning**: Improved conditioning mechanism