# Algorithm 22: Invariant Point Attention (IPA)

Invariant Point Attention is a key innovation in AlphaFold2's Structure Module. It enables attention between residues that is **equivariant** to global rotations and translations - the attention patterns depend only on relative positions in 3D space, not absolute coordinates.

## Algorithm Pseudocode

![Invariant Point Attention](../imgs/algorithms/InvariantPointAttention.png)

## Source Code Location
- **File**: `AF2-source-code/model/folding.py`
- **Class**: `InvariantPointAttention`
- **Lines**: 37-279

## Key Concepts

### SE(3) Equivariance
- The attention mechanism is invariant to global rotations and translations
- Each residue has a local reference frame (rotation + translation)
- Points are computed in local frames and transformed to global frame for distance computation

### Three Types of Attention Contributions
1. **Scalar attention**: Standard query-key dot product
2. **Point attention**: Euclidean distance between points in 3D space
3. **Pair bias**: Bias from pair representation

In [None]:
import numpy as np

np.random.seed(42)

## Rigid Body Transformations

In [None]:
def rotation_matrix_from_angles(angles):
    """Create rotation matrix from Euler angles (simplified)."""
    # Using quaternion-like initialization for simplicity
    theta = np.linalg.norm(angles)
    if theta < 1e-6:
        return np.eye(3)
    axis = angles / theta
    K = np.array([[0, -axis[2], axis[1]],
                  [axis[2], 0, -axis[0]],
                  [-axis[1], axis[0], 0]])
    return np.eye(3) + np.sin(theta) * K + (1 - np.cos(theta)) * K @ K


def apply_rigid_transform(points, rotation, translation):
    """Apply rigid transformation: R @ points + t"""
    return np.einsum('...ij,...j->...i', rotation, points) + translation


def invert_rigid_transform(points, rotation, translation):
    """Apply inverse transformation: R^T @ (points - t)"""
    return np.einsum('...ij,...j->...i', rotation.swapaxes(-1, -2), points - translation)

## NumPy Implementation of IPA

In [None]:
def softmax(x, axis=-1):
    """Numerically stable softmax."""
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)


