# Algorithm 1: Affinity Module (Boltz-2)

The core new feature in Boltz-2: predicts binding affinity between molecules.

## Source Code Location
- **File**: `Boltz-Ref-src/boltz-official/src/boltz/model/modules/affinity.py`
- **Class**: `AffinityModule`
- **Paper Algorithm**: Algorithm 31

## Overview

Boltz-2 introduces binding affinity prediction, approaching FEP (Free Energy Perturbation) accuracy while being **1000x faster**.

### Key Outputs

| Output | Description | Use Case |
|--------|-------------|----------|
| `affinity_pred_value` | log10(IC50) in Î¼M | Lead optimization |
| `affinity_probability_binary` | Binder probability (0-1) | Hit discovery |

In [None]:
import numpy as np
np.random.seed(42)

def layer_norm(x, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

def sigmoid(x):
    return 1 / (1 + np.exp(-np.clip(x, -500, 500)))

def softmax(x, axis=-1):
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

In [None]:
def gaussian_smearing(distances, start=0.0, stop=5.0, num_gaussians=50):
    """Gaussian smearing for distance encoding."""
    offset = np.linspace(start, stop, num_gaussians)
    coeff = -0.5 / (offset[1] - offset[0]) ** 2
    
    dist_expanded = distances[..., np.newaxis] - offset
    return np.exp(coeff * dist_expanded ** 2)

In [None]:
def compute_interface_distances(x_pred, ligand_mask, protein_mask):
    """Compute distances between ligand and protein atoms."""
    ligand_pos = x_pred[ligand_mask]
    protein_pos = x_pred[protein_mask]
    
    # Pairwise distances
    diff = ligand_pos[:, None, :] - protein_pos[None, :, :]
    distances = np.sqrt(np.sum(diff ** 2, axis=-1))
    
    return distances

In [None]:
def affinity_module(s, z, x_pred, ligand_mask, protein_mask, token_z=128):
    """
    Affinity Module - predicts binding affinity.
    
    Args:
        s: Single representation [N, c_s]
        z: Pair representation [N, N, c_z]
        x_pred: Predicted coordinates [N, 3]
        ligand_mask: Boolean mask for ligand tokens [N]
        protein_mask: Boolean mask for protein tokens [N]
        token_z: Pair dimension
    
    Returns:
        affinity_value: Predicted log10(IC50)
        affinity_binary: Binder probability
    """
    N, c_s = s.shape
    c_z = z.shape[-1]
    
    print(f"Affinity Module (Boltz-2)")
    print(f"="*50)
    print(f"Ligand tokens: {ligand_mask.sum()}, Protein tokens: {protein_mask.sum()}")
    
    # 1. Normalize pair representation
    z_norm = layer_norm(z)
    W_z = np.random.randn(c_z, token_z) * (c_z ** -0.5)
    z_proj = np.einsum('ijc,cd->ijd', z_norm, W_z)
    
    # 2. Add single-to-pair projections
    W_s1 = np.random.randn(c_s, token_z) * (c_s ** -0.5)
    W_s2 = np.random.randn(c_s, token_z) * (c_s ** -0.5)
    
    s_proj1 = s @ W_s1  # [N, token_z]
    s_proj2 = s @ W_s2
    
    z_enhanced = z_proj + s_proj1[:, None, :] + s_proj2[None, :, :]
    
    # 3. Compute interface distances and encode
    distances = compute_interface_distances(x_pred, ligand_mask, protein_mask)
    dist_smeared = gaussian_smearing(distances, stop=22.0, num_gaussians=64)
    
    print(f"Interface distances: {distances.shape}")
    print(f"  Min: {distances.min():.2f}A, Max: {distances.max():.2f}A")
    
    # 4. Extract interface pair features
    interface_z = z_enhanced[ligand_mask][:, protein_mask]
    
    # 5. Pool interface features
    interface_pooled = interface_z.mean(axis=(0, 1))
    
    # 6. Pool single features per entity
    ligand_s = s[ligand_mask].mean(axis=0)
    protein_s = s[protein_mask].mean(axis=0)
    
    # 7. Combine all features
    combined = np.concatenate([ligand_s, protein_s, interface_pooled])
    combined = layer_norm(combined)
    
    # 8. Predict affinity value (regression)
    W_val = np.random.randn(len(combined), 1) * (len(combined) ** -0.5)
    affinity_value = (combined @ W_val).item()
    
    # 9. Predict binder probability (classification)
    W_bin = np.random.randn(len(combined), 1) * (len(combined) ** -0.5)
    affinity_binary = sigmoid((combined @ W_bin).item())
    
    print(f"\nPredictions:")
    print(f"  Affinity (log10 IC50): {affinity_value:.3f}")
    print(f"  Binder probability: {affinity_binary:.3f}")
    
    return affinity_value, affinity_binary

In [None]:
# Test
print("Test: Affinity Module")
print("="*60)

N = 60  # Total tokens
N_lig = 15  # Ligand tokens
N_prot = 45  # Protein tokens
c_s = 128
c_z = 64

# Representations
s = np.random.randn(N, c_s)
z = np.random.randn(N, N, c_z)

# Predicted structure
x_pred = np.random.randn(N, 3) * 10

# Masks
ligand_mask = np.zeros(N, dtype=bool)
ligand_mask[:N_lig] = True
protein_mask = np.zeros(N, dtype=bool)
protein_mask[N_lig:] = True

aff_val, aff_bin = affinity_module(s, z, x_pred, ligand_mask, protein_mask)

## Key Insights

1. **Dual Output**: Regression (IC50) + Classification (binder)
2. **Interface Focus**: Aggregates ligand-protein interface features
3. **Distance Encoding**: Gaussian smearing for continuous distances
4. **1000x Faster**: Compared to physics-based FEP methods
5. **Drug Discovery Ready**: Practical for hit discovery and lead optimization