# Algorithm 31: Recycling (Training)

During training, recycling is handled differently to enable gradient computation while maintaining memory efficiency.

## Algorithm Pseudocode

![RecyclingTraining](../imgs/algorithms/RecyclingTraining.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Class**: `AlphaFold`
- **Lines**: 123-200

## Training vs Inference Recycling

| Aspect | Inference | Training |
|--------|-----------|----------|
| Recycles | Fixed (3) | Random (0-3) |
| Gradients | Not needed | Only on final |
| stop_gradient | On all | On prev tensors |
| Memory | Full | Checkpoint |

In [None]:
import numpy as np

np.random.seed(42)

In [None]:
def recycling_training(batch, model_fn, max_recycle=3):
    """
    Recycling Training - Algorithm 31.
    
    Training-time recycling with random number of iterations.
    
    Args:
        batch: Input features
        model_fn: Forward pass function
        max_recycle: Maximum recycling iterations
    
    Returns:
        Final prediction (with gradients only on final iteration)
    """
    N_res = batch['aatype'].shape[0]
    
    # Sample random number of recycles
    num_recycle = np.random.randint(0, max_recycle + 1)
    
    print(f"Recycling Training")
    print(f"  Residues: {N_res}")
    print(f"  Sampled recycles: {num_recycle} (max: {max_recycle})")
    
    # Initialize recycled tensors
    prev = {
        'prev_pos': np.zeros((N_res, 37, 3)),
        'prev_msa_first_row': np.zeros((N_res, 256)),
        'prev_pair': np.zeros((N_res, N_res, 128)),
    }
    
    # Recycling loop (no gradients until final)
    for i in range(num_recycle):
        print(f"\n  Iteration {i} (no gradient):")
        
        # Forward pass WITHOUT gradients
        # In JAX: jax.lax.stop_gradient
        output = model_fn(batch, prev, is_training=True)
        
        # Update prev with stop_gradient
        prev['prev_pos'] = output['final_atom_positions'].copy()  # stop_gradient
        prev['prev_msa_first_row'] = output['msa_first_row'].copy()
        prev['prev_pair'] = output['pair_repr'].copy()
        
        print(f"    stop_gradient applied to prev tensors")
    
    # Final iteration WITH gradients
    print(f"\n  Final iteration (with gradient):")
    final_output = model_fn(batch, prev, is_training=True)
    print(f"    Computing loss and gradients")
    
    return final_output


def dummy_model(batch, prev, is_training=True):
    """Dummy model function for testing."""
    N_res = batch['aatype'].shape[0]
    
    return {
        'final_atom_positions': prev['prev_pos'] + np.random.randn(N_res, 37, 3),
        'msa_first_row': prev['prev_msa_first_row'] + np.random.randn(N_res, 256) * 0.1,
        'pair_repr': prev['prev_pair'] + np.random.randn(N_res, N_res, 128) * 0.1,
    }

In [None]:
# Test
batch = {
    'aatype': np.random.randint(0, 20, size=64),
}

print("Test Recycling Training")
print("="*50)

output = recycling_training(batch, dummy_model, max_recycle=3)

print(f"\nFinal output keys: {list(output.keys())}")

In [None]:
# Test multiple runs to show random sampling
print("\nRandom recycle sampling:")
for _ in range(5):
    num_recycle = np.random.randint(0, 4)
    print(f"  Run: {num_recycle} recycles")

## Source Code Reference

```python
# From AF2-source-code/model/modules.py

class AlphaFold(hk.Module):
  """AlphaFold model with recycling.

  Jumper et al. (2021) Suppl. Alg. 31 "RecyclingTraining"
  """

  def __call__(self, batch, is_training, ...):
    if is_training:
      # Random number of recycles during training
      num_iter = hk.get_state('num_recycles', ...)
      
      # Run iterations without gradients
      for i in range(num_iter):
        ret = impl(batch_i, jax.lax.stop_gradient(prev), ...)
        prev = get_prev(ret)
      
      # Final iteration with gradients
      ret = impl(batch, jax.lax.stop_gradient(prev), is_training=True, ...)
```