# Algorithm 28: computeFAPE (Frame Aligned Point Error)

FAPE is the primary structural loss function in AlphaFold2. It measures the error between predicted and ground truth atom positions in a way that is invariant to global rotations and translations, by aligning each residue's local frame.

## Algorithm Pseudocode

![computeFAPE](../imgs/algorithms/computeFAPE.png)

## Source Code Location
- **File**: `AF2-source-code/model/all_atom.py`
- **Function**: `frame_aligned_point_error`
- **Lines**: 1025-1100

## Key Concepts

### Why FAPE?

Traditional RMSD measures global alignment error. FAPE instead:
1. For each residue frame, compute the error of ALL atoms when aligned to that frame
2. This captures both global structure and local geometry
3. Invariant to global rotation/translation

### Mathematical Definition

$$\text{FAPE} = \frac{1}{N_{frames} \cdot N_{atoms}} \sum_i \sum_j \min\left(\|T_i^{-1} \cdot x_j^{pred} - T_i^{-1} \cdot x_j^{true}\|, d_{clamp}\right)$$

Where:
- $T_i$: Rigid transformation of frame $i$
- $x_j$: Position of atom $j$
- $d_{clamp}$: Clamping distance (10 Å for backbone, varies for sidechains)

In [None]:
import numpy as np

np.random.seed(42)

## NumPy Implementation

In [None]:
def apply_inverse_frame(rotation, translation, points):
    """
    Apply inverse rigid transformation: R^T @ (points - t)
    
    Args:
        rotation: [N_frames, 3, 3] rotation matrices
        translation: [N_frames, 3] translation vectors
        points: [N_points, 3] points to transform
    
    Returns:
        Transformed points [N_frames, N_points, 3]
    """
    # Subtract translation: [N_points, 3] - [N_frames, 1, 3] -> [N_frames, N_points, 3]
    centered = points[None, :, :] - translation[:, None, :]
    
    # Apply inverse rotation: R^T @ centered
    # [N_frames, 3, 3]^T @ [N_frames, N_points, 3] -> [N_frames, N_points, 3]
    return np.einsum('fji,fpj->fpi', rotation, centered)


def compute_fape(
    pred_frames_rot,      # [N_res, 3, 3] predicted frame rotations
    pred_frames_trans,    # [N_res, 3] predicted frame translations
    target_frames_rot,    # [N_res, 3, 3] ground truth frame rotations
    target_frames_trans,  # [N_res, 3] ground truth frame translations
    pred_positions,       # [N_atoms, 3] predicted atom positions
    target_positions,     # [N_atoms, 3] ground truth atom positions
    frames_mask,          # [N_res] mask for valid frames
    positions_mask,       # [N_atoms] mask for valid atoms
    length_scale=10.0,    # Scale factor for loss
    l1_clamp_distance=10.0  # Clamping distance
):
    """
    Compute Frame Aligned Point Error (FAPE).
    
    Algorithm 28 from AlphaFold2 supplementary materials.
    
    Args:
        pred_frames_rot: Predicted frame rotations [N_res, 3, 3]
        pred_frames_trans: Predicted frame translations [N_res, 3]
        target_frames_rot: Ground truth frame rotations [N_res, 3, 3]
        target_frames_trans: Ground truth frame translations [N_res, 3]
        pred_positions: Predicted atom positions [N_atoms, 3]
        target_positions: Ground truth atom positions [N_atoms, 3]
        frames_mask: Mask for valid frames [N_res]
        positions_mask: Mask for valid atoms [N_atoms]
        length_scale: Scale factor for final loss
        l1_clamp_distance: Maximum distance for clamping
    
    Returns:
        FAPE loss (scalar)
    """
    N_frames = pred_frames_rot.shape[0]
    N_atoms = pred_positions.shape[0]
    
    print(f"Computing FAPE: {N_frames} frames, {N_atoms} atoms")
    
    # Step 1: Transform predicted positions to each predicted frame's local coordinates
    # Line 1: For each frame i, compute T_i^{-1} @ x_j for all atoms j
    pred_local = apply_inverse_frame(
        pred_frames_rot, pred_frames_trans, pred_positions
    )  # [N_frames, N_atoms, 3]
    
    print(f"Predicted positions in local frames: {pred_local.shape}")
    
    # Step 2: Transform target positions to each target frame's local coordinates
    # Line 2: For each frame i, compute T_i^{-1} @ x_j^{true} for all atoms j
    target_local = apply_inverse_frame(
        target_frames_rot, target_frames_trans, target_positions
    )  # [N_frames, N_atoms, 3]
    
    print(f"Target positions in local frames: {target_local.shape}")
    
    # Step 3: Compute distances between corresponding local positions
    # Line 3: d_ij = ||pred_local_ij - target_local_ij||
    distances = np.sqrt(
        np.sum((pred_local - target_local) ** 2, axis=-1) + 1e-8
    )  # [N_frames, N_atoms]
    
    print(f"Distances: {distances.shape}")
    print(f"  Mean distance: {distances.mean():.3f} Å")
    print(f"  Max distance: {distances.max():.3f} Å")
    
    # Step 4: Clamp distances (Line 4)
    if l1_clamp_distance is not None:
        distances = np.minimum(distances, l1_clamp_distance)
        print(f"  After clamping (max {l1_clamp_distance} Å): max = {distances.max():.3f} Å")
    
    # Step 5: Apply masks
    # Create combined mask: [N_frames, N_atoms]
    mask = frames_mask[:, None] * positions_mask[None, :]
    
    # Step 6: Compute weighted average (Lines 5-6)
    masked_distances = distances * mask
    fape = np.sum(masked_distances) / (np.sum(mask) + 1e-8)
    
    # Scale by length_scale
    fape = fape / length_scale
    
    print(f"\nFinal FAPE loss: {fape:.6f}")
    
    return fape

