# Algorithm 32: Recycling Embedder

The Recycling Embedder processes outputs from previous iterations and adds them to the current input embeddings, enabling iterative refinement.

## Algorithm Pseudocode

![RecyclingEmbedder](../imgs/algorithms/RecyclingEmbedder.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Class**: `EmbeddingsAndEvoformer` (recycling section)
- **Lines**: 1700-1740

## Recycled Information

1. **prev_pos**: Previous atom positions → Pairwise distances → Pair embedding
2. **prev_msa_first_row**: Previous MSA first row → MSA embedding
3. **prev_pair**: Previous pair representation → Pair embedding

In [None]:
import numpy as np

np.random.seed(42)

In [None]:
def dgram_from_positions(pos, min_bin=3.25, max_bin=50.75, num_bins=15):
    """
    Compute distance histogram from positions.
    
    Args:
        pos: Atom positions [N_res, 3]
        min_bin, max_bin: Distance range
        num_bins: Number of bins
    
    Returns:
        dgram: Distance histogram [N_res, N_res, num_bins]
    """
    N_res = pos.shape[0]
    
    # Pairwise distances
    diff = pos[:, None, :] - pos[None, :, :]  # [N, N, 3]
    dist = np.sqrt(np.sum(diff**2, axis=-1))  # [N, N]
    
    # Bin edges
    bins = np.linspace(min_bin, max_bin, num_bins + 1)
    
    # One-hot histogram
    dgram = np.zeros((N_res, N_res, num_bins))
    for i in range(num_bins):
        mask = (dist >= bins[i]) & (dist < bins[i+1])
        dgram[:, :, i] = mask.astype(float)
    
    return dgram


def recycling_embedder(m, z, prev_m, prev_z, prev_pos, c_m=256, c_z=128):
    """
    Recycling Embedder - Algorithm 32.
    
    Adds recycled information to current embeddings.
    
    Args:
        m: Current MSA representation [N_seq, N_res, c_m]
        z: Current pair representation [N_res, N_res, c_z]
        prev_m: Previous MSA first row [N_res, c_m]
        prev_z: Previous pair representation [N_res, N_res, c_z]
        prev_pos: Previous Cα positions [N_res, 3]
    
    Returns:
        Updated m, z
    """
    N_seq, N_res, _ = m.shape
    
    print(f"Recycling Embedder")
    print(f"  MSA: [{N_seq}, {N_res}, {c_m}]")
    print(f"  Pair: [{N_res}, {N_res}, {c_z}]")
    
    # Step 1: Compute distance histogram from previous positions
    dgram = dgram_from_positions(prev_pos)  # [N, N, 15]
    print(f"  Distance histogram: {dgram.shape}")
    
    # Step 2: Embed distance histogram
    W_dgram = np.random.randn(15, c_z) * 0.02
    z_dgram = np.einsum('ijd,dc->ijc', dgram, W_dgram)  # [N, N, c_z]
    
    # Step 3: Layer norm and project previous pair
    prev_z_norm = (prev_z - prev_z.mean(axis=-1, keepdims=True)) / (prev_z.std(axis=-1, keepdims=True) + 1e-5)
    W_prev_z = np.random.randn(c_z, c_z) * 0.02
    z_prev_embed = np.einsum('ijc,cd->ijd', prev_z_norm, W_prev_z)
    
    # Step 4: Add to pair representation
    z_updated = z + z_dgram + z_prev_embed
    print(f"  Pair updated with dgram + prev_z")
    
    # Step 5: Layer norm and project previous MSA first row
    prev_m_norm = (prev_m - prev_m.mean(axis=-1, keepdims=True)) / (prev_m.std(axis=-1, keepdims=True) + 1e-5)
    W_prev_m = np.random.randn(c_m, c_m) * 0.02
    m_prev_embed = np.einsum('rc,cd->rd', prev_m_norm, W_prev_m)  # [N_res, c_m]
    
    # Step 6: Add to first row of MSA
    m_updated = m.copy()
    m_updated[0] = m_updated[0] + m_prev_embed
    print(f"  MSA first row updated with prev_m")
    
    return m_updated, z_updated

In [None]:
# Test
N_seq, N_res, c_m, c_z = 128, 32, 256, 128

# Current representations
m = np.random.randn(N_seq, N_res, c_m)
z = np.random.randn(N_res, N_res, c_z)

# Previous outputs
prev_m = np.random.randn(N_res, c_m)
prev_z = np.random.randn(N_res, N_res, c_z)
prev_pos = np.random.randn(N_res, 3) * 10  # Cα positions

print("Test Recycling Embedder")
print("="*50)

m_updated, z_updated = recycling_embedder(m, z, prev_m, prev_z, prev_pos)

print(f"\nOutput shapes:")
print(f"  MSA: {m_updated.shape}")
print(f"  Pair: {z_updated.shape}")

## Source Code Reference

```python
# From AF2-source-code/model/modules.py

# In EmbeddingsAndEvoformer:
# Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder"

# Embed previous positions as distance histogram
prev_pseudo_beta = pseudo_beta_fn(batch['aatype'], prev['prev_pos'], None)
dgram = dgram_from_positions(prev_pseudo_beta, ...)
pair_activations += common_modules.Linear(c.pair_channel)(dgram)

# Add previous pair representation
pair_activations += common_modules.Linear(c.pair_channel)(
    common_modules.LayerNorm()(prev['prev_pair']))

# Add previous MSA first row
msa_activations[0] += common_modules.Linear(c.msa_channel)(
    common_modules.LayerNorm()(prev['prev_msa_first_row']))
```