# Algorithm 2: Inference

This is the main inference algorithm that orchestrates the entire AlphaFold2 pipeline, from input features to final structure prediction with recycling.

## Algorithm Pseudocode

![Inference](../imgs/algorithms/Inference.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Class**: `AlphaFold`, `AlphaFoldIteration`
- **Lines**: 123-391

## Pipeline Overview

```
Input Features
├── MSA features (sequences, deletion info)
├── Template features (known structures)
└── Sequence features (amino acid types, positions)
        ↓
┌────────────────────────────────────────┐
│         Recycling Loop (×3)            │
│  ┌──────────────────────────────────┐  │
│  │ 1. Input Embedding (Alg 3)       │  │
│  │    + Recycling Embedder (Alg 32) │  │
│  └──────────────┬───────────────────┘  │
│                 ↓                      │
│  ┌──────────────────────────────────┐  │
│  │ 2. Extra MSA Stack (Alg 18)      │  │
│  │    4 blocks, updates pair only   │  │
│  └──────────────┬───────────────────┘  │
│                 ↓                      │
│  ┌──────────────────────────────────┐  │
│  │ 3. Evoformer Stack (Alg 6)       │  │
│  │    48 blocks of MSA + Pair       │  │
│  └──────────────┬───────────────────┘  │
│                 ↓                      │
│  ┌──────────────────────────────────┐  │
│  │ 4. Structure Module (Alg 20)     │  │
│  │    8 iterations of IPA           │  │
│  └──────────────┬───────────────────┘  │
│                 ↓                      │
│  Update prev_pos, prev_msa, prev_pair  │
└────────────────────────────────────────┘
        ↓
Final Structure Prediction
├── Atom coordinates [N_res, 37, 3]
├── pLDDT confidence [N_res]
└── PAE (optional) [N_res, N_res]
```

In [None]:
import numpy as np

np.random.seed(42)

## NumPy Implementation

In [None]:
class AlphaFoldInference:
    """
    Simplified AlphaFold2 inference implementation.
    
    Algorithm 2 from supplementary materials.
    """
    
    def __init__(self, config=None):
        """Initialize with configuration."""
        self.config = config or {
            'c_m': 256,           # MSA channel dimension
            'c_z': 128,           # Pair channel dimension
            'c_s': 384,           # Single representation dimension
            'num_recycle': 3,     # Number of recycling iterations
            'num_evoformer_blocks': 48,
            'num_structure_iterations': 8,
        }
    
    def input_embedder(self, batch, prev):
        """
        Algorithm 3 + Algorithm 32: Input and Recycling Embedding.
        
        Creates initial MSA (m) and pair (z) representations.
        Incorporates previous iteration outputs for recycling.
        """
        N_res = batch['aatype'].shape[0]
        N_seq = batch.get('msa', np.zeros((128, N_res))).shape[0]
        c_m = self.config['c_m']
        c_z = self.config['c_z']
        
        # Initial embeddings (simplified)
        m = np.random.randn(N_seq, N_res, c_m) * 0.1
        z = np.random.randn(N_res, N_res, c_z) * 0.1
        
        # Add recycled information
        if prev['prev_msa_first_row'] is not None:
            # Add previous MSA first row
            m[0] = m[0] + prev['prev_msa_first_row']
        
        if prev['prev_pair'] is not None:
            # Add previous pair representation
            z = z + prev['prev_pair'] * 0.1
        
        if prev['prev_pos'] is not None:
            # Add distance features from previous positions
            ca_pos = prev['prev_pos'][:, 1, :]  # CA atoms
            dist = np.linalg.norm(ca_pos[:, None, :] - ca_pos[None, :, :], axis=-1)
            # Binned distance features (simplified)
            z[:, :, 0] += np.clip(dist / 10.0, 0, 1)  # Normalized distance
        
        return m, z
    
    def evoformer(self, m, z, num_blocks=48):
        """
        Algorithm 6: Evoformer Stack.
        
        Iteratively refines MSA and pair representations.
        """
        for block_idx in range(num_blocks):
            # MSA row attention with pair bias (Alg 7)
            m = m + np.random.randn(*m.shape) * 0.01
            
            # MSA column attention (Alg 8)
            m = m + np.random.randn(*m.shape) * 0.01
            
            # MSA transition (Alg 9)
            m = m + np.random.randn(*m.shape) * 0.01
            
            # Outer product mean (Alg 10) -> pair update
            z = z + np.random.randn(*z.shape) * 0.01
            
            # Triangle multiplication (Alg 11, 12)
            z = z + np.random.randn(*z.shape) * 0.01
            
            # Triangle attention (Alg 13, 14)
            z = z + np.random.randn(*z.shape) * 0.01
            
            # Pair transition (Alg 15)
            z = z + np.random.randn(*z.shape) * 0.01
        
        return m, z
    
    def structure_module(self, s, z, num_iterations=8):
        """
        Algorithm 20: Structure Module.
        
        Predicts 3D structure from representations.
        """
        N_res = s.shape[0]
        
        # Initialize backbone frames as identity
        R = np.tile(np.eye(3), (N_res, 1, 1))  # [N_res, 3, 3]
        t = np.zeros((N_res, 3))  # [N_res, 3]
        
        for iter_idx in range(num_iterations):
            # Invariant Point Attention (Alg 22)
            s = s + np.random.randn(*s.shape) * 0.01
            
            # Backbone Update (Alg 23)
            delta_t = np.random.randn(N_res, 3) * 0.1
            t = t + delta_t
        
        # Compute all atom positions (Alg 24)
        atom_positions = np.zeros((N_res, 37, 3))
        # Simplified: place backbone atoms
        atom_positions[:, 0] = t  # N
        atom_positions[:, 1] = t + np.array([1.458, 0, 0])  # CA
        atom_positions[:, 2] = t + np.array([2.0, 1.4, 0])  # C
        atom_positions[:, 3] = t + np.array([1.2, 2.5, 0])  # O
        
        return atom_positions, s
    
    def predict_plddt(self, s):
        """
        Algorithm 29: Predict per-residue LDDT (pLDDT).
        
        Confidence metric for each residue.
        """
        N_res = s.shape[0]
        # Simplified: random confidence scores
        logits = np.random.randn(N_res, 50) * 0.5
        probs = np.exp(logits) / np.exp(logits).sum(axis=-1, keepdims=True)
        bins = np.linspace(0, 1, 50)
        plddt = (probs * bins).sum(axis=-1) * 100
        return plddt
    
    def __call__(self, batch, num_recycle=None):
        """
        Main inference loop with recycling.
        
        Algorithm 2 + Algorithm 30 (Recycling Inference).
        """
        if num_recycle is None:
            num_recycle = self.config['num_recycle']
        
        N_res = batch['aatype'].shape[0]
        c_m = self.config['c_m']
        c_z = self.config['c_z']
        c_s = self.config['c_s']
        
        print(f"AlphaFold2 Inference")
        print(f"="*60)
        print(f"Sequence length: {N_res}")
        print(f"Recycling iterations: {num_recycle}")
        print()
        
        # Initialize previous outputs (zeros for first iteration)
        prev = {
            'prev_pos': None,
            'prev_msa_first_row': None,
            'prev_pair': None,
        }
        
        # Recycling loop
        for recycle_idx in range(num_recycle + 1):
            is_final = (recycle_idx == num_recycle)
            print(f"Recycling iteration {recycle_idx} {'(final)' if is_final else ''}:")
            
            # Step 1: Input Embedding + Recycling
            m, z = self.input_embedder(batch, prev)
            print(f"  Input embedding: m={m.shape}, z={z.shape}")
            
            # Step 2: Evoformer Stack
            m, z = self.evoformer(m, z, num_blocks=48)
            print(f"  Evoformer: m={m.shape}, z={z.shape}")
            
            # Step 3: Extract single representation (first row of MSA)
            s = m[0]  # [N_res, c_m]
            # Project to single dimension
            W_s = np.random.randn(c_m, c_s) * 0.01
            s = s @ W_s  # [N_res, c_s]
            print(f"  Single representation: s={s.shape}")
            
            # Step 4: Structure Module
            atom_positions, s_updated = self.structure_module(s, z, num_iterations=8)
            print(f"  Structure module: atoms={atom_positions.shape}")
            
            # Update prev for next iteration
            if not is_final:
                prev['prev_pos'] = atom_positions.copy()
                prev['prev_msa_first_row'] = m[0].copy()
                prev['prev_pair'] = z.copy()
            
            print()
        
        # Step 5: Compute confidence
        plddt = self.predict_plddt(s_updated)
        
        print(f"Final outputs:")
        print(f"  Atom positions: {atom_positions.shape}")
        print(f"  pLDDT: mean={plddt.mean():.1f}, range=[{plddt.min():.1f}, {plddt.max():.1f}]")
        
        return {
            'final_atom_positions': atom_positions,
            'plddt': plddt,
            'msa_first_row': m[0],
            'pair_repr': z,
            'single_repr': s_updated,
        }

## Test Example

In [None]:
# Create test batch
N_res = 64
N_seq = 128

batch = {
    'aatype': np.random.randint(0, 20, size=N_res),
    'msa': np.random.randint(0, 21, size=(N_seq, N_res)),
    'residue_index': np.arange(N_res),
}

print("Test AlphaFold2 Inference")
print("="*60)

# Run inference
model = AlphaFoldInference()
result = model(batch, num_recycle=3)

In [None]:
# Verify output structure
print("\nOutput Verification:")
print("="*60)

expected_keys = ['final_atom_positions', 'plddt', 'msa_first_row', 'pair_repr', 'single_repr']
for key in expected_keys:
    if key in result:
        print(f"  {key}: {result[key].shape}")
    else:
        print(f"  {key}: MISSING")

In [None]:
# Test: Recycling improves consistency
print("\nTest: Effect of Recycling")
print("="*60)

for n_recycle in [0, 1, 2, 3]:
    np.random.seed(42)
    result = model(batch, num_recycle=n_recycle)
    pos = result['final_atom_positions']
    
    # Compute CA distances to check structure consistency
    ca_pos = pos[:, 1, :]  # CA atoms
    ca_dist = np.linalg.norm(np.diff(ca_pos, axis=0), axis=1)  # Sequential CA distances
    
    print(f"\nRecycle={n_recycle}: Mean CA-CA distance = {ca_dist.mean():.2f}Å (std={ca_dist.std():.2f})")

## Verification: Key Properties

In [None]:
print("Verification: Key Properties")
print("="*60)

np.random.seed(42)
result = model(batch, num_recycle=3)

# Property 1: Correct output shapes
assert result['final_atom_positions'].shape == (N_res, 37, 3), "Atom positions shape"
assert result['plddt'].shape == (N_res,), "pLDDT shape"
print(f"Property 1 - Output shapes correct: True")

# Property 2: pLDDT in valid range [0, 100]
plddt_valid = (result['plddt'] >= 0).all() and (result['plddt'] <= 100).all()
print(f"Property 2 - pLDDT in [0, 100]: {plddt_valid}")

# Property 3: Atom positions are finite
positions_finite = np.isfinite(result['final_atom_positions']).all()
print(f"Property 3 - Atom positions finite: {positions_finite}")

# Property 4: Pair representation is symmetric-ish
z = result['pair_repr']
symmetry_diff = np.abs(z - z.transpose(1, 0, 2)).mean()
print(f"Property 4 - Pair asymmetry (lower=better): {symmetry_diff:.4f}")

## Source Code Reference

```python
# From AF2-source-code/model/modules.py

class AlphaFold(hk.Module):
  """AlphaFold model with recycling.

  Jumper et al. (2021) Suppl. Alg. 2 "Inference"
  Jumper et al. (2021) Suppl. Alg. 30 "RecyclingInference"
  """

  def __call__(self, batch, is_training, ...):
    # Implementation using AlphaFoldIteration
    impl = AlphaFoldIteration(self.config, self.global_config)
    
    # Initialize previous outputs
    prev = {
        'prev_pos': jnp.zeros([num_res, 37, 3]),
        'prev_msa_first_row': jnp.zeros([num_res, msa_channel]),
        'prev_pair': jnp.zeros([num_res, num_res, pair_channel]),
    }
    
    # Recycling loop
    for recycle_idx in range(num_recycle + 1):
      # Add recycled embeddings
      batch_with_prev = {**batch, **prev}
      
      # Run iteration
      ret = impl(batch_with_prev, is_training=False, ...)
      
      # Update prev (with stop_gradient for training)
      prev = {
          'prev_pos': ret['structure_module']['final_atom_positions'],
          'prev_msa_first_row': ret['representations']['msa_first_row'],
          'prev_pair': ret['representations']['pair'],
      }
    
    return ret
```

## Key Insights

1. **Recycling**: The same network is run multiple times, with each iteration using the previous output as additional input. This allows iterative refinement of predictions.

2. **Three Recycled Tensors**:
   - `prev_pos`: Previous atom positions → converted to distance features
   - `prev_msa_first_row`: Previous MSA representation → added to MSA
   - `prev_pair`: Previous pair representation → added to pair features

3. **Evoformer Dominates**: 48 blocks × ~100M parameters = most of the model's capacity

4. **Structure Module is Iterative**: 8 internal iterations for structure refinement

5. **Confidence Prediction**: pLDDT provides per-residue confidence, crucial for filtering predictions

6. **Memory Efficiency**: During training, gradients only flow through the final recycling iteration (stop_gradient on prev tensors)