# Algorithm 31: Recycling (Training)

During training, recycling is handled differently to enable gradient computation while maintaining memory efficiency. Gradients only flow through the final iteration, with previous iterations treated as fixed.

## Algorithm Pseudocode

![RecyclingTraining](../imgs/algorithms/RecyclingTraining.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Class**: `AlphaFold`
- **Lines**: 123-200

## Overview

### Training vs Inference Recycling

| Aspect | Inference (Alg 30) | Training (Alg 31) |
|--------|-------------------|-------------------|
| Recycles | Fixed (3) | Random (0 to max) |
| Gradients | None needed | Only on final iteration |
| stop_gradient | On all prev | On all prev |
| Memory | Moderate | Optimized |

### Random Recycling

During training:
1. Sample a random number of recycles (0 to max_recycle)
2. Run that many iterations without gradients
3. Run final iteration WITH gradients
4. Compute loss and backpropagate only through final iteration

In [None]:
import numpy as np

np.random.seed(42)

## NumPy Implementation

In [None]:
def stop_gradient(x):
    """
    Simulated stop_gradient operation.
    
    In JAX: jax.lax.stop_gradient(x)
    In PyTorch: x.detach()
    
    Here we just return a copy (NumPy has no gradients).
    """
    if isinstance(x, dict):
        return {k: stop_gradient(v) for k, v in x.items()}
    elif isinstance(x, np.ndarray):
        return x.copy()
    elif x is None:
        return None
    else:
        return x

In [None]:
class SimpleModel:
    """Simplified model for demonstrating training recycling."""
    
    def __init__(self, c_m=256, c_z=128):
        self.c_m = c_m
        self.c_z = c_z
    
    def __call__(self, batch, prev, is_training=True):
        N_res = batch['aatype'].shape[0]
        
        # Simulate forward pass
        msa_repr = np.random.randn(1, N_res, self.c_m) * 0.1
        pair_repr = np.random.randn(N_res, N_res, self.c_z) * 0.1
        
        if prev['prev_msa_first_row'] is not None:
            msa_repr[0] = msa_repr[0] + prev['prev_msa_first_row']
        
        if prev['prev_pair'] is not None:
            pair_repr = pair_repr + prev['prev_pair'] * 0.1
        
        atom_positions = np.random.randn(N_res, 37, 3) * 5
        if prev['prev_pos'] is not None:
            atom_positions = prev['prev_pos'] + np.random.randn(N_res, 37, 3) * 0.5
        
        return {
            'final_atom_positions': atom_positions,
            'msa_first_row': msa_repr[0],
            'pair_repr': pair_repr,
        }

In [None]:
def recycling_training(batch, model, max_recycle=3):
    """
    Recycling Training - Algorithm 31.
    
    Training-time recycling with random number of iterations.
    
    Args:
        batch: Input features including ground truth
        model: Model to train
        max_recycle: Maximum recycling iterations
    
    Returns:
        Final prediction and loss info
    """
    N_res = batch['aatype'].shape[0]
    
    # STEP 1: Sample random number of recycles (Algorithm 31, Line 1)
    num_recycle = np.random.randint(0, max_recycle + 1)
    
    print(f"Recycling Training")
    print(f"="*50)
    print(f"Residues: {N_res}")
    print(f"Sampled recycles: {num_recycle} (max: {max_recycle})")
    
    # Initialize previous outputs
    prev = {
        'prev_pos': None,
        'prev_msa_first_row': None,
        'prev_pair': None,
    }
    
    # STEP 2: Recycling loop WITHOUT gradients (Lines 2-6)
    for i in range(num_recycle):
        print(f"\nIteration {i} (no gradient):")
        
        # stop_gradient on previous outputs (Line 3)
        prev_stopped = stop_gradient(prev)
        
        # Forward pass without gradients
        output = model(batch, prev_stopped, is_training=True)
        
        # Update prev with stop_gradient (Line 5)
        prev = {
            'prev_pos': stop_gradient(output['final_atom_positions']),
            'prev_msa_first_row': stop_gradient(output['msa_first_row']),
            'prev_pair': stop_gradient(output['pair_repr']),
        }
        
        print(f"  stop_gradient applied to prev tensors")
    
    # STEP 3: Final iteration WITH gradients (Lines 7-9)
    print(f"\nFinal iteration (WITH gradient):")
    
    # stop_gradient on prev, but not on model outputs
    prev_stopped = stop_gradient(prev)
    
    # This is where gradients would flow in real training
    final_output = model(batch, prev_stopped, is_training=True)
    
    print(f"  Gradients flow through this iteration")
    
    # STEP 4: Compute loss (would use ground truth)
    if 'gt_positions' in batch:
        pos_error = np.mean((final_output['final_atom_positions'] - batch['gt_positions']) ** 2)
        print(f"  Position MSE: {pos_error:.4f}")
    
    return {
        'output': final_output,
        'num_recycle': num_recycle,
    }

## Test Examples

In [None]:
# Test 1: Basic functionality
print("Test 1: Basic Functionality")
print("="*60)

batch = {
    'aatype': np.random.randint(0, 20, size=64),
    'gt_positions': np.random.randn(64, 37, 3) * 10,
}

model = SimpleModel(c_m=256, c_z=128)
result = recycling_training(batch, model, max_recycle=3)

In [None]:
# Test 2: Random recycle sampling distribution
print("\nTest 2: Recycle Sampling Distribution")
print("="*60)

max_recycle = 3
n_samples = 10000

recycle_counts = np.zeros(max_recycle + 1)

for _ in range(n_samples):
    num_recycle = np.random.randint(0, max_recycle + 1)
    recycle_counts[num_recycle] += 1

print(f"Distribution of recycles over {n_samples} samples:")
for i in range(max_recycle + 1):
    pct = recycle_counts[i] / n_samples * 100
    print(f"  {i} recycles: {pct:.1f}%")

print(f"\nExpected: uniform {100/(max_recycle+1):.1f}% each")

In [None]:
# Test 3: Multiple training steps with different recycles
print("\nTest 3: Multiple Training Steps")
print("="*60)

batch = {
    'aatype': np.random.randint(0, 20, size=32),
    'gt_positions': np.random.randn(32, 37, 3) * 10,
}

model = SimpleModel()

for step in range(5):
    print(f"\n--- Training step {step + 1} ---")
    result = recycling_training(batch, model, max_recycle=3)
    print(f"Used {result['num_recycle']} recycles")

In [None]:
# Test 4: Zero recycles case
print("\nTest 4: Zero Recycles Case")
print("="*60)

# Force zero recycles by setting seed
np.random.seed(0)  # This gives num_recycle=0 for uniform(0,4)

batch = {
    'aatype': np.random.randint(0, 20, size=32),
}

model = SimpleModel()

# Run multiple times until we get 0 recycles
found_zero = False
for _ in range(10):
    num_recycle = np.random.randint(0, 4)
    if num_recycle == 0:
        found_zero = True
        print("Got 0 recycles - only final iteration runs with gradients")
        break

print(f"\nWith 0 recycles:")
print(f"  - No previous outputs")
print(f"  - Model sees only raw inputs")
print(f"  - Equivalent to non-recycling baseline")

## Verification: Key Properties

In [None]:
print("Verification: Key Properties")
print("="*60)

np.random.seed(42)
batch = {
    'aatype': np.random.randint(0, 20, size=32),
    'gt_positions': np.random.randn(32, 37, 3) * 10,
}

model = SimpleModel()
max_recycle = 3

# Property 1: num_recycle in valid range
valid_range = True
for _ in range(100):
    n = np.random.randint(0, max_recycle + 1)
    if n < 0 or n > max_recycle:
        valid_range = False
        break
print(f"Property 1 - Recycles in [0, {max_recycle}]: {valid_range}")

# Property 2: Final output exists regardless of num_recycle
result = recycling_training(batch, model, max_recycle=3)
has_output = result['output'] is not None
print(f"Property 2 - Final output exists: {has_output}")

# Property 3: Output shape correct
N_res = batch['aatype'].shape[0]
shape_correct = result['output']['final_atom_positions'].shape == (N_res, 37, 3)
print(f"Property 3 - Output shape correct: {shape_correct}")

# Property 4: stop_gradient behavior (all prev are copies)
# In our simulation, stop_gradient returns copies
prev = {'x': np.array([1, 2, 3])}
prev_stopped = stop_gradient(prev)
prev_stopped['x'][0] = 999  # Modify copy
unchanged = prev['x'][0] == 1  # Original unchanged
print(f"Property 4 - stop_gradient creates copy: {unchanged}")

## Source Code Reference

```python
# From AF2-source-code/model/modules.py

class AlphaFold(hk.Module):
  """AlphaFold model with recycling.

  Jumper et al. (2021) Suppl. Alg. 31 "RecyclingTraining"
  """

  def __call__(self, batch, is_training, ...):
    if is_training:
      # Sample random number of recycles
      num_recycle = hk.get_state(
          'num_recycles',
          init=lambda: jax.random.randint(..., 0, max_recycle + 1))
      
      # Recycling iterations without gradients
      for i in range(num_recycle):
        # Apply stop_gradient to break gradient flow
        prev = jax.lax.stop_gradient(prev)
        ret = impl(batch, prev, is_training=True, ...)
        prev = get_prev(ret)
      
      # Final iteration WITH gradients
      prev = jax.lax.stop_gradient(prev)  # Still stop grad on prev
      ret = impl(batch, prev, is_training=True, ...)  # Gradients flow here
      
      return ret
```

## Key Insights

1. **Random Recycles**: Using a random number of recycles during training provides data augmentation and prevents the model from overfitting to a specific number of iterations.

2. **stop_gradient**: The `stop_gradient` operation is crucial - it prevents gradients from flowing through previous iterations, reducing memory usage by not storing intermediate activations.

3. **Only Final Gradients**: Only the final iteration contributes to gradient computation. This is memory-efficient but means the model learns to make good final predictions given (possibly imperfect) previous outputs.

4. **Training Efficiency**: With random recycles, some training steps use 0 recycles (no recycling overhead), while others use up to max_recycle iterations.

5. **Generalization**: Training with varying numbers of recycles helps the model generalize to inference where a fixed number is used.

6. **Memory Savings**: By not computing gradients through all iterations, memory usage is roughly 1/N of what full backprop would require.