# Algorithm 5: Input Embedder v2 (Boltz-2)

Enhanced input embedding for Boltz-2.

## Source Code Location
- **File**: `Boltz-Ref-src/boltz-official/src/boltz/model/modules/trunkv2.py`
- **Class**: `InputEmbedder`

In [None]:
import numpy as np
np.random.seed(42)

def layer_norm(x, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

In [None]:
def atom_encoder(atom_features, atom_to_token, c_atom=64):
    """Encode atom features to token level."""
    N_atoms = atom_features.shape[0]
    N_tokens = atom_to_token.max() + 1
    
    # Simple mean pooling per token
    token_features = np.zeros((N_tokens, c_atom))
    counts = np.zeros(N_tokens)
    
    for a in range(N_atoms):
        t = atom_to_token[a]
        token_features[t] += atom_features[a]
        counts[t] += 1
    
    token_features = token_features / np.maximum(counts[:, None], 1)
    return token_features

In [None]:
def input_embedder_v2(res_type, profile, deletion_mean, pocket_feature, 
                       atom_features=None, atom_to_token=None, token_s=384):
    """
    Input Embedder v2 for Boltz-2.
    
    Enhanced with better atom encoding and pocket features.
    
    Args:
        res_type: Residue type one-hot [N, 32]
        profile: MSA profile [N, 32]
        deletion_mean: Deletion frequency [N]
        pocket_feature: Pocket indicators [N, pocket_dim]
        atom_features: Per-atom features [N_atoms, c_atom]
        atom_to_token: Atom to token mapping [N_atoms]
        token_s: Output dimension
    
    Returns:
        Token embeddings [N, token_s]
    """
    N = res_type.shape[0]
    
    print(f"Input Embedder v2 (Boltz-2)")
    print(f"="*50)
    print(f"Tokens: {N}")
    
    # Atom features (enhanced in v2)
    if atom_features is not None and atom_to_token is not None:
        atom_emb = atom_encoder(atom_features, atom_to_token, c_atom=64)
        # Project to token_s
        W_atom = np.random.randn(64, token_s) * (64 ** -0.5)
        atom_emb = atom_emb @ W_atom
        print(f"  Atom features encoded")
    else:
        atom_emb = np.zeros((N, token_s))
    
    # Expand deletion_mean
    deletion_mean = deletion_mean[:, np.newaxis]
    
    # Concatenate all features
    features = np.concatenate([
        atom_emb,
        res_type,
        profile,
        deletion_mean,
        pocket_feature,
    ], axis=-1)
    
    print(f"  Combined features: {features.shape}")
    
    # Project to token dimension
    W = np.random.randn(features.shape[-1], token_s) * (features.shape[-1] ** -0.5)
    output = features @ W
    
    print(f"  Output: {output.shape}")
    
    return output

In [None]:
# Test
print("Test: Input Embedder v2")
print("="*60)

N = 50
N_atoms = 200

res_type = np.eye(32)[np.random.randint(0, 20, N)]
profile = np.random.dirichlet(np.ones(32), N)
deletion_mean = np.random.uniform(0, 0.5, N)
pocket_feature = np.random.randint(0, 2, (N, 4)).astype(np.float32)

# With atom features
atom_features = np.random.randn(N_atoms, 64)
atom_to_token = np.random.randint(0, N, N_atoms)

output = input_embedder_v2(
    res_type, profile, deletion_mean, pocket_feature,
    atom_features, atom_to_token, token_s=256
)

print(f"\nOutput finite: {np.isfinite(output).all()}")

## Key Insights

1. **Enhanced Atom Encoding**: Better atomâ†’token aggregation
2. **Pocket Features**: Important for drug binding
3. **Unified Embedding**: Combines all input modalities
4. **Ligand Support**: Works for proteins and small molecules