# Algorithm 10: B-Factor Prediction (Boltz-2)

Predicts crystallographic B-factors (temperature factors).

## Source Code Location
- **File**: `Boltz-Ref-src/boltz-official/src/boltz/model/modules/trunkv2.py`
- **Class**: `BFactorModule`
- **Loss**: `model/loss/bfactor.py`

## Overview

B-factors (temperature factors) describe atomic displacement from mean positions.
- High B-factor = flexible/disordered region
- Low B-factor = rigid/ordered region

This is useful for:
- Drug binding site analysis
- Protein dynamics
- Structure quality assessment

In [None]:
import numpy as np
np.random.seed(42)

def layer_norm(x, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

In [None]:
def bfactor_head(s, num_bins=64, max_bfactor=100.0):
    """
    B-Factor prediction head.
    
    Args:
        s: Single representation [N, c_s]
        num_bins: Number of B-factor bins
        max_bfactor: Maximum B-factor value
    
    Returns:
        B-factor logits and predictions
    """
    N, c_s = s.shape
    
    print(f"B-Factor Head")
    print(f"="*50)
    
    s_norm = layer_norm(s)
    
    # MLP
    W1 = np.random.randn(c_s, c_s) * (c_s ** -0.5)
    h = np.maximum(0, s_norm @ W1)
    h = layer_norm(h)
    
    W2 = np.random.randn(c_s, num_bins) * (c_s ** -0.5)
    logits = h @ W2
    
    # Expected B-factor
    probs = np.exp(logits - logits.max(axis=-1, keepdims=True))
    probs = probs / probs.sum(axis=-1, keepdims=True)
    
    bin_centers = np.linspace(0, max_bfactor, num_bins)
    bfactor = np.sum(probs * bin_centers, axis=-1)
    
    print(f"B-factor range: [{bfactor.min():.1f}, {bfactor.max():.1f}]")
    
    return logits, bfactor

In [None]:
def bfactor_loss(pred_bfactor, true_bfactor, mask=None):
    """
    B-Factor prediction loss (MSE).
    
    Args:
        pred_bfactor: Predicted B-factors [N]
        true_bfactor: True B-factors [N]
        mask: Valid residue mask [N]
    
    Returns:
        MSE loss
    """
    if mask is None:
        mask = np.ones_like(pred_bfactor)
    
    # Normalize B-factors (common practice)
    pred_norm = (pred_bfactor - pred_bfactor.mean()) / (pred_bfactor.std() + 1e-8)
    true_norm = (true_bfactor - true_bfactor.mean()) / (true_bfactor.std() + 1e-8)
    
    # MSE
    mse = ((pred_norm - true_norm) ** 2 * mask).sum() / (mask.sum() + 1e-8)
    
    return mse

In [None]:
# Test
print("Test: B-Factor Prediction")
print("="*60)

N = 50
c_s = 128

s = np.random.randn(N, c_s)

logits, pred_bfactor = bfactor_head(s)

# Simulate ground truth (higher B-factors at termini)
true_bfactor = 30 + 20 * np.exp(-((np.arange(N) - N//2) / (N/4)) ** 2)
true_bfactor[:5] += 20  # Higher at N-terminus
true_bfactor[-5:] += 20  # Higher at C-terminus

loss = bfactor_loss(pred_bfactor, true_bfactor)
print(f"\nLoss: {loss:.4f}")

# Correlation
corr = np.corrcoef(pred_bfactor, true_bfactor)[0, 1]
print(f"Correlation: {corr:.3f}")

## Key Insights

1. **Flexibility Prediction**: Predicts atomic displacement
2. **Drug Design**: Important for binding site flexibility
3. **Normalized Loss**: B-factors are normalized before comparison
4. **Binned Prediction**: Classification over B-factor bins