def invariant_point_attention(
    inputs_1d,       # [N_res, c_s] single representation
    inputs_2d,       # [N_res, N_res, c_z] pair representation
    rotations,       # [N_res, 3, 3] rotation matrices
    translations,    # [N_res, 3] translation vectors
    mask,            # [N_res] residue mask
    num_head=12,
    num_scalar_qk=16,
    num_scalar_v=16,
    num_point_qk=4,
    num_point_v=8
):
    """
    Invariant Point Attention.
    
    Algorithm 22 from AlphaFold2 supplementary materials.
    
    Args:
        inputs_1d: Single representation [N_res, c_s]
        inputs_2d: Pair representation [N_res, N_res, c_z]
        rotations: Rotation matrices for each residue [N_res, 3, 3]
        translations: Translation vectors for each residue [N_res, 3]
        mask: Residue mask [N_res]
    
    Returns:
        Updated single representation [N_res, c_s]
    """
    N_res, c_s = inputs_1d.shape
    c_z = inputs_2d.shape[-1]
    
    print(f"Input shapes: 1D={inputs_1d.shape}, 2D={inputs_2d.shape}")
    
    # ========== Step 1: Compute scalar queries, keys, values ==========
    # (Lines 2-4 in algorithm)
    q_scalar_w = np.random.randn(c_s, num_head * num_scalar_qk) * 0.01
    k_scalar_w = np.random.randn(c_s, num_head * num_scalar_qk) * 0.01
    v_scalar_w = np.random.randn(c_s, num_head * num_scalar_v) * 0.01
    
    q_scalar = (inputs_1d @ q_scalar_w).reshape(N_res, num_head, num_scalar_qk)
    k_scalar = (inputs_1d @ k_scalar_w).reshape(N_res, num_head, num_scalar_qk)
    v_scalar = (inputs_1d @ v_scalar_w).reshape(N_res, num_head, num_scalar_v)
    
    print(f"Scalar Q/K/V shapes: {q_scalar.shape}")
    
    # ========== Step 2: Compute point queries, keys, values ==========
    # (Lines 5-7 in algorithm)
    # Points are generated in local frame and transformed to global frame
    
    q_point_local_w = np.random.randn(c_s, num_head * 3 * num_point_qk) * 0.01
    kv_point_local_w = np.random.randn(c_s, num_head * 3 * (num_point_qk + num_point_v)) * 0.01
    
    # Generate local points
    q_point_local = (inputs_1d @ q_point_local_w).reshape(N_res, num_head, num_point_qk, 3)
    kv_point_local = (inputs_1d @ kv_point_local_w).reshape(N_res, num_head, num_point_qk + num_point_v, 3)
    k_point_local = kv_point_local[:, :, :num_point_qk, :]
    v_point_local = kv_point_local[:, :, num_point_qk:, :]
    
    # Transform points to global frame
    # q_point_global[i] = R[i] @ q_point_local[i] + t[i]
    q_point_global = np.zeros_like(q_point_local)
    k_point_global = np.zeros_like(k_point_local)
    v_point_global = np.zeros_like(v_point_local)
    
    for i in range(N_res):
        for h in range(num_head):
            for p in range(num_point_qk):
                q_point_global[i, h, p] = rotations[i] @ q_point_local[i, h, p] + translations[i]
                k_point_global[i, h, p] = rotations[i] @ k_point_local[i, h, p] + translations[i]
            for p in range(num_point_v):
                v_point_global[i, h, p] = rotations[i] @ v_point_local[i, h, p] + translations[i]
    
    print(f"Point Q/K shapes: {q_point_global.shape}")
    
    # ========== Step 3: Compute attention logits ==========
    # (Lines 8-12 in algorithm)
    
    # 3.1: Scalar attention
    scalar_variance = max(num_scalar_qk, 1) * 1.0
    scalar_weights = np.sqrt(1.0 / (3.0 * scalar_variance))
    
    # [N_res, num_head, num_scalar_qk] x [N_res, num_head, num_scalar_qk] -> [num_head, N_res, N_res]
    attn_scalar = scalar_weights * np.einsum('qhc,khc->hqk', q_scalar, k_scalar)
    
    print(f"Scalar attention shape: {attn_scalar.shape}")
    
    # 3.2: Point attention (squared distance in global frame)
    point_variance = max(num_point_qk, 1) * 9.0 / 2.0
    point_weights = np.sqrt(1.0 / (3.0 * point_variance))
    
    # Trainable per-head point weights
    trainable_point_weights = np.ones(num_head)  # Simplified, actual uses softplus
    
    # Compute squared distances between query and key points
    # [num_head, N_res, N_res, num_point_qk]
    q_point_t = q_point_global.transpose(1, 0, 2, 3)  # [num_head, N_res, num_point_qk, 3]
    k_point_t = k_point_global.transpose(1, 0, 2, 3)
    
    dist2 = np.sum(
        (q_point_t[:, :, None, :, :] - k_point_t[:, None, :, :, :]) ** 2,
        axis=-1
    )  # [num_head, N_res, N_res, num_point_qk]
    
    attn_point = -0.5 * point_weights * trainable_point_weights[:, None, None, None] * dist2
    attn_point = np.sum(attn_point, axis=-1)  # [num_head, N_res, N_res]
    
    print(f"Point attention shape: {attn_point.shape}")
    
    # 3.3: Pair bias
    attention_2d_w = np.random.randn(c_z, num_head) * 0.01
    attention_2d = np.einsum('ijc,ch->hij', inputs_2d, attention_2d_w)
    attention_2d *= np.sqrt(1.0 / 3.0)
    
    print(f"Pair bias shape: {attention_2d.shape}")
    
    # 3.4: Combine all attention terms
    attn_logits = attn_scalar + attn_point + attention_2d
    
    # Apply mask
    mask_2d = mask[:, None] * mask[None, :]  # [N_res, N_res]
    attn_logits = attn_logits - 1e5 * (1.0 - mask_2d)
    
    # Softmax
    attn = softmax(attn_logits, axis=-1)  # [num_head, N_res, N_res]
    
    print(f"Attention weights shape: {attn.shape}")
    
    # ========== Step 4: Compute outputs ==========
    # (Lines 13-16 in algorithm)
    
    # 4.1: Scalar output
    v_scalar_t = v_scalar.transpose(1, 0, 2)  # [num_head, N_res, num_scalar_v]
    result_scalar = np.einsum('hqk,hkc->hqc', attn, v_scalar_t)  # [num_head, N_res, num_scalar_v]
    result_scalar = result_scalar.transpose(1, 0, 2).reshape(N_res, num_head * num_scalar_v)
    
    print(f"Scalar output shape: {result_scalar.shape}")
    
    # 4.2: Point output (weighted sum of value points in global frame)
    v_point_t = v_point_global.transpose(1, 0, 2, 3)  # [num_head, N_res, num_point_v, 3]
    result_point_global = np.einsum('hqk,hkpc->hqpc', attn, v_point_t)
    result_point_global = result_point_global.transpose(1, 0, 2, 3)  # [N_res, num_head, num_point_v, 3]
    
    # Transform back to local frame
    result_point_local = np.zeros_like(result_point_global)
    for i in range(N_res):
        for h in range(num_head):
            for p in range(num_point_v):
                result_point_local[i, h, p] = rotations[i].T @ (result_point_global[i, h, p] - translations[i])
    
    result_point_local = result_point_local.reshape(N_res, num_head * num_point_v * 3)
    
    # Point norms (distance from origin in local frame)
    result_point_norm = np.sqrt(
        np.sum(result_point_local.reshape(N_res, num_head * num_point_v, 3) ** 2, axis=-1) + 1e-8
    ).reshape(N_res, num_head * num_point_v)
    
    print(f"Point output shape: {result_point_local.shape}")
    
    # 4.3: Pair output
    result_pair = np.einsum('hqk,qkc->qhc', attn, inputs_2d)
    result_pair = result_pair.reshape(N_res, num_head * c_z)
    
    print(f"Pair output shape: {result_pair.shape}")
    
    # ========== Step 5: Concatenate and project ==========
    output_features = np.concatenate([
        result_scalar,
        result_point_local,
        result_point_norm,
        result_pair
    ], axis=-1)
    
    print(f"Concatenated features shape: {output_features.shape}")
    
    # Final projection
    output_w = np.random.randn(output_features.shape[-1], c_s) * 0.01
    output = output_features @ output_w
    
    print(f"Final output shape: {output.shape}")
    
    return output

