# Algorithm 20: Distogram Head (AlphaFold3)

Predicts distance distributions between residues from pair representation.

## Source Code Location
- **File**: `AF3-Ref-src/alphafold3-official/src/alphafold3/model/network/heads.py`

## Overview

### Purpose
- Predicts pairwise distance distributions between C-beta atoms
- Provides auxiliary supervision during training
- Helps learn informative pair representations

### Distance Bins
- 64 bins from 2 to 22 Angstroms
- Bin width: 0.3125 A

In [None]:
import numpy as np
np.random.seed(42)

def layer_norm(x, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

def softmax(x, axis=-1):
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

In [None]:
def distogram_head(z, num_bins=64, min_dist=2.0, max_dist=22.0):
    """
    Distogram prediction head.
    
    Args:
        z: Pair representation [N, N, c_z]
        num_bins: Number of distance bins
        min_dist: Minimum distance in Angstroms
        max_dist: Maximum distance in Angstroms
    
    Returns:
        Distogram logits [N, N, num_bins]
    """
    N = z.shape[0]
    c_z = z.shape[-1]
    
    print(f"Distogram Head")
    print(f"="*50)
    print(f"Pair: [{N}, {N}, {c_z}]")
    print(f"Bins: {num_bins} ({min_dist}A to {max_dist}A)")
    
    # Symmetrize pair representation
    z_sym = (z + z.transpose(1, 0, 2)) / 2
    z_norm = layer_norm(z_sym)
    
    # Project to bins
    W = np.random.randn(c_z, num_bins) * (c_z ** -0.5)
    b = np.zeros(num_bins)
    
    logits = np.einsum('ijc,cb->ijb', z_norm, W) + b
    
    print(f"Logits: {logits.shape}")
    
    return logits

In [None]:
def distogram_loss(logits, distances, num_bins=64, min_dist=2.0, max_dist=22.0):
    """
    Compute distogram cross-entropy loss.
    
    Args:
        logits: Predicted logits [N, N, num_bins]
        distances: Ground truth distances [N, N]
    """
    # Create bin edges
    bin_edges = np.linspace(min_dist, max_dist, num_bins + 1)
    
    # Convert distances to bin indices
    bin_indices = np.digitize(distances, bin_edges) - 1
    bin_indices = np.clip(bin_indices, 0, num_bins - 1)
    
    # One-hot targets
    targets = np.eye(num_bins)[bin_indices]  # [N, N, num_bins]
    
    # Cross-entropy loss
    probs = softmax(logits)
    loss = -np.sum(targets * np.log(probs + 1e-8)) / (logits.shape[0] ** 2)
    
    return loss

In [None]:
# Test
print("Test: Distogram Head")
print("="*60)

N = 32
c_z = 128

z = np.random.randn(N, N, c_z)

logits = distogram_head(z)
print(f"Logits finite: {np.isfinite(logits).all()}")

# Simulate ground truth distances
coords = np.random.randn(N, 3) * 5
distances = np.sqrt(np.sum((coords[:, None] - coords[None, :]) ** 2, axis=-1))

loss = distogram_loss(logits, distances)
print(f"\nLoss: {loss:.4f}")

## Key Insights

1. **Symmetrization**: Distance is symmetric, so z is symmetrized
2. **Auxiliary Loss**: Provides additional training signal
3. **64 Bins**: Covers 2-22 Angstrom range
4. **Cross-Entropy**: Standard classification loss over bins