# Algorithm 17: Distogram Head (Boltz)

Predicts distance distributions between residues.

## Source Code Location
- **File**: `Boltz-Ref-src/boltz-official/src/boltz/model/modules/trunk.py`
- **Class**: `DistogramModule`

In [None]:
import numpy as np
np.random.seed(42)

def layer_norm(x, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

def softmax(x, axis=-1):
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

In [None]:
def distogram_head(z, num_bins=64, min_dist=2.0, max_dist=22.0):
    """
    Distogram prediction head.
    
    Args:
        z: Pair representation [N, N, c_z]
        num_bins: Number of distance bins
        min_dist: Minimum distance (Angstroms)
        max_dist: Maximum distance (Angstroms)
    
    Returns:
        Distogram logits and expected distance
    """
    N = z.shape[0]
    c_z = z.shape[-1]
    
    print(f"Distogram Head")
    print(f"="*50)
    print(f"Bins: {num_bins} ({min_dist}A to {max_dist}A)")
    
    # Symmetrize
    z_sym = (z + z.transpose(1, 0, 2)) / 2
    z_norm = layer_norm(z_sym)
    
    # Project to bins
    W = np.random.randn(c_z, num_bins) * (c_z ** -0.5)
    logits = np.einsum('ijc,cb->ijb', z_norm, W)
    
    # Expected distance
    probs = softmax(logits)
    bin_centers = np.linspace(min_dist, max_dist, num_bins)
    expected_dist = np.sum(probs * bin_centers, axis=-1)
    
    print(f"Logits: {logits.shape}")
    print(f"Expected distance range: [{expected_dist.min():.2f}, {expected_dist.max():.2f}]A")
    
    return logits, expected_dist

In [None]:
def distogram_loss(logits, true_distances, min_dist=2.0, max_dist=22.0):
    """
    Compute distogram cross-entropy loss.
    
    Args:
        logits: Predicted logits [N, N, num_bins]
        true_distances: Ground truth distances [N, N]
    
    Returns:
        Cross-entropy loss
    """
    num_bins = logits.shape[-1]
    
    # Bin edges
    bin_edges = np.linspace(min_dist, max_dist, num_bins + 1)
    
    # Convert distances to bin indices
    bin_indices = np.digitize(true_distances, bin_edges) - 1
    bin_indices = np.clip(bin_indices, 0, num_bins - 1)
    
    # One-hot targets
    targets = np.eye(num_bins)[bin_indices]
    
    # Cross-entropy
    probs = softmax(logits)
    loss = -np.sum(targets * np.log(probs + 1e-8)) / (logits.shape[0] ** 2)
    
    return loss

In [None]:
# Test
print("Test: Distogram Head")
print("="*60)

N = 32
c_z = 128

z = np.random.randn(N, N, c_z)

logits, expected_dist = distogram_head(z)

# Simulate ground truth
coords = np.random.randn(N, 3) * 5
true_dist = np.sqrt(np.sum((coords[:, None] - coords[None, :]) ** 2, axis=-1))

loss = distogram_loss(logits, true_dist)
print(f"\nLoss: {loss:.4f}")

## Key Insights

1. **Symmetrization**: Distances are symmetric
2. **64 Bins**: Covers 2-22 Angstrom range
3. **Auxiliary Loss**: Provides training signal
4. **Expected Distance**: Can compute expected value