## Test Example

In [None]:
# Test parameters
N_res = 16     # Number of residues
c_s = 384      # Single representation dimension
c_z = 128      # Pair representation dimension

# Create test inputs
inputs_1d = np.random.randn(N_res, c_s).astype(np.float32)
inputs_2d = np.random.randn(N_res, N_res, c_z).astype(np.float32)

# Create random rigid body transformations for each residue
rotations = np.array([rotation_matrix_from_angles(np.random.randn(3) * 0.1) for _ in range(N_res)])
translations = np.random.randn(N_res, 3) * 5  # Random positions in Angstroms

# Mask (all valid)
mask = np.ones(N_res, dtype=np.float32)

print(f"Number of residues: {N_res}")
print(f"Single representation dim: {c_s}")
print(f"Pair representation dim: {c_z}")
print(f"Rotation matrices shape: {rotations.shape}")
print(f"Translations shape: {translations.shape}")
print()

In [None]:
# Run IPA
output = invariant_point_attention(
    inputs_1d,
    inputs_2d,
    rotations,
    translations,
    mask,
    num_head=8,
    num_scalar_qk=16,
    num_scalar_v=16,
    num_point_qk=4,
    num_point_v=8
)

print(f"\nOutput statistics: mean={output.mean():.6f}, std={output.std():.6f}")

## Verify SE(3) Invariance

Apply a global rotation and translation to all residues. The attention patterns should remain the same.

In [None]:
# Global transformation
global_rotation = rotation_matrix_from_angles(np.array([0.5, 0.3, 0.7]))
global_translation = np.array([10.0, 20.0, 30.0])

# Apply to all residue frames
rotations_transformed = np.array([global_rotation @ r for r in rotations])
translations_transformed = np.array([global_rotation @ t + global_translation for t in translations])

# Run IPA with transformed frames
output_transformed = invariant_point_attention(
    inputs_1d,
    inputs_2d,
    rotations_transformed,
    translations_transformed,
    mask,
    num_head=8,
    num_scalar_qk=16,
    num_scalar_v=16,
    num_point_qk=4,
    num_point_v=8
)

# The outputs should be similar (not exactly equal due to point outputs being in local frames)
print(f"Original output mean: {output.mean():.6f}")
print(f"Transformed output mean: {output_transformed.mean():.6f}")
print(f"Difference norm: {np.linalg.norm(output - output_transformed):.6f}")

## Source Code Reference

```python
# From AF2-source-code/model/folding.py

class InvariantPointAttention(hk.Module):
  """Invariant Point attention module.

  Jumper et al. (2021) Suppl. Alg. 22 "InvariantPointAttention"
  """

  def __call__(self, inputs_1d, inputs_2d, mask, affine):
    # Scalar queries/keys/values
    q_scalar = Linear(num_head * num_scalar_qk)(inputs_1d)
    k_scalar, v_scalar = split(Linear(...)(inputs_1d))
    
    # Point queries/keys/values in local frame
    q_point_local = Linear(num_head * 3 * num_point_qk)(inputs_1d)
    # Transform to global frame
    q_point_global = affine.apply_to_point(q_point_local, extra_dims=1)
    
    # Compute attention from:
    # 1. Scalar dot product
    # 2. Point distance (invariant to global transform)
    # 3. Pair bias
    attn_logits = attn_scalar + attn_point + attention_2d
    attn = softmax(attn_logits)
    
    # Aggregate values and project
    result = concat([scalar_result, point_result, pair_result])
    return output_projection(result)
```

## Key Insights

1. **SE(3) Invariance**: Point attention uses squared distances, which are invariant to global rotations/translations.

2. **Three Attention Sources**: Scalar (sequence), point (3D geometry), and pair (pairwise features) attention are combined.

3. **Local/Global Frames**: Points are generated in local frames but distances are computed in global frame for invariance.

4. **Value Transformation**: Point values are transformed back to local frames for equivariant output.