# Algorithm 20: Structure Module

The Structure Module is responsible for converting the abstract representations from the Evoformer into actual 3D atomic coordinates. It iteratively refines backbone frames using Invariant Point Attention and predicts sidechain conformations.

## Algorithm Pseudocode

![Structure Module](../imgs/algorithms/StructureModule.png)

## Source Code Location
- **File**: `AF2-source-code/model/folding.py`
- **Class**: `StructureModule`, `FoldIteration`
- **Lines**: 281-559

## Overview

The Structure Module consists of:
1. **Initial frame generation**: Start with identity transforms for all residues
2. **Iterative refinement**: 8 iterations of IPA + backbone updates
3. **Sidechain prediction**: Predict torsion angles for sidechains
4. **Coordinate generation**: Convert frames and torsion angles to atom positions

In [None]:
import numpy as np

np.random.seed(42)

## Rigid Body Representation (QuatAffine)

In [None]:
class RigidTransform:
    """Simplified rigid body transformation (rotation + translation)."""
    
    def __init__(self, rotation, translation):
        """
        Args:
            rotation: [N, 3, 3] rotation matrices
            translation: [N, 3] translation vectors
        """
        self.rotation = rotation
        self.translation = translation
    
    @staticmethod
    def identity(n_residues):
        """Create identity transforms for all residues."""
        rotation = np.tile(np.eye(3), (n_residues, 1, 1))
        translation = np.zeros((n_residues, 3))
        return RigidTransform(rotation, translation)
    
    def apply(self, points):
        """Apply transformation to points: R @ p + t"""
        return np.einsum('nij,nj->ni', self.rotation, points) + self.translation
    
    def compose(self, update_rotation, update_translation):
        """Compose with another transform: self ∘ update"""
        new_rotation = np.einsum('nij,njk->nik', self.rotation, update_rotation)
        new_translation = self.translation + np.einsum('nij,nj->ni', self.rotation, update_translation)
        return RigidTransform(new_rotation, new_translation)
    
    def to_tensor(self):
        """Convert to flat tensor [N, 7] (quaternion + translation)."""
        # Simplified: just flatten rotation and translation
        return np.concatenate([
            self.rotation.reshape(-1, 9),
            self.translation
        ], axis=-1)

## Structure Module Components

In [None]:
def layer_norm(x, axis=-1, eps=1e-5):
    mean = np.mean(x, axis=axis, keepdims=True)
    var = np.var(x, axis=axis, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)


def relu(x):
    return np.maximum(x, 0)


def simplified_ipa(single_act, pair_act, affine, mask):
    """
    Simplified Invariant Point Attention.
    Returns update to single representation.
    """
    N_res, c_s = single_act.shape
    
    # Simplified: just use pair information as attention bias
    pair_reduced = np.mean(pair_act, axis=-1)  # [N_res, N_res]
    
    # Simple attention
    attn = np.exp(pair_reduced * 0.1)
    attn = attn * mask[:, None] * mask[None, :]
    attn = attn / (np.sum(attn, axis=-1, keepdims=True) + 1e-8)
    
    # Aggregate
    update = attn @ single_act
    
    return update


def transition_block(act, num_layers=3):
    """Simple transition MLP."""
    c = act.shape[-1]
    for i in range(num_layers):
        w = np.random.randn(c, c) * 0.01
        act = act @ w
        if i < num_layers - 1:
            act = relu(act)
    return act


