# Algorithm 19: Diffusion Loss (Boltz)

Training loss for diffusion-based structure prediction.

## Source Code Location
- **File**: `Boltz-Ref-src/boltz-official/src/boltz/model/loss/diffusion.py`

In [None]:
import numpy as np
np.random.seed(42)

In [None]:
def weighted_rigid_align(pred, true, weights=None):
    """
    Weighted Kabsch alignment.
    
    Aligns predicted coordinates to true coordinates.
    
    Args:
        pred: Predicted coordinates [N, 3]
        true: True coordinates [N, 3]
        weights: Per-atom weights [N]
    
    Returns:
        Aligned predicted coordinates [N, 3]
    """
    N = pred.shape[0]
    if weights is None:
        weights = np.ones(N)
    
    # Center
    w_sum = weights.sum()
    pred_center = (pred * weights[:, None]).sum(axis=0) / w_sum
    true_center = (true * weights[:, None]).sum(axis=0) / w_sum
    
    pred_centered = pred - pred_center
    true_centered = true - true_center
    
    # Weighted covariance
    H = (pred_centered * weights[:, None]).T @ true_centered
    
    # SVD
    U, S, Vt = np.linalg.svd(H)
    
    # Rotation
    d = np.sign(np.linalg.det(Vt.T @ U.T))
    R = Vt.T @ np.diag([1, 1, d]) @ U.T
    
    # Apply
    pred_aligned = pred_centered @ R + true_center
    
    return pred_aligned

In [None]:
def mse_loss(pred, true, mask=None):
    """Mean squared error loss."""
    diff = pred - true
    sq_diff = np.sum(diff ** 2, axis=-1)
    
    if mask is not None:
        return (sq_diff * mask).sum() / mask.sum()
    return sq_diff.mean()

In [None]:
def smooth_lddt_loss(pred, true, cutoff=15.0, temperature=0.5):
    """
    Smooth LDDT loss.
    
    Differentiable version of LDDT for training.
    """
    thresholds = [0.5, 1.0, 2.0, 4.0]
    
    # Compute distances
    pred_dist = np.sqrt(np.sum((pred[:, None] - pred[None, :]) ** 2, axis=-1))
    true_dist = np.sqrt(np.sum((true[:, None] - true[None, :]) ** 2, axis=-1))
    
    # Mask for local contacts
    mask = 1 / (1 + np.exp((true_dist - cutoff) / temperature))
    mask = mask * (1 - np.eye(len(true)))  # Exclude diagonal
    
    # Distance difference
    dist_diff = np.abs(pred_dist - true_dist)
    
    # Smooth thresholds
    scores = []
    for thresh in thresholds:
        within = 1 / (1 + np.exp((dist_diff - thresh) / (temperature * thresh)))
        score = (within * mask).sum() / (mask.sum() + 1e-8)
        scores.append(score)
    
    lddt = np.mean(scores)
    return 1 - lddt  # Loss = 1 - LDDT

In [None]:
def diffusion_loss(pred, true, sigma, sigma_data=16, align=True):
    """
    Diffusion training loss.
    
    Args:
        pred: Predicted coordinates [N, 3]
        true: True coordinates [N, 3]
        sigma: Noise level
        sigma_data: Data standard deviation
        align: Whether to align before computing loss
    
    Returns:
        Total loss
    """
    print(f"Diffusion Loss")
    print(f"="*50)
    print(f"sigma: {sigma:.4f}")
    
    # Align if requested
    if align:
        pred_aligned = weighted_rigid_align(pred, true)
    else:
        pred_aligned = pred
    
    # MSE loss
    mse = mse_loss(pred_aligned, true)
    
    # Smooth LDDT loss
    lddt_loss = smooth_lddt_loss(pred_aligned, true)
    
    # Weight by noise level (EDM-style)
    weight = (sigma ** 2 + sigma_data ** 2) / (sigma * sigma_data) ** 2
    
    total_loss = weight * mse + lddt_loss
    
    print(f"  MSE: {mse:.4f}")
    print(f"  LDDT loss: {lddt_loss:.4f}")
    print(f"  Weight: {weight:.4f}")
    print(f"  Total: {total_loss:.4f}")
    
    return total_loss

In [None]:
# Test
print("Test: Diffusion Loss")
print("="*60)

N = 32

# Ground truth
true = np.random.randn(N, 3) * 10

# Good prediction (low noise)
pred_good = true + np.random.randn(N, 3) * 0.5
print("Good prediction:")
loss_good = diffusion_loss(pred_good, true, sigma=0.5)

print()

# Bad prediction (high noise)
pred_bad = true + np.random.randn(N, 3) * 5.0
print("Bad prediction:")
loss_bad = diffusion_loss(pred_bad, true, sigma=5.0)

## Key Insights

1. **Alignment**: Rigid alignment before loss computation
2. **MSE + LDDT**: Combines coordinate and structural losses
3. **Noise Weighting**: EDM-style weighting by sigma
4. **Smooth LDDT**: Differentiable version for training