# Algorithm 9: Distogram v2 (Boltz-2)

Enhanced distogram prediction for Boltz-2.

## Source Code Location
- **File**: `Boltz-Ref-src/boltz-official/src/boltz/model/loss/distogramv2.py`

In [None]:
import numpy as np
np.random.seed(42)

def layer_norm(x, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

def softmax(x, axis=-1):
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

In [None]:
def distogram_head_v2(z, num_bins=64, min_dist=2.0, max_dist=22.0):
    """
    Distogram Head v2.
    
    Enhanced with better symmetrization and processing.
    
    Args:
        z: Pair representation [N, N, c_z]
        num_bins: Number of distance bins
        min_dist: Minimum distance
        max_dist: Maximum distance
    
    Returns:
        Distogram logits and expected distances
    """
    N = z.shape[0]
    c_z = z.shape[-1]
    
    print(f"Distogram Head v2")
    print(f"="*50)
    
    # Symmetrize first
    z_sym = (z + z.transpose(1, 0, 2)) / 2
    z_norm = layer_norm(z_sym)
    
    # Enhanced MLP (v2)
    W1 = np.random.randn(c_z, c_z) * (c_z ** -0.5)
    h = np.maximum(0, np.einsum('ijc,cd->ijd', z_norm, W1))
    h = layer_norm(h)
    
    W2 = np.random.randn(c_z, num_bins) * (c_z ** -0.5)
    logits = np.einsum('ijc,cb->ijb', h, W2)
    
    # Expected distance
    probs = softmax(logits)
    bin_centers = np.linspace(min_dist, max_dist, num_bins)
    expected_dist = np.sum(probs * bin_centers, axis=-1)
    
    print(f"Distance range: [{expected_dist.min():.2f}, {expected_dist.max():.2f}]A")
    
    return logits, expected_dist

In [None]:
def distogram_loss_v2(logits, true_distances, min_dist=2.0, max_dist=22.0, mask=None):
    """
    Distogram loss v2 with improved masking.
    
    Args:
        logits: Predicted logits [N, N, num_bins]
        true_distances: Ground truth distances [N, N]
        mask: Valid pair mask [N, N]
    
    Returns:
        Cross-entropy loss
    """
    num_bins = logits.shape[-1]
    N = logits.shape[0]
    
    if mask is None:
        mask = np.ones((N, N))
    
    # Exclude diagonal
    mask = mask * (1 - np.eye(N))
    
    # Bin edges
    bin_edges = np.linspace(min_dist, max_dist, num_bins + 1)
    
    # Clip distances outside range
    true_dist_clipped = np.clip(true_distances, min_dist, max_dist - 0.01)
    
    # Convert to bin indices
    bin_indices = np.digitize(true_dist_clipped, bin_edges) - 1
    bin_indices = np.clip(bin_indices, 0, num_bins - 1)
    
    # One-hot targets
    targets = np.eye(num_bins)[bin_indices]
    
    # Cross-entropy with mask
    probs = softmax(logits)
    log_probs = np.log(probs + 1e-8)
    ce = -np.sum(targets * log_probs, axis=-1)
    
    loss = (ce * mask).sum() / (mask.sum() + 1e-8)
    
    return loss

In [None]:
# Test
print("Test: Distogram v2")
print("="*60)

N = 32
c_z = 128

z = np.random.randn(N, N, c_z)

logits, expected_dist = distogram_head_v2(z)

# Simulate ground truth
coords = np.random.randn(N, 3) * 5
true_dist = np.sqrt(np.sum((coords[:, None] - coords[None, :]) ** 2, axis=-1))

loss = distogram_loss_v2(logits, true_dist)
print(f"\nLoss: {loss:.4f}")

## Key Insights

1. **Enhanced MLP**: Additional layers for better prediction
2. **Proper Masking**: Handles missing/invalid pairs
3. **Symmetrization**: Ensures symmetric predictions
4. **Auxiliary Signal**: Important training supervision