# Algorithm 23: Backbone Update

The Backbone Update converts the single representation to updates for the backbone frames (rigid body transformations). This is how the model iteratively refines the 3D structure.

## Algorithm Pseudocode

![BackboneUpdate](../imgs/algorithms/BackboneUpdate.png)

## Source Code Location
- **File**: `AF2-source-code/model/folding.py`
- **Function**: `FoldIteration` (within Structure Module)
- **Lines**: 263-280

## Overview

The Backbone Update:

1. Takes the single representation as input
2. Predicts quaternion rotation and translation updates
3. Composes updates with current backbone frames

### Frame Representation

Each residue's backbone is represented as a rigid frame (rotation + translation):
- **R**: 3×3 rotation matrix
- **t**: 3D translation vector

The update is predicted as a 6D vector:
- Quaternion-like rotation: 4 values (converted to rotation)
- Translation: 3 values

In [None]:
import numpy as np

np.random.seed(42)

## NumPy Implementation

In [None]:
def quaternion_to_rotation_matrix(q):
    """
    Convert quaternion to rotation matrix.
    
    Args:
        q: Quaternion [w, x, y, z] with shape [..., 4]
    
    Returns:
        R: Rotation matrix [..., 3, 3]
    """
    # Normalize quaternion
    q = q / (np.linalg.norm(q, axis=-1, keepdims=True) + 1e-8)
    
    w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
    
    # Build rotation matrix
    R = np.zeros(q.shape[:-1] + (3, 3))
    
    R[..., 0, 0] = 1 - 2*(y**2 + z**2)
    R[..., 0, 1] = 2*(x*y - w*z)
    R[..., 0, 2] = 2*(x*z + w*y)
    
    R[..., 1, 0] = 2*(x*y + w*z)
    R[..., 1, 1] = 1 - 2*(x**2 + z**2)
    R[..., 1, 2] = 2*(y*z - w*x)
    
    R[..., 2, 0] = 2*(x*z - w*y)
    R[..., 2, 1] = 2*(y*z + w*x)
    R[..., 2, 2] = 1 - 2*(x**2 + y**2)
    
    return R


def compose_frames(R1, t1, R2, t2):
    """
    Compose two rigid transformations.
    
    T_composed = T1 ∘ T2
    R_composed = R1 @ R2
    t_composed = R1 @ t2 + t1
    
    Args:
        R1, t1: First transformation
        R2, t2: Second transformation (applied first)
    
    Returns:
        R_composed, t_composed
    """
    R_composed = np.einsum('nij,njk->nik', R1, R2)
    t_composed = np.einsum('nij,nj->ni', R1, t2) + t1
    
    return R_composed, t_composed

In [None]:
def backbone_update(s, R_curr, t_curr, position_scale=10.0):
    """
    Backbone Update - Algorithm 23.
    
    Predicts and applies frame updates from single representation.
    
    Args:
        s: Single representation [N_res, c_s]
        R_curr: Current rotation matrices [N_res, 3, 3]
        t_curr: Current translations [N_res, 3]
        position_scale: Scale factor for translation predictions
    
    Returns:
        R_new: Updated rotation matrices [N_res, 3, 3]
        t_new: Updated translations [N_res, 3]
    """
    N_res, c_s = s.shape
    
    print(f"Backbone Update")
    print(f"="*50)
    print(f"Single representation: [{N_res}, {c_s}]")
    print(f"Current frames: R {R_curr.shape}, t {t_curr.shape}")
    
    # Step 1: Linear projection to 6D update (Line 1)
    # Output: [quaternion (4), translation (3)] but we use 6 for simplicity
    W = np.zeros((c_s, 6))  # Initialized to zeros in AF2
    b = np.zeros(6)
    
    # Small random perturbation for demo
    updates = s @ W + b + np.random.randn(N_res, 6) * 0.01
    
    # Apply position scale
    updates = updates * position_scale
    
    print(f"\nStep 1 - Raw updates: {updates.shape}")
    
    # Step 2: Split into rotation and translation (Lines 2-3)
    # First 3 values: rotation (as axis-angle or quaternion derivative)
    # Last 3 values: translation
    rot_update = updates[:, :3] * 0.1  # Small rotation updates
    trans_update = updates[:, 3:]       # Translation updates
    
    print(f"Step 2 - Rotation update norm: {np.linalg.norm(rot_update):.4f}")
    print(f"         Translation update norm: {np.linalg.norm(trans_update):.4f}")
    
    # Step 3: Convert rotation update to quaternion (Line 4)
    # Use (1, δx, δy, δz) quaternion for small rotations
    quat = np.zeros((N_res, 4))
    quat[:, 0] = 1.0  # w = 1 (identity quaternion base)
    quat[:, 1:] = rot_update  # Small rotation components
    
    # Convert to rotation matrix
    delta_R = quaternion_to_rotation_matrix(quat)
    
    print(f"Step 3 - Delta rotation shape: {delta_R.shape}")
    
    # Step 4: Compose with current frame (Lines 5-6)
    # New frame = Current frame ∘ Delta frame
    # This applies the update in the local frame
    R_new, t_new = compose_frames(R_curr, t_curr, delta_R, trans_update)
    
    print(f"Step 4 - New frames: R {R_new.shape}, t {t_new.shape}")
    
    # Verify rotation matrices are valid
    det_mean = np.abs(np.linalg.det(R_new)).mean()
    ortho_error = np.abs(np.einsum('nij,nkj->nik', R_new, R_new) - np.eye(3)).mean()
    print(f"\nVerification:")
    print(f"  Mean |det(R)|: {det_mean:.6f} (should be 1)")
    print(f"  Orthogonality error: {ortho_error:.6f} (should be ~0)")
    
    return R_new, t_new

## Test Examples

In [None]:
# Test 1: Basic functionality
print("Test 1: Basic Functionality")
print("="*60)

N_res, c_s = 32, 384

# Single representation
s = np.random.randn(N_res, c_s).astype(np.float32)

# Initialize frames as identity
R_init = np.tile(np.eye(3), (N_res, 1, 1))
t_init = np.zeros((N_res, 3))

R_new, t_new = backbone_update(s, R_init, t_init, position_scale=10.0)

print(f"\nOutput shapes: R={R_new.shape}, t={t_new.shape}")

In [None]:
# Test 2: Iterative refinement
print("\nTest 2: Iterative Refinement")
print("="*60)

N_res, c_s = 16, 128
num_iterations = 8

s = np.random.randn(N_res, c_s).astype(np.float32)
R = np.tile(np.eye(3), (N_res, 1, 1))
t = np.zeros((N_res, 3))

print(f"Simulating {num_iterations} structure module iterations:")

for i in range(num_iterations):
    R, t = backbone_update(s, R, t, position_scale=10.0)
    
    # Compute CA positions (assuming CA at origin of each frame)
    ca_positions = t  # In the frame, CA is at origin
    ca_dist = np.linalg.norm(np.diff(ca_positions, axis=0), axis=1)
    
    print(f"\nIteration {i+1}:")
    print(f"  t range: [{t.min():.2f}, {t.max():.2f}]")
    print(f"  Mean CA-CA distance: {ca_dist.mean():.2f}Å")

In [None]:
# Test 3: Quaternion to rotation matrix
print("\nTest 3: Quaternion Validation")
print("="*60)

# Test identity quaternion
q_identity = np.array([[1, 0, 0, 0]])
R_identity = quaternion_to_rotation_matrix(q_identity)
print(f"Identity quaternion [1,0,0,0]:")
print(f"  R = Identity: {np.allclose(R_identity[0], np.eye(3))}")

# Test 90-degree rotation around z-axis
# q = (cos(45°), 0, 0, sin(45°))
angle = np.pi / 2
q_90z = np.array([[np.cos(angle/2), 0, 0, np.sin(angle/2)]])
R_90z = quaternion_to_rotation_matrix(q_90z)
print(f"\n90° rotation around z:")
print(f"  R[0,1] ≈ -1: {np.allclose(R_90z[0, 0, 1], -1, atol=0.01)}")
print(f"  R[1,0] ≈ 1: {np.allclose(R_90z[0, 1, 0], 1, atol=0.01)}")

In [None]:
# Test 4: Frame composition
print("\nTest 4: Frame Composition")
print("="*60)

N = 4

# Create two sets of random frames
R1 = np.tile(np.eye(3), (N, 1, 1))
t1 = np.random.randn(N, 3)

R2 = np.tile(np.eye(3), (N, 1, 1))
t2 = np.random.randn(N, 3)

R_composed, t_composed = compose_frames(R1, t1, R2, t2)

# For identity rotations, composition is just addition of translations
print(f"With identity rotations:")
print(f"  t_composed = t1 + t2: {np.allclose(t_composed, t1 + t2)}")

## Verification: Key Properties

In [None]:
print("Verification: Key Properties")
print("="*60)

np.random.seed(42)
N_res, c_s = 24, 256

s = np.random.randn(N_res, c_s).astype(np.float32)
R_init = np.tile(np.eye(3), (N_res, 1, 1))
t_init = np.zeros((N_res, 3))

R_new, t_new = backbone_update(s, R_init, t_init)

# Property 1: Rotation matrices are orthogonal
RRT = np.einsum('nij,nkj->nik', R_new, R_new)
is_orthogonal = np.allclose(RRT, np.tile(np.eye(3), (N_res, 1, 1)), atol=0.01)
print(f"Property 1 - R is orthogonal: {is_orthogonal}")

# Property 2: Determinant is 1 (proper rotation)
dets = np.linalg.det(R_new)
det_is_one = np.allclose(dets, 1.0, atol=0.01)
print(f"Property 2 - det(R) = 1: {det_is_one}")

# Property 3: Output shape preserved
shape_ok = R_new.shape == R_init.shape and t_new.shape == t_init.shape
print(f"Property 3 - Shape preserved: {shape_ok}")

# Property 4: Finite values
finite = np.isfinite(R_new).all() and np.isfinite(t_new).all()
print(f"Property 4 - Finite values: {finite}")

# Property 5: Updates are reasonable magnitude
t_change = np.linalg.norm(t_new - t_init)
print(f"Property 5 - Translation change: {t_change:.4f}")

## Source Code Reference

```python
# From AF2-source-code/model/folding.py (FoldIteration)

# Jumper et al. (2021) Suppl. Alg. 23 "BackboneUpdate"

# Predict quaternion update (initialized to zeros)
quaternion_update = common_modules.Linear(
    6,  # 4 for quaternion + 2 (but uses 6 in practice)
    initializer='zeros',
    name='quaternion'
)(act)

# Scale the updates
quaternion_update = quaternion_update * self.config.position_scale

# Convert current frames to quaternion-affine representation
affine = quat_affine.QuatAffine.from_tensor(backb_to_global.to_tensor())

# Apply update (pre-compose in local frame)
affine = affine.pre_compose(quaternion_update)

# Convert back to rigid representation
backb_to_global = r3.rigids_from_quataffine(affine)
```

## Key Insights

1. **Zero Initialization**: The linear layer is initialized to zeros, meaning initial updates are zero. The structure starts from identity frames.

2. **Local Frame Updates**: Updates are applied in the local frame (pre-compose), which is more stable than global updates.

3. **Quaternion Representation**: Using quaternions avoids gimbal lock and provides smooth interpolation.

4. **Iterative Refinement**: The Structure Module runs 8 iterations, each refining the backbone frames.

5. **Position Scale**: The `position_scale` parameter controls the magnitude of updates, typically around 10Å.

6. **Equivariance**: The backbone update respects SE(3) equivariance when combined with IPA's invariant features.