def backbone_update(act):
    """
    Algorithm 23: Backbone Update
    
    Predict 6-DOF update to backbone frame.
    Returns rotation (as axis-angle) and translation updates.
    """
    c = act.shape[-1]
    N_res = act.shape[0]
    
    # Project to 6 values: 3 for rotation (axis-angle), 3 for translation
    w = np.random.randn(c, 6) * 0.001  # Small init for stability
    update = act @ w
    
    # Split into rotation and translation
    rot_update = update[:, :3]  # Axis-angle
    trans_update = update[:, 3:]  # Translation
    
    # Convert axis-angle to rotation matrix
    rot_matrices = np.zeros((N_res, 3, 3))
    for i in range(N_res):
        theta = np.linalg.norm(rot_update[i])
        if theta < 1e-6:
            rot_matrices[i] = np.eye(3)
        else:
            axis = rot_update[i] / theta
            K = np.array([[0, -axis[2], axis[1]],
                          [axis[2], 0, -axis[0]],
                          [-axis[1], axis[0], 0]])
            rot_matrices[i] = np.eye(3) + np.sin(theta) * K + (1 - np.cos(theta)) * K @ K
    
    return rot_matrices, trans_update

## Main Structure Module

In [None]:
def structure_module(
    single_repr,      # [N_res, c_s] from Evoformer
    pair_repr,        # [N_res, N_res, c_z] from Evoformer
    mask,             # [N_res] residue mask
    num_iterations=8  # Number of refinement iterations
):
    """
    Structure Module - Algorithm 20.
    
    Iteratively refines backbone frames using IPA.
    
    Args:
        single_repr: Single representation from Evoformer
        pair_repr: Pair representation from Evoformer
        mask: Residue mask
        num_iterations: Number of structure refinement iterations
    
    Returns:
        Dictionary with:
        - 'frames': Final backbone frames [N_res, 12]
        - 'positions': Backbone atom positions [N_res, 3, 3] (N, CA, C)
        - 'trajectory': Frames at each iteration
    """
    N_res, c_s = single_repr.shape
    
    print(f"Structure Module: {N_res} residues, {num_iterations} iterations")
    
    # Step 1: Initialize (Lines 1-4)
    # Layer norm on inputs
    single_act = layer_norm(single_repr)
    pair_act = layer_norm(pair_repr)
    
    # Initial projection
    init_proj_w = np.random.randn(c_s, c_s) * 0.01
    act = single_act @ init_proj_w
    
    # Initialize backbone frames to identity (Line 5)
    affine = RigidTransform.identity(N_res)
    
    print(f"  Initial frames: identity transforms")
    
    # Store trajectory
    trajectory = [affine.to_tensor()]
    
    # Step 2: Iterative refinement (Lines 6-21)
    for iteration in range(num_iterations):
        print(f"  Iteration {iteration + 1}/{num_iterations}")
        
        # Line 7: Invariant Point Attention (Algorithm 22)
        ipa_update = simplified_ipa(act, pair_act, affine, mask)
        act = act + ipa_update
        act = layer_norm(act)
        
        # Line 8: Dropout (skip in inference)
        
        # Lines 9-13: Transition block
        trans_update = transition_block(act)
        act = act + trans_update
        act = layer_norm(act)
        
        # Lines 14-15: Backbone update (Algorithm 23)
        rot_update, trans_update = backbone_update(act)
        
        # Line 16: Update affine (pre-compose)
        affine = affine.compose(rot_update, trans_update)
        
        # Store in trajectory
        trajectory.append(affine.to_tensor())
        
        # Line 17: Stop gradient on rotation (for next iteration)
        # In JAX: affine.rotation = stop_gradient(affine.rotation)
    
    # Step 3: Generate backbone positions
    # Standard backbone atom positions in local frame (Angstroms)
    N_pos_local = np.array([-0.527, 1.360, 0.000])  # N atom
    CA_pos_local = np.array([0.000, 0.000, 0.000])  # CA atom (origin)
    C_pos_local = np.array([1.526, 0.000, 0.000])   # C atom
    
    backbone_positions = np.zeros((N_res, 3, 3))  # [N_res, 3 atoms, 3 coords]
    for i in range(N_res):
        backbone_positions[i, 0] = affine.rotation[i] @ N_pos_local + affine.translation[i]
        backbone_positions[i, 1] = affine.rotation[i] @ CA_pos_local + affine.translation[i]
        backbone_positions[i, 2] = affine.rotation[i] @ C_pos_local + affine.translation[i]
    
    print(f"  Final backbone positions computed")
    
    return {
        'frames': affine.to_tensor(),
        'positions': backbone_positions,
        'trajectory': np.stack(trajectory),
        'single_act': act
    }