## Test Example

In [None]:
# Generate test data
N_res = 20  # Number of residues (frames)
N_atoms = N_res * 3  # 3 backbone atoms per residue (N, CA, C)

# Generate "ground truth" structure - a simple helix
def generate_helix(n_residues, rise_per_residue=1.5, radius=2.3, twist=100):
    """Generate an idealized alpha helix."""
    positions = []
    frames_rot = []
    frames_trans = []
    
    for i in range(n_residues):
        angle = np.radians(twist * i)
        z = rise_per_residue * i
        
        # CA position (on helix axis + radius)
        ca = np.array([radius * np.cos(angle), radius * np.sin(angle), z])
        
        # N and C positions (simplified)
        n = ca + np.array([-1.46, 0, -0.5])
        c = ca + np.array([1.0, 0.5, 0.5])
        
        positions.extend([n, ca, c])
        
        # Simple frame (identity rotation, CA as translation)
        frames_rot.append(np.eye(3))
        frames_trans.append(ca)
    
    return np.array(positions), np.array(frames_rot), np.array(frames_trans)

# Ground truth structure
target_positions, target_frames_rot, target_frames_trans = generate_helix(N_res)

print(f"Ground truth structure:")
print(f"  Positions: {target_positions.shape}")
print(f"  Frames: {target_frames_rot.shape}, {target_frames_trans.shape}")

In [None]:
# Create predictions with varying levels of error

# Case 1: Perfect prediction (FAPE should be ~0)
print("="*50)
print("Case 1: Perfect prediction")
print("="*50)

pred_positions = target_positions.copy()
pred_frames_rot = target_frames_rot.copy()
pred_frames_trans = target_frames_trans.copy()

frames_mask = np.ones(N_res)
positions_mask = np.ones(N_atoms)

fape_perfect = compute_fape(
    pred_frames_rot, pred_frames_trans,
    target_frames_rot, target_frames_trans,
    pred_positions, target_positions,
    frames_mask, positions_mask
)

In [None]:
# Case 2: Small random perturbation
print("\n" + "="*50)
print("Case 2: Small perturbation (0.5 Å noise)")
print("="*50)

noise_scale = 0.5
pred_positions = target_positions + np.random.randn(*target_positions.shape) * noise_scale
pred_frames_trans = target_frames_trans + np.random.randn(*target_frames_trans.shape) * noise_scale

fape_small = compute_fape(
    pred_frames_rot, pred_frames_trans,
    target_frames_rot, target_frames_trans,
    pred_positions, target_positions,
    frames_mask, positions_mask
)

In [None]:
# Case 3: Larger perturbation
print("\n" + "="*50)
print("Case 3: Larger perturbation (2.0 Å noise)")
print("="*50)

noise_scale = 2.0
pred_positions = target_positions + np.random.randn(*target_positions.shape) * noise_scale
pred_frames_trans = target_frames_trans + np.random.randn(*target_frames_trans.shape) * noise_scale

fape_large = compute_fape(
    pred_frames_rot, pred_frames_trans,
    target_frames_rot, target_frames_trans,
    pred_positions, target_positions,
    frames_mask, positions_mask
)

In [None]:
# Case 4: Very large errors (to test clamping)
print("\n" + "="*50)
print("Case 4: Large errors (20 Å noise) - tests clamping")
print("="*50)

noise_scale = 20.0
pred_positions = target_positions + np.random.randn(*target_positions.shape) * noise_scale
pred_frames_trans = target_frames_trans + np.random.randn(*target_frames_trans.shape) * noise_scale

fape_clamped = compute_fape(
    pred_frames_rot, pred_frames_trans,
    target_frames_rot, target_frames_trans,
    pred_positions, target_positions,
    frames_mask, positions_mask,
    l1_clamp_distance=10.0  # Clamp at 10 Å
)

## FAPE Summary

In [None]:
print("\n" + "="*50)
print("FAPE Loss Summary")
print("="*50)
print(f"Perfect prediction:     {fape_perfect:.6f}")
print(f"Small noise (0.5 Å):    {fape_small:.6f}")
print(f"Large noise (2.0 Å):    {fape_large:.6f}")
print(f"Very large (clamped):   {fape_clamped:.6f}")

## SE(3) Invariance Test

In [None]:
# FAPE should be invariant to global rotation/translation

# Apply random global transformation
def random_rotation():
    """Generate random rotation matrix."""
    theta = np.random.randn(3) * 0.5
    norm = np.linalg.norm(theta)
    if norm < 1e-6:
        return np.eye(3)
    axis = theta / norm
    K = np.array([[0, -axis[2], axis[1]],
                  [axis[2], 0, -axis[0]],
                  [-axis[1], axis[0], 0]])
    return np.eye(3) + np.sin(norm) * K + (1 - np.cos(norm)) * K @ K

global_rot = random_rotation()
global_trans = np.random.randn(3) * 10

# Transform predictions
pred_positions_transformed = (global_rot @ pred_positions.T).T + global_trans
pred_frames_trans_transformed = (global_rot @ pred_frames_trans.T).T + global_trans
pred_frames_rot_transformed = np.array([global_rot @ r for r in pred_frames_rot])

# Also transform targets
target_positions_transformed = (global_rot @ target_positions.T).T + global_trans
target_frames_trans_transformed = (global_rot @ target_frames_trans.T).T + global_trans
target_frames_rot_transformed = np.array([global_rot @ r for r in target_frames_rot])

print("Testing SE(3) invariance...")
print(f"Applied rotation:\n{global_rot}")
print(f"Applied translation: {global_trans}")

fape_transformed = compute_fape(
    pred_frames_rot_transformed, pred_frames_trans_transformed,
    target_frames_rot_transformed, target_frames_trans_transformed,
    pred_positions_transformed, target_positions_transformed,
    frames_mask, positions_mask
)

print(f"\nOriginal FAPE:    {fape_clamped:.6f}")
print(f"Transformed FAPE: {fape_transformed:.6f}")
print(f"Difference: {abs(fape_clamped - fape_transformed):.9f}")

## Source Code Reference

```python
# From AF2-source-code/model/all_atom.py

def frame_aligned_point_error(
    pred_frames: r3.Rigids,
    target_frames: r3.Rigids,
    frames_mask: jnp.ndarray,
    pred_positions: r3.Vecs,
    target_positions: r3.Vecs,
    positions_mask: jnp.ndarray,
    l1_clamp_distance: Optional[float] = None,
    length_scale: float = 1.0,
) -> jnp.ndarray:
  """Measure point error under different alignments.

  Jumper et al. (2021) Suppl. Alg. 28 "computeFAPE"
  """
  # Transform predicted and target positions to local frames
  local_pred_pos = pred_frames.apply_inverse(pred_positions)
  local_target_pos = target_frames.apply_inverse(target_positions)

  # Compute L2 distance
  error_dist = jnp.sqrt(
      r3.vecs_squared_distance(local_pred_pos, local_target_pos) + 1e-8)

  # Clamp
  if l1_clamp_distance:
    error_dist = jnp.minimum(error_dist, l1_clamp_distance)

  # Normalize and return
  return jnp.sum(error_dist * mask) / (jnp.sum(mask) + 1e-8) / length_scale
```

## Key Insights

1. **SE(3) Invariance**: FAPE is invariant to global rotation and translation because it computes errors in local frames.

2. **All-to-All**: Each frame evaluates the error of ALL atoms, not just atoms in that residue. This captures long-range structure.

3. **Clamping**: Distance clamping prevents outliers from dominating the loss. Default is 10 Å for backbone.

4. **Scale Factor**: The length_scale (default 10 Å) normalizes the loss to a reasonable range.