# Algorithm 3: Input Embedder

The Input Embedder converts raw sequence and MSA features into initial representations for the Evoformer. It creates both the MSA representation (`m`) and pair representation (`z`).

## Algorithm Pseudocode

![Input Embedder](../imgs/algorithms/InputEmbedder.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Class**: `EmbeddingsAndEvoformer`
- **Lines**: 1673-1760 (embedding portion)

## Overview

The Input Embedder performs several key operations:

1. **Target Features → 1D Embedding**: Embed target sequence features
2. **MSA Features → MSA Embedding**: Embed MSA features
3. **Pair Features**: Create initial pair representation from:
   - Outer product of target embeddings (left × right)
   - Relative position encoding (Algorithm 4)
   - Previous structure (for recycling)

In [None]:
import numpy as np

np.random.seed(42)

## Input Feature Description

| Feature | Shape | Description |
|---------|-------|-------------|
| `target_feat` | [N_res, 22] | One-hot amino acid type + features |
| `msa_feat` | [N_seq, N_res, 49] | MSA features (one-hot + deletion info) |
| `residue_index` | [N_res] | Residue position indices |
| `seq_mask` | [N_res] | Valid residue mask |

In [None]:
def one_hot(indices, num_classes):
    """Create one-hot encoding."""
    return np.eye(num_classes)[indices]


def relpos_encoding(residue_index, max_relative_feature=32):
    """
    Relative position encoding (Algorithm 4 + Algorithm 5).
    
    Creates one-hot encoding of clipped relative positions.
    
    Args:
        residue_index: [N_res] residue indices
        max_relative_feature: Maximum relative distance to encode
    
    Returns:
        rel_pos: [N_res, N_res, 2*max_relative_feature+1] one-hot encoding
    """
    N_res = len(residue_index)
    
    # Compute pairwise position offsets
    offset = residue_index[:, None] - residue_index[None, :]
    
    # Clip to [-max, max] and shift to [0, 2*max]
    clipped = np.clip(offset + max_relative_feature, 0, 2 * max_relative_feature)
    
    # One-hot encode
    rel_pos = one_hot(clipped.astype(int), 2 * max_relative_feature + 1)
    
    return rel_pos

## NumPy Implementation

In [None]:
def input_embedder(
    target_feat,      # [N_res, d_target] target sequence features
    msa_feat,         # [N_seq, N_res, d_msa] MSA features
    residue_index,    # [N_res] residue indices
    c_m=256,          # MSA channel dimension
    c_z=128,          # Pair channel dimension
    max_relative_feature=32  # Max relative position
):
    """
    Input Embedder - Algorithm 3.
    
    Embeds input features into initial MSA and pair representations.
    
    Args:
        target_feat: Target sequence features [N_res, d_target]
        msa_feat: MSA features [N_seq, N_res, d_msa]
        residue_index: Residue position indices [N_res]
        c_m: MSA channel dimension
        c_z: Pair channel dimension
        max_relative_feature: Maximum relative position to encode
    
    Returns:
        msa_act: MSA representation [N_seq, N_res, c_m]
        pair_act: Pair representation [N_res, N_res, c_z]
    """
    N_res, d_target = target_feat.shape
    N_seq, _, d_msa = msa_feat.shape
    
    print(f"Input Embedder")
    print(f"="*50)
    print(f"Target features: {target_feat.shape}")
    print(f"MSA features: {msa_feat.shape}")
    print(f"Residue indices: {residue_index.shape}")
    print()
    
    # ========== Step 1: Embed target sequence (Line 1) ==========
    # Linear projection of target features
    w_target = np.random.randn(d_target, c_m) * 0.01
    preprocess_1d = target_feat @ w_target  # [N_res, c_m]
    
    print(f"Step 1 - Target embedding: {preprocess_1d.shape}")
    
    # ========== Step 2: Embed MSA (Line 2) ==========
    # Linear projection of MSA features
    w_msa = np.random.randn(d_msa, c_m) * 0.01
    preprocess_msa = np.einsum('sra,ad->srd', msa_feat, w_msa)  # [N_seq, N_res, c_m]
    
    print(f"Step 2 - MSA embedding: {preprocess_msa.shape}")
    
    # ========== Step 3: Combine for MSA representation (Line 3) ==========
    # Add target embedding (broadcasted) to MSA embedding
    msa_act = preprocess_1d[None, :, :] + preprocess_msa  # [N_seq, N_res, c_m]
    
    print(f"Step 3 - Combined MSA representation: {msa_act.shape}")
    
    # ========== Step 4: Create pair representation (Lines 4-5) ==========
    # Left projection
    w_left = np.random.randn(d_target, c_z) * 0.01
    left_single = target_feat @ w_left  # [N_res, c_z]
    
    # Right projection
    w_right = np.random.randn(d_target, c_z) * 0.01
    right_single = target_feat @ w_right  # [N_res, c_z]
    
    # Outer sum to create pair features
    pair_act = left_single[:, None, :] + right_single[None, :, :]  # [N_res, N_res, c_z]
    
    print(f"Step 4 - Initial pair representation: {pair_act.shape}")
    
    # ========== Step 5: Add relative position encoding (Line 6) ==========
    # Compute relative position features
    rel_pos = relpos_encoding(residue_index, max_relative_feature)
    
    print(f"Step 5 - Relative position encoding: {rel_pos.shape}")
    
    # Project relative position to pair dimension
    d_relpos = 2 * max_relative_feature + 1
    w_relpos = np.random.randn(d_relpos, c_z) * 0.01
    rel_pos_proj = np.einsum('ija,ad->ijd', rel_pos, w_relpos)  # [N_res, N_res, c_z]
    
    # Add to pair representation
    pair_act = pair_act + rel_pos_proj
    
    print(f"Step 5 - Pair with relpos: {pair_act.shape}")
    
    return msa_act, pair_act

## Test Example

In [None]:
# Test parameters
N_res = 64      # Number of residues
N_seq = 128     # Number of MSA sequences
d_target = 22   # Target feature dimension (21 amino acids + 1)
d_msa = 49      # MSA feature dimension (23 amino acids + deletion features)

# Create test inputs
# Random amino acid sequence (one-hot encoded)
aatype = np.random.randint(0, 21, size=N_res)
target_feat = one_hot(aatype, d_target).astype(np.float32)

# Random MSA (one-hot + features)
msa_aatype = np.random.randint(0, 23, size=(N_seq, N_res))  # 23 = 21 AA + gap + mask
msa_onehot = one_hot(msa_aatype, 23)
msa_features = np.concatenate([
    msa_onehot,
    np.random.rand(N_seq, N_res, d_msa - 23)  # Additional features
], axis=-1).astype(np.float32)

# Residue indices
residue_index = np.arange(N_res)

print(f"Test Input Summary")
print(f"="*50)
print(f"Sequence length: {N_res}")
print(f"MSA depth: {N_seq}")
print(f"Target features: {target_feat.shape}")
print(f"MSA features: {msa_features.shape}")
print()

In [None]:
# Run input embedder
msa_act, pair_act = input_embedder(
    target_feat,
    msa_features,
    residue_index,
    c_m=256,
    c_z=128,
    max_relative_feature=32
)

print(f"\nOutput Summary")
print(f"="*50)
print(f"MSA representation: {msa_act.shape}")
print(f"Pair representation: {pair_act.shape}")
print(f"\nMSA stats: mean={msa_act.mean():.4f}, std={msa_act.std():.4f}")
print(f"Pair stats: mean={pair_act.mean():.4f}, std={pair_act.std():.4f}")

## Visualize Pair Representation Structure

In [None]:
# The pair representation should show structure from relative positions
# Check the diagonal pattern from relpos encoding

# Take first channel and look at structure
pair_channel_0 = pair_act[:, :, 0]

print("Pair representation structure (first channel):")
print(f"  Diagonal mean: {np.diag(pair_channel_0).mean():.4f}")
print(f"  Off-diagonal mean: {(pair_channel_0 - np.diag(np.diag(pair_channel_0))).mean():.4f}")

# Check symmetry
asymmetry = np.abs(pair_channel_0 - pair_channel_0.T).mean()
print(f"  Asymmetry: {asymmetry:.6f}")

## Source Code Reference

```python
# From AF2-source-code/model/modules.py (EmbeddingsAndEvoformer.__call__)

# Embed clustered MSA.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5
# Jumper et al. (2021) Suppl. Alg. 3 "InputEmbedder"
preprocess_1d = common_modules.Linear(
    c.msa_channel, name='preprocess_1d')(batch['target_feat'])

preprocess_msa = common_modules.Linear(
    c.msa_channel, name='preprocess_msa')(batch['msa_feat'])

msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa

left_single = common_modules.Linear(
    c.pair_channel, name='left_single')(batch['target_feat'])
right_single = common_modules.Linear(
    c.pair_channel, name='right_single')(batch['target_feat'])
pair_activations = left_single[:, None] + right_single[None]

# Relative position encoding
# Jumper et al. (2021) Suppl. Alg. 4 "relpos"
# Jumper et al. (2021) Suppl. Alg. 5 "one_hot"
if c.max_relative_feature:
    pos = batch['residue_index']
    offset = pos[:, None] - pos[None, :]
    rel_pos = jax.nn.one_hot(
        jnp.clip(offset + c.max_relative_feature, 0, 2 * c.max_relative_feature),
        2 * c.max_relative_feature + 1)
    pair_activations += common_modules.Linear(
        c.pair_channel, name='pair_activiations')(rel_pos)
```

## Key Insights

1. **Two Representations**: Creates both MSA (sequence-level) and pair (residue-pair-level) representations.

2. **Target + MSA**: MSA representation combines target sequence info (broadcast) with MSA-specific features.

3. **Outer Sum**: Pair representation starts as outer sum of left/right target projections.

4. **Relative Position**: Position encoding is crucial for understanding sequence structure.

5. **Linear Projections**: All embeddings use simple linear layers, complexity comes from downstream modules.