# Algorithm 30: Recycling (Inference)

During inference, AlphaFold2 recycles predictions multiple times to iteratively refine the structure. Each iteration uses the previous output as additional input, allowing the model to correct errors and improve predictions.

## Algorithm Pseudocode

![RecyclingInference](../imgs/algorithms/RecyclingInference.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Class**: `AlphaFold`
- **Lines**: 123-200

## Overview

### Recycling Process

```
Iteration 0:
├── prev_pos = zeros (no prior structure)
├── prev_msa = zeros (no prior MSA embedding)
├── prev_pair = zeros (no prior pair representation)
└── Run model → output₀
        ↓
Iteration 1:
├── prev_pos = output₀.positions
├── prev_msa = output₀.msa_first_row
├── prev_pair = output₀.pair_repr
└── Run model → output₁
        ↓
    ... (repeat) ...
        ↓
Iteration N (final):
└── Return output_N as final prediction
```

### Key Features

1. **Fixed iterations**: Typically 3 recycling iterations during inference
2. **No gradients**: All iterations use stop_gradient (inference only)
3. **Same weights**: The same model is applied in each iteration

In [None]:
import numpy as np

np.random.seed(42)

## NumPy Implementation

In [None]:
def dgram_from_positions(pos, min_bin=3.25, max_bin=50.75, num_bins=15):
    """
    Compute distance histogram from positions.
    Used to convert previous positions to features.
    """
    N_res = pos.shape[0]
    
    # Pairwise distances
    diff = pos[:, None, :] - pos[None, :, :]
    dist = np.sqrt(np.sum(diff**2, axis=-1) + 1e-8)
    
    # Bin edges
    bins = np.linspace(min_bin, max_bin, num_bins + 1)
    
    # One-hot histogram
    dgram = np.zeros((N_res, N_res, num_bins))
    for i in range(num_bins):
        mask = (dist >= bins[i]) & (dist < bins[i+1])
        dgram[:, :, i] = mask.astype(float)
    
    return dgram

In [None]:
class SimpleModel:
    """
    Simplified AlphaFold-like model for demonstrating recycling.
    """
    
    def __init__(self, c_m=256, c_z=128):
        self.c_m = c_m
        self.c_z = c_z
    
    def __call__(self, batch, prev, is_training=False):
        """
        Single forward pass of the model.
        
        Args:
            batch: Input features
            prev: Previous iteration outputs
            is_training: Training mode flag
        
        Returns:
            Model outputs
        """
        N_res = batch['aatype'].shape[0]
        
        # Simulate embedding + Evoformer
        msa_repr = np.random.randn(1, N_res, self.c_m) * 0.1
        pair_repr = np.random.randn(N_res, N_res, self.c_z) * 0.1
        
        # Add previous outputs (recycling embeddings)
        if prev['prev_msa_first_row'] is not None:
            msa_repr[0] = msa_repr[0] + prev['prev_msa_first_row']
        
        if prev['prev_pair'] is not None:
            pair_repr = pair_repr + prev['prev_pair'] * 0.1
        
        if prev['prev_pos'] is not None:
            # Add distance features from previous positions
            ca_pos = prev['prev_pos'][:, 1, :]  # CA atoms
            dgram = dgram_from_positions(ca_pos)
            # Project and add (simplified)
            pair_repr[:, :, :15] += dgram * 0.1
        
        # Simulate structure module
        atom_positions = np.random.randn(N_res, 37, 3) * 5
        if prev['prev_pos'] is not None:
            # Refine from previous positions
            atom_positions = prev['prev_pos'] + np.random.randn(N_res, 37, 3) * 0.5
        
        return {
            'final_atom_positions': atom_positions,
            'msa_first_row': msa_repr[0],
            'pair_repr': pair_repr,
        }

In [None]:
def recycling_inference(batch, model, num_recycle=3):
    """
    Recycling Inference - Algorithm 30.
    
    Iteratively refines predictions using recycling.
    
    Args:
        batch: Input features
        model: Model to run
        num_recycle: Number of recycling iterations
    
    Returns:
        Final prediction
    """
    N_res = batch['aatype'].shape[0]
    
    print(f"Recycling Inference")
    print(f"="*50)
    print(f"Residues: {N_res}")
    print(f"Recycle iterations: {num_recycle}")
    
    # Initialize previous outputs (zeros for first iteration)
    prev = {
        'prev_pos': None,
        'prev_msa_first_row': None,
        'prev_pair': None,
    }
    
    outputs_history = []
    
    # Recycling loop
    for i in range(num_recycle + 1):
        is_final = (i == num_recycle)
        
        print(f"\nIteration {i} {'(final)' if is_final else ''}:")
        
        # Run model
        output = model(batch, prev, is_training=False)
        
        # Track metrics
        pos = output['final_atom_positions']
        ca_pos = pos[:, 1, :]  # CA atoms
        
        # Compute some metrics
        if len(outputs_history) > 0:
            prev_ca = outputs_history[-1]['final_atom_positions'][:, 1, :]
            rmsd = np.sqrt(np.mean((ca_pos - prev_ca) ** 2))
            print(f"  RMSD from previous: {rmsd:.4f}Å")
        
        # Compute CA-CA distances for structure check
        ca_dist = np.linalg.norm(np.diff(ca_pos, axis=0), axis=1)
        print(f"  Mean CA-CA distance: {ca_dist.mean():.2f}Å (std={ca_dist.std():.2f})")
        
        outputs_history.append(output)
        
        # Update prev for next iteration (if not final)
        if not is_final:
            prev['prev_pos'] = output['final_atom_positions'].copy()
            prev['prev_msa_first_row'] = output['msa_first_row'].copy()
            prev['prev_pair'] = output['pair_repr'].copy()
    
    print(f"\nFinal output:")
    print(f"  Atom positions: {output['final_atom_positions'].shape}")
    
    return output, outputs_history

## Test Examples

In [None]:
# Test 1: Basic functionality
print("Test 1: Basic Functionality")
print("="*60)

batch = {
    'aatype': np.random.randint(0, 20, size=64),
}

model = SimpleModel(c_m=256, c_z=128)
output, history = recycling_inference(batch, model, num_recycle=3)

In [None]:
# Test 2: Effect of recycling on structure convergence
print("\nTest 2: Structure Convergence with Recycling")
print("="*60)

np.random.seed(42)
batch = {'aatype': np.random.randint(0, 20, size=32)}
model = SimpleModel()

# Run with different numbers of recycles
for n_recycle in [0, 1, 2, 3, 5]:
    np.random.seed(42)  # Reset for consistency
    output, history = recycling_inference(batch, model, num_recycle=n_recycle)
    
    # Compute final structure metrics
    ca_pos = output['final_atom_positions'][:, 1, :]
    radius_of_gyration = np.sqrt(np.mean(np.sum((ca_pos - ca_pos.mean(axis=0)) ** 2, axis=1)))
    
    print(f"\nRecycles={n_recycle}: Radius of gyration = {radius_of_gyration:.2f}Å")

In [None]:
# Test 3: Track RMSD convergence
print("\nTest 3: RMSD Convergence Across Iterations")
print("="*60)

np.random.seed(42)
batch = {'aatype': np.random.randint(0, 20, size=48)}
model = SimpleModel()

output, history = recycling_inference(batch, model, num_recycle=5)

print(f"\nIteration-by-iteration analysis:")
for i in range(1, len(history)):
    ca_prev = history[i-1]['final_atom_positions'][:, 1, :]
    ca_curr = history[i]['final_atom_positions'][:, 1, :]
    rmsd = np.sqrt(np.mean((ca_curr - ca_prev) ** 2))
    print(f"  Iteration {i-1} -> {i}: RMSD = {rmsd:.4f}Å")

## Verification: Key Properties

In [None]:
print("Verification: Key Properties")
print("="*60)

np.random.seed(42)
batch = {'aatype': np.random.randint(0, 20, size=32)}
model = SimpleModel()

output, history = recycling_inference(batch, model, num_recycle=3)

# Property 1: Correct number of iterations
n_iterations = len(history)
expected = 4  # num_recycle + 1
print(f"Property 1 - Correct iterations: {n_iterations == expected} ({n_iterations})")

# Property 2: Final output has correct shape
N_res = batch['aatype'].shape[0]
shape_correct = output['final_atom_positions'].shape == (N_res, 37, 3)
print(f"Property 2 - Output shape correct: {shape_correct}")

# Property 3: All outputs are finite
all_finite = all(np.isfinite(h['final_atom_positions']).all() for h in history)
print(f"Property 3 - All outputs finite: {all_finite}")

# Property 4: Recycling uses previous output
# Check that iteration 1 output differs from iteration 0
pos_0 = history[0]['final_atom_positions']
pos_1 = history[1]['final_atom_positions']
differs = not np.allclose(pos_0, pos_1)
print(f"Property 4 - Iterations differ: {differs}")

## Source Code Reference

```python
# From AF2-source-code/model/modules.py

class AlphaFold(hk.Module):
  """AlphaFold model with recycling.

  Jumper et al. (2021) Suppl. Alg. 30 "RecyclingInference"
  Jumper et al. (2021) Suppl. Alg. 2 "Inference"
  """

  def __call__(self, batch, is_training, ...):
    impl = AlphaFoldIteration(self.config, self.global_config)
    
    # Initialize previous outputs
    prev = {
        'prev_pos': jnp.zeros([num_res, 37, 3]),
        'prev_msa_first_row': jnp.zeros([num_res, msa_channel]),
        'prev_pair': jnp.zeros([num_res, num_res, pair_channel]),
    }
    
    # Recycling loop
    for recycle_idx in range(num_recycle + 1):
      # stop_gradient on previous outputs during inference
      prev = jax.tree_map(jax.lax.stop_gradient, prev)
      
      # Run iteration
      ret = impl(batch, prev, is_training=False, ...)
      
      # Extract outputs for next iteration
      prev = {
          'prev_pos': ret['structure_module']['final_atom_positions'],
          'prev_msa_first_row': ret['representations']['msa_first_row'],
          'prev_pair': ret['representations']['pair'],
      }
    
    return ret
```

## Key Insights

1. **Iterative Refinement**: Recycling allows the model to iteratively refine predictions, correcting errors from earlier iterations.

2. **Three Recycled Tensors**:
   - `prev_pos`: Previous atom positions (converted to distance features)
   - `prev_msa_first_row`: Previous MSA representation (adds to MSA embedding)
   - `prev_pair`: Previous pair representation (adds to pair features)

3. **Fixed Iterations**: During inference, a fixed number of iterations (typically 3) is used.

4. **No Gradients**: All recycled tensors have stop_gradient applied, so the model treats them as fixed inputs.

5. **Same Weights**: The same model weights are used in each iteration - this is weight sharing across iterations.

6. **Convergence**: In practice, predictions typically converge after 2-3 iterations, with diminishing returns beyond that.