# Algorithm 1: Input Embedder (Boltz)

The Input Embedder combines multiple input features into initial token representations.

## Source Code Location
- **File**: `Boltz-Ref-src/boltz-official/src/boltz/model/modules/trunk.py`
- **Class**: `InputEmbedder`

## Overview

### Input Features

| Feature | Description |
|---------|-------------|
| `res_type` | Residue type one-hot encoding |
| `profile` | MSA profile (amino acid frequencies) |
| `deletion_mean` | Average deletion frequency |
| `pocket_feature` | Binding pocket indicators |
| Atom features | From AtomAttentionEncoder |

In [None]:
import numpy as np
np.random.seed(42)

def layer_norm(x, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

In [None]:
def input_embedder(res_type, profile, deletion_mean, pocket_feature, atom_features=None, token_s=384):
    """
    Input Embedder - combines input features into token representations.
    
    Args:
        res_type: One-hot residue types [N, 32]
        profile: MSA profile [N, 32]
        deletion_mean: Deletion frequency [N]
        pocket_feature: Pocket indicators [N, pocket_dim]
        atom_features: From atom encoder [N, atom_dim] (optional)
        token_s: Output token dimension
    
    Returns:
        Token embeddings [N, token_s]
    """
    N = res_type.shape[0]
    
    print(f"Input Embedder")
    print(f"="*50)
    print(f"Tokens: {N}")
    
    # Expand deletion_mean
    deletion_mean = deletion_mean[:, np.newaxis]  # [N, 1]
    
    # Atom features (if not provided, use zeros)
    if atom_features is None:
        atom_features = np.zeros((N, token_s))
        print(f"  Atom features: zeros [{N}, {token_s}]")
    else:
        print(f"  Atom features: [{N}, {atom_features.shape[-1]}]")
    
    # Concatenate all features
    features = np.concatenate([
        atom_features,     # [N, token_s or atom_dim]
        res_type,          # [N, 32]
        profile,           # [N, 32]
        deletion_mean,     # [N, 1]
        pocket_feature,    # [N, pocket_dim]
    ], axis=-1)
    
    print(f"  Concatenated: {features.shape}")
    
    # Project to token dimension
    W = np.random.randn(features.shape[-1], token_s) * (features.shape[-1] ** -0.5)
    output = features @ W
    
    print(f"  Output: {output.shape}")
    
    return output

In [None]:
# Test
print("Test: Input Embedder")
print("="*60)

N = 50
num_res_types = 32
pocket_dim = 4

# Simulate input features
res_type = np.eye(num_res_types)[np.random.randint(0, 20, N)]  # One-hot
profile = np.random.dirichlet(np.ones(num_res_types), N)  # Frequencies
deletion_mean = np.random.uniform(0, 0.5, N)
pocket_feature = np.random.randint(0, 2, (N, pocket_dim)).astype(np.float32)

output = input_embedder(res_type, profile, deletion_mean, pocket_feature, token_s=256)

print(f"\nOutput shape: {output.shape}")
print(f"Output finite: {np.isfinite(output).all()}")

## Key Insights

1. **Feature Concatenation**: Combines residue, MSA, and atom-level features
2. **Atom Encoder Integration**: Can use AtomAttentionEncoder for atom-level info
3. **Pocket Features**: Includes binding pocket indicators (useful for drug design)
4. **Flexible Dimensions**: Projects to desired token dimension