# Algorithm 32: Recycling Embedder

The Recycling Embedder processes outputs from previous iterations and adds them to the current input embeddings. This is how recycled information is integrated into the model.

## Algorithm Pseudocode

![RecyclingEmbedder](../imgs/algorithms/RecyclingEmbedder.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Class**: `EmbeddingsAndEvoformer` (recycling section)
- **Lines**: 1700-1740

## Overview

### Recycled Information

| Input | Processing | Output |
|-------|------------|--------|
| prev_pos | → Distance histogram → Linear | → Added to pair repr |
| prev_pair | → LayerNorm → Linear | → Added to pair repr |
| prev_msa_first_row | → LayerNorm → Linear | → Added to MSA first row |

### Processing Steps

1. **Position → Distance Histogram**: Convert CA positions to binned pairwise distances
2. **Project and Add**: Linear projections to match representation dimensions
3. **Layer Norm**: Normalize previous representations before projection

In [None]:
import numpy as np

np.random.seed(42)

## NumPy Implementation

In [None]:
def layer_norm(x, axis=-1, eps=1e-5):
    """Layer normalization."""
    mean = np.mean(x, axis=axis, keepdims=True)
    var = np.var(x, axis=axis, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)


def dgram_from_positions(pos, min_bin=3.25, max_bin=50.75, num_bins=15):
    """
    Compute distance histogram (dgram) from CA positions.
    
    Bins pairwise distances into one-hot histogram.
    
    Args:
        pos: CA atom positions [N_res, 3]
        min_bin: Minimum distance bin edge
        max_bin: Maximum distance bin edge
        num_bins: Number of distance bins
    
    Returns:
        dgram: Distance histogram [N_res, N_res, num_bins]
    """
    N_res = pos.shape[0]
    
    # Compute pairwise distances
    diff = pos[:, None, :] - pos[None, :, :]  # [N, N, 3]
    dist = np.sqrt(np.sum(diff**2, axis=-1) + 1e-8)  # [N, N]
    
    # Define bin edges
    bins = np.linspace(min_bin, max_bin, num_bins + 1)
    
    # One-hot binning
    dgram = np.zeros((N_res, N_res, num_bins))
    for i in range(num_bins):
        mask = (dist >= bins[i]) & (dist < bins[i+1])
        dgram[:, :, i] = mask.astype(np.float32)
    
    # Handle overflow (distances >= max_bin)
    overflow = dist >= max_bin
    dgram[:, :, -1] += overflow.astype(np.float32)
    
    return dgram

In [None]:
def recycling_embedder(m, z, prev_m, prev_z, prev_pos, c_m=256, c_z=128):
    """
    Recycling Embedder - Algorithm 32.
    
    Adds recycled information to current embeddings.
    
    Args:
        m: Current MSA representation [N_seq, N_res, c_m]
        z: Current pair representation [N_res, N_res, c_z]
        prev_m: Previous MSA first row [N_res, c_m] (or None)
        prev_z: Previous pair representation [N_res, N_res, c_z] (or None)
        prev_pos: Previous CA positions [N_res, 3] (or None)
        c_m: MSA channel dimension
        c_z: Pair channel dimension
    
    Returns:
        Updated m, z
    """
    N_seq, N_res, _ = m.shape
    num_bins = 15
    
    print(f"Recycling Embedder")
    print(f"="*50)
    print(f"MSA: [{N_seq}, {N_res}, {c_m}]")
    print(f"Pair: [{N_res}, {N_res}, {c_z}]")
    
    m_updated = m.copy()
    z_updated = z.copy()
    
    # STEP 1: Add distance histogram from previous positions (Lines 1-3)
    if prev_pos is not None:
        print(f"\nStep 1: Position → Distance histogram")
        
        # Compute distance histogram
        dgram = dgram_from_positions(prev_pos)
        print(f"  Distance histogram: {dgram.shape}")
        
        # Linear projection to pair dimension
        W_dgram = np.random.randn(num_bins, c_z) * (num_bins ** -0.5)
        z_dgram = np.einsum('ijd,dc->ijc', dgram, W_dgram)
        
        # Add to pair representation
        z_updated = z_updated + z_dgram
        print(f"  Added to pair: dgram_norm={np.linalg.norm(z_dgram):.4f}")
    
    # STEP 2: Add previous pair representation (Lines 4-5)
    if prev_z is not None:
        print(f"\nStep 2: Previous pair representation")
        
        # Layer norm and project
        prev_z_norm = layer_norm(prev_z)
        W_prev_z = np.random.randn(c_z, c_z) * (c_z ** -0.5)
        z_prev_embed = np.einsum('ijc,cd->ijd', prev_z_norm, W_prev_z)
        
        # Add to pair representation
        z_updated = z_updated + z_prev_embed
        print(f"  Added prev_z: norm={np.linalg.norm(z_prev_embed):.4f}")
    
    # STEP 3: Add previous MSA first row (Lines 6-7)
    if prev_m is not None:
        print(f"\nStep 3: Previous MSA first row")
        
        # Layer norm and project
        prev_m_norm = layer_norm(prev_m)
        W_prev_m = np.random.randn(c_m, c_m) * (c_m ** -0.5)
        m_prev_embed = np.einsum('rc,cd->rd', prev_m_norm, W_prev_m)
        
        # Add to MSA first row only
        m_updated[0] = m_updated[0] + m_prev_embed
        print(f"  Added prev_m: norm={np.linalg.norm(m_prev_embed):.4f}")
    
    print(f"\nOutput:")
    print(f"  MSA: {m_updated.shape}")
    print(f"  Pair: {z_updated.shape}")
    
    return m_updated, z_updated

## Test Examples

In [None]:
# Test 1: Basic functionality
print("Test 1: Basic Functionality")
print("="*60)

N_seq, N_res, c_m, c_z = 128, 32, 256, 128

# Current representations
m = np.random.randn(N_seq, N_res, c_m).astype(np.float32)
z = np.random.randn(N_res, N_res, c_z).astype(np.float32)

# Previous outputs
prev_m = np.random.randn(N_res, c_m).astype(np.float32)
prev_z = np.random.randn(N_res, N_res, c_z).astype(np.float32)
prev_pos = np.random.randn(N_res, 3).astype(np.float32) * 10  # CA positions

m_updated, z_updated = recycling_embedder(m, z, prev_m, prev_z, prev_pos)

In [None]:
# Test 2: First iteration (no previous outputs)
print("\nTest 2: First Iteration (No Previous Outputs)")
print("="*60)

N_seq, N_res, c_m, c_z = 64, 24, 256, 128

m = np.random.randn(N_seq, N_res, c_m).astype(np.float32)
z = np.random.randn(N_res, N_res, c_z).astype(np.float32)

# No previous outputs (first iteration)
m_updated, z_updated = recycling_embedder(m, z, None, None, None)

# Should be unchanged
m_unchanged = np.allclose(m, m_updated)
z_unchanged = np.allclose(z, z_updated)
print(f"\nMSA unchanged (expected): {m_unchanged}")
print(f"Pair unchanged (expected): {z_unchanged}")

In [None]:
# Test 3: Distance histogram analysis
print("\nTest 3: Distance Histogram Analysis")
print("="*60)

N_res = 32

# Create positions along a helix-like structure
t = np.linspace(0, 4 * np.pi, N_res)
pos = np.stack([
    t * 0.5,  # X: linear
    np.sin(t) * 5,  # Y: sinusoidal
    np.cos(t) * 5,  # Z: sinusoidal
], axis=1)

print(f"Positions shape: {pos.shape}")
print(f"Position range: [{pos.min():.1f}, {pos.max():.1f}]")

# Compute distance histogram
dgram = dgram_from_positions(pos)
print(f"\nDistance histogram: {dgram.shape}")

# Analyze histogram
print(f"Sum per pair (should be ~1): {dgram.sum(axis=-1).mean():.4f}")

# Show which bins are most populated
bin_counts = dgram.sum(axis=(0, 1))
print(f"\nBin populations:")
bins = np.linspace(3.25, 50.75, 16)
for i, (lo, hi, count) in enumerate(zip(bins[:-1], bins[1:], bin_counts)):
    if count > 0:
        print(f"  Bin {i} [{lo:.1f}-{hi:.1f}Å]: {count:.0f} pairs")

In [None]:
# Test 4: Effect of recycled info on representations
print("\nTest 4: Effect on Representations")
print("="*60)

np.random.seed(42)
N_seq, N_res, c_m, c_z = 64, 32, 128, 64

# Current (random) representations
m = np.random.randn(N_seq, N_res, c_m).astype(np.float32) * 0.1
z = np.random.randn(N_res, N_res, c_z).astype(np.float32) * 0.1

# Strong previous signal
prev_m = np.ones((N_res, c_m), dtype=np.float32)  # Uniform signal
prev_z = np.eye(N_res)[:, :, None] * np.ones(c_z)  # Diagonal pattern
prev_pos = np.random.randn(N_res, 3).astype(np.float32) * 10

m_updated, z_updated = recycling_embedder(m, z, prev_m, prev_z, prev_pos)

print(f"\nMSA first row change:")
m_diff = np.linalg.norm(m_updated[0] - m[0])
m_other_diff = np.linalg.norm(m_updated[1:] - m[1:])
print(f"  First row change: {m_diff:.4f}")
print(f"  Other rows change: {m_other_diff:.4f} (should be ~0)")

print(f"\nPair representation change:")
z_diff = np.linalg.norm(z_updated - z)
print(f"  Total change: {z_diff:.4f}")

## Verification: Key Properties

In [None]:
print("Verification: Key Properties")
print("="*60)

np.random.seed(42)
N_seq, N_res, c_m, c_z = 64, 24, 128, 64

m = np.random.randn(N_seq, N_res, c_m).astype(np.float32)
z = np.random.randn(N_res, N_res, c_z).astype(np.float32)
prev_m = np.random.randn(N_res, c_m).astype(np.float32)
prev_z = np.random.randn(N_res, N_res, c_z).astype(np.float32)
prev_pos = np.random.randn(N_res, 3).astype(np.float32) * 10

m_updated, z_updated = recycling_embedder(m, z, prev_m, prev_z, prev_pos)

# Property 1: Shape preserved
shape_preserved = m_updated.shape == m.shape and z_updated.shape == z.shape
print(f"Property 1 - Shape preserved: {shape_preserved}")

# Property 2: Finite output
all_finite = np.isfinite(m_updated).all() and np.isfinite(z_updated).all()
print(f"Property 2 - All finite: {all_finite}")

# Property 3: MSA first row affected, others unchanged
first_row_changed = not np.allclose(m_updated[0], m[0])
other_rows_unchanged = np.allclose(m_updated[1:], m[1:])
print(f"Property 3 - First row changed: {first_row_changed}")
print(f"           - Other rows unchanged: {other_rows_unchanged}")

# Property 4: Pair representation changed
pair_changed = not np.allclose(z_updated, z)
print(f"Property 4 - Pair changed: {pair_changed}")

# Property 5: Distance histogram is valid (sums to ~1)
dgram = dgram_from_positions(prev_pos)
dgram_valid = np.allclose(dgram.sum(axis=-1), 1.0, atol=0.1)
print(f"Property 5 - Distance histogram valid: {dgram_valid}")

## Source Code Reference

```python
# From AF2-source-code/model/modules.py (EmbeddingsAndEvoformer)

# Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder"

if self.config.recycle_pos:
    # Embed previous positions as distance histogram
    prev_pseudo_beta = pseudo_beta_fn(
        batch['aatype'], prev['prev_pos'], None)
    
    dgram = dgram_from_positions(
        prev_pseudo_beta,
        min_bin=self.config.min_bin,
        max_bin=self.config.max_bin,
        num_bins=self.config.num_bins)
    
    pair_activations += common_modules.Linear(
        c.pair_channel, name='prev_pos_linear')(dgram)

if self.config.recycle_features:
    # Add previous pair representation
    pair_activations += common_modules.Linear(
        c.pair_channel, name='prev_pair_norm')(
            common_modules.LayerNorm(name='prev_pair_norm')(
                prev['prev_pair']))
    
    # Add previous MSA first row
    msa_activations = msa_activations.at[0].add(
        common_modules.Linear(
            c.msa_channel, name='prev_msa_first_row_norm')(
                common_modules.LayerNorm(name='prev_msa_first_row_norm')(
                    prev['prev_msa_first_row'])))
```

## Key Insights

1. **Position → Distance**: Previous positions are converted to distance histograms, which are more suitable as features than raw coordinates (translation-invariant).

2. **Layer Normalization**: Previous representations are normalized before projection to stabilize training and handle varying scales.

3. **MSA First Row Only**: Only the first row of MSA (query sequence representation) receives recycled MSA information, not all sequences.

4. **Additive Updates**: Recycled information is added to (not replaced with) the current representations, allowing gradual refinement.

5. **Three Information Sources**: The model receives structural (positions), pairwise (pair repr), and sequence (MSA) information from previous iterations.

6. **Distance Binning**: Using 15 bins from 3.25Å to 50.75Å captures both local (secondary structure) and long-range (tertiary) distance information.