# Algorithm 20: Confidence Loss (Boltz)

Training loss for confidence prediction (pLDDT, pAE).

## Source Code Location
- **File**: `Boltz-Ref-src/boltz-official/src/boltz/model/loss/confidence.py`

In [None]:
import numpy as np
np.random.seed(42)

def softmax(x, axis=-1):
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

In [None]:
def compute_lddt(pred, true, cutoff=15.0):
    """
    Compute per-residue LDDT.
    
    Args:
        pred: Predicted coordinates [N, 3]
        true: True coordinates [N, 3]
        cutoff: Distance cutoff
    
    Returns:
        Per-residue LDDT [N]
    """
    thresholds = [0.5, 1.0, 2.0, 4.0]
    
    pred_dist = np.sqrt(np.sum((pred[:, None] - pred[None, :]) ** 2, axis=-1))
    true_dist = np.sqrt(np.sum((true[:, None] - true[None, :]) ** 2, axis=-1))
    
    # Mask for local contacts
    mask = (true_dist < cutoff) & (true_dist > 0)
    
    # Distance difference
    dist_diff = np.abs(pred_dist - true_dist)
    
    # Count within thresholds
    scores = []
    for thresh in thresholds:
        within = (dist_diff < thresh) & mask
        score = within.sum(axis=1) / np.maximum(mask.sum(axis=1), 1)
        scores.append(score)
    
    lddt = np.mean(scores, axis=0)
    return lddt

In [None]:
def plddt_loss(logits, pred_coords, true_coords, num_bins=50):
    """
    pLDDT prediction loss.
    
    Args:
        logits: Predicted pLDDT logits [N, num_bins]
        pred_coords: Predicted coordinates [N, 3]
        true_coords: True coordinates [N, 3]
    
    Returns:
        Cross-entropy loss
    """
    # Compute true LDDT
    true_lddt = compute_lddt(pred_coords, true_coords)  # [N], in [0, 1]
    
    # Convert to bin indices
    bin_indices = np.clip((true_lddt * num_bins).astype(int), 0, num_bins - 1)
    
    # One-hot targets
    targets = np.eye(num_bins)[bin_indices]  # [N, num_bins]
    
    # Cross-entropy
    probs = softmax(logits)
    loss = -np.sum(targets * np.log(probs + 1e-8)) / len(logits)
    
    return loss

In [None]:
def pae_loss(logits, pred_coords, true_coords, num_bins=64, max_error=32.0):
    """
    PAE prediction loss.
    
    Args:
        logits: Predicted PAE logits [N, N, num_bins]
        pred_coords: Predicted coordinates [N, 3]
        true_coords: True coordinates [N, 3]
    
    Returns:
        Cross-entropy loss
    """
    N = pred_coords.shape[0]
    
    # Compute true aligned errors (simplified)
    errors = np.sqrt(np.sum((pred_coords[:, None] - true_coords[None, :]) ** 2, axis=-1))
    
    # Bin edges
    bin_edges = np.linspace(0, max_error, num_bins + 1)
    
    # Convert to bin indices
    bin_indices = np.digitize(errors, bin_edges) - 1
    bin_indices = np.clip(bin_indices, 0, num_bins - 1)
    
    # One-hot targets
    targets = np.eye(num_bins)[bin_indices]  # [N, N, num_bins]
    
    # Cross-entropy
    probs = softmax(logits)
    loss = -np.sum(targets * np.log(probs + 1e-8)) / (N ** 2)
    
    return loss

In [None]:
def confidence_loss(plddt_logits, pae_logits, pred_coords, true_coords):
    """
    Combined confidence loss.
    """
    print(f"Confidence Loss")
    print(f"="*50)
    
    l_plddt = plddt_loss(plddt_logits, pred_coords, true_coords)
    print(f"  pLDDT loss: {l_plddt:.4f}")
    
    l_pae = pae_loss(pae_logits, pred_coords, true_coords)
    print(f"  PAE loss: {l_pae:.4f}")
    
    total = l_plddt + l_pae
    print(f"  Total: {total:.4f}")
    
    return total

In [None]:
# Test
print("Test: Confidence Loss")
print("="*60)

N = 24
num_plddt_bins = 50
num_pae_bins = 64

# Ground truth
true_coords = np.random.randn(N, 3) * 10

# Prediction
pred_coords = true_coords + np.random.randn(N, 3) * 1.0

# Random logits
plddt_logits = np.random.randn(N, num_plddt_bins)
pae_logits = np.random.randn(N, N, num_pae_bins)

loss = confidence_loss(plddt_logits, pae_logits, pred_coords, true_coords)

## Key Insights

1. **Supervised by Structure**: Uses predicted vs true structure to compute targets
2. **Binned Classification**: Cross-entropy over discrete bins
3. **LDDT as Target**: pLDDT trained to predict actual LDDT
4. **PAE as Target**: Trained on aligned error values