## Test Example

In [None]:
# Test parameters
N_res = 32      # Number of residues
c_s = 384       # Single representation dimension
c_z = 128       # Pair representation dimension

# Create test inputs (simulating Evoformer output)
single_repr = np.random.randn(N_res, c_s).astype(np.float32)
pair_repr = np.random.randn(N_res, N_res, c_z).astype(np.float32)
mask = np.ones(N_res, dtype=np.float32)

print(f"Single representation: {single_repr.shape}")
print(f"Pair representation: {pair_repr.shape}")
print()

In [None]:
# Run structure module
result = structure_module(
    single_repr,
    pair_repr,
    mask,
    num_iterations=8
)

print(f"\nOutput shapes:")
print(f"  Frames: {result['frames'].shape}")
print(f"  Backbone positions: {result['positions'].shape}")
print(f"  Trajectory: {result['trajectory'].shape}")

## Analyze Output

In [None]:
# Check backbone geometry
positions = result['positions']

# Compute CA-CA distances
ca_positions = positions[:, 1, :]  # CA atoms
ca_distances = np.sqrt(np.sum((ca_positions[1:] - ca_positions[:-1])**2, axis=-1))

print(f"CA-CA distances (should be ~3.8 Å for alpha helix):")
print(f"  Mean: {ca_distances.mean():.2f} Å")
print(f"  Std: {ca_distances.std():.2f} Å")
print(f"  Range: [{ca_distances.min():.2f}, {ca_distances.max():.2f}] Å")

# Compute N-CA and CA-C bond lengths
n_ca_dist = np.sqrt(np.sum((positions[:, 1] - positions[:, 0])**2, axis=-1))
ca_c_dist = np.sqrt(np.sum((positions[:, 2] - positions[:, 1])**2, axis=-1))

print(f"\nBond lengths:")
print(f"  N-CA: {n_ca_dist.mean():.3f} Å (ideal: 1.46 Å)")
print(f"  CA-C: {ca_c_dist.mean():.3f} Å (ideal: 1.53 Å)")

## Visualize Frame Evolution

In [None]:
# Track how frames change over iterations
trajectory = result['trajectory']

# Extract translations from trajectory
translations = trajectory[:, :, 9:12]  # Last 3 values are translations

print(f"Frame evolution over {trajectory.shape[0]} iterations:")
for i in range(trajectory.shape[0]):
    trans_norm = np.linalg.norm(translations[i], axis=-1).mean()
    print(f"  Iteration {i}: mean translation magnitude = {trans_norm:.4f}")

## Source Code Reference

```python
# From AF2-source-code/model/folding.py

class StructureModule(hk.Module):
  """StructureModule as a network head.

  Jumper et al. (2021) Suppl. Alg. 20 "StructureModule"
  """

  def __call__(self, representations, batch, is_training, safe_key=None):
    output = generate_affines(
        representations=representations,
        batch=batch,
        config=self.config,
        global_config=self.global_config,
        is_training=is_training,
        safe_key=safe_key)

    # Convert affines to atom positions
    atom14_pred_positions = output['sc']['atom_pos'][-1]
    atom37_pred_positions = all_atom.atom14_to_atom37(atom14_pred_positions, batch)
    
    return {
        'final_atom_positions': atom37_pred_positions,
        'final_atom_mask': batch['atom37_atom_exists'],
        'final_affines': output['affine'][-1],
    }
```

## Key Insights

1. **Iterative Refinement**: 8 iterations of IPA + backbone updates progressively refine the structure.

2. **SE(3) Equivariance**: IPA ensures the refinement respects 3D geometry.

3. **Stop Gradient**: Rotation gradients are stopped between iterations for training stability.

4. **Pre-composition**: Frame updates are applied via pre-composition: T_new = T_old ∘ ΔT.

5. **Sidechain Prediction**: After backbone, sidechains are predicted from torsion angles (Algorithm 24).