# Algorithm 2: Template Embedder (AlphaFold3)

The Template Embedder processes structural templates to provide prior knowledge about similar structures.

## Source Code Location
- **File**: `AF3-Ref-src/alphafold3-official/src/alphafold3/model/network/template_modules.py`
- **Class**: `TemplateEmbedding`

## Overview

### Template Features

| Feature | Description |
|---------|-------------|
| `template_distogram` | Binned distances between tokens [N_templ, N, N, 39] |
| `template_unit_vector` | Direction vectors [N_templ, N, N, 3] |
| `template_backbone_mask` | Valid backbone atoms [N_templ, N] |
| `template_restype` | Residue types [N_templ, N] |

In [None]:
import numpy as np
np.random.seed(42)

def layer_norm(x, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

In [None]:
def create_template_features(N_templ, N_tokens, num_bins=39):
    """
    Create template features.
    
    Args:
        N_templ: Number of templates
        N_tokens: Sequence length
        num_bins: Number of distance bins
    """
    print(f"Creating Template Features")
    print(f"  Templates: {N_templ}, Tokens: {N_tokens}")
    
    # Distance histogram (one-hot binned distances)
    distogram = np.zeros((N_templ, N_tokens, N_tokens, num_bins))
    for t in range(N_templ):
        # Simulate distances
        distances = np.random.uniform(3, 50, (N_tokens, N_tokens))
        distances = (distances + distances.T) / 2  # Symmetric
        
        # Bin distances
        bins = np.linspace(2, 22, num_bins + 1)
        for b in range(num_bins):
            mask = (distances >= bins[b]) & (distances < bins[b+1])
            distogram[t, :, :, b] = mask
    
    # Unit vectors
    unit_vector = np.random.randn(N_templ, N_tokens, N_tokens, 3)
    unit_vector = unit_vector / (np.linalg.norm(unit_vector, axis=-1, keepdims=True) + 1e-8)
    
    # Masks
    backbone_mask = np.random.choice([0, 1], (N_templ, N_tokens), p=[0.1, 0.9])
    
    # Residue types (one-hot)
    restype = np.random.randint(0, 20, (N_templ, N_tokens))
    restype_onehot = np.eye(32)[restype]
    
    return {
        'distogram': distogram,
        'unit_vector': unit_vector,
        'backbone_mask': backbone_mask,
        'restype': restype_onehot,
    }

In [None]:
def template_embedder(template_features, pair_repr, c=64):
    """
    Template Embedder - embeds template features into pair representation.
    
    Args:
        template_features: Dictionary of template features
        pair_repr: Current pair representation [N, N, c_z]
        c: Hidden dimension for template processing
    
    Returns:
        Template embedding to add to pair [N, N, c_z]
    """
    N_templ = template_features['distogram'].shape[0]
    N = template_features['distogram'].shape[1]
    c_z = pair_repr.shape[-1]
    
    print(f"Template Embedder")
    print(f"="*50)
    print(f"Templates: {N_templ}, Tokens: {N}")
    
    # Initialize output
    u = np.zeros((N, N, c))
    
    for t in range(N_templ):
        # Step 1: Concatenate features for this template
        # [N, N, 39] + [N, N, 3] + [N, N, 1] + [N, N, 32] + [N, N, 32]
        dgram = template_features['distogram'][t]
        unit_vec = template_features['unit_vector'][t]
        
        # Mask features
        bb_mask = template_features['backbone_mask'][t]
        pair_mask = bb_mask[:, None] * bb_mask[None, :]  # [N, N]
        
        # Residue type outer product
        restype = template_features['restype'][t]  # [N, 32]
        
        # Concatenate: 39 + 3 + 1 + 32 + 32 = 107 -> 108 with mask
        a_t = np.concatenate([
            dgram,  # [N, N, 39]
            unit_vec,  # [N, N, 3]
            pair_mask[:, :, None],  # [N, N, 1]
            np.tile(restype[:, None, :], (1, N, 1)),  # [N, N, 32]
            np.tile(restype[None, :, :], (N, 1, 1)),  # [N, N, 32]
        ], axis=-1)  # [N, N, 107]
        
        # Step 2: Project to hidden dim
        W_a = np.random.randn(a_t.shape[-1], c) * 0.02
        v = np.einsum('ijf,fc->ijc', a_t, W_a)
        
        # Step 3: Add projected pair representation
        pair_norm = layer_norm(pair_repr)
        W_z = np.random.randn(c_z, c) * 0.02
        v = v + np.einsum('ijc,cd->ijd', pair_norm, W_z)
        
        # Step 4: Process through simplified Pairformer block
        # (In real AF3, uses full PairformerStack)
        v = layer_norm(v)
        
        # Accumulate
        u = u + v
    
    # Step 5: Average and project
    u = u / max(N_templ, 1)
    u = layer_norm(u)
    
    # Final projection to pair dimension
    W_out = np.random.randn(c, c_z) * 0.02
    output = np.einsum('ijc,cd->ijd', np.maximum(u, 0), W_out)  # ReLU + project
    
    print(f"Output: {output.shape}")
    return output

In [None]:
# Test
print("Test: Template Embedder")
print("="*60)

N_templ = 4
N_tokens = 32
c_z = 128

template_features = create_template_features(N_templ, N_tokens)
pair_repr = np.random.randn(N_tokens, N_tokens, c_z)

output = template_embedder(template_features, pair_repr)

print(f"\nOutput shape: {output.shape}")
print(f"Output finite: {np.isfinite(output).all()}")

## Key Insights

1. **Multi-Template**: Processes multiple templates and averages their contributions
2. **Pair Conditioning**: Uses current pair representation as conditioning
3. **Feature Composition**: Combines distance, direction, mask, and residue type information
4. **Chain Masking**: Zero-out features for cross-chain pairs