# Algorithm 6: Template Module v2 (Boltz-2)

Enhanced template processing for Boltz-2.

## Source Code Location
- **File**: `Boltz-Ref-src/boltz-official/src/boltz/model/modules/trunkv2.py`
- **Classes**: `TemplateModule`, `TemplateV2Module`

In [None]:
import numpy as np
np.random.seed(42)

def layer_norm(x, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

def softmax(x, axis=-1):
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

In [None]:
def template_module_v2(template_features, z, num_templates=4, token_z=128):
    """
    Template Module v2.
    
    Enhanced template processing with multi-template attention.
    
    Args:
        template_features: Template features [num_templates, N, N, c_t]
        z: Current pair representation [N, N, c_z]
        num_templates: Number of templates
        token_z: Pair dimension
    
    Returns:
        Updated pair representation [N, N, c_z]
    """
    T, N, _, c_t = template_features.shape
    c_z = z.shape[-1]
    
    print(f"Template Module v2 (Boltz-2)")
    print(f"="*50)
    print(f"Templates: {T}, Tokens: {N}")
    
    # Project templates to token_z dimension
    W_t = np.random.randn(c_t, token_z) * (c_t ** -0.5)
    template_embeds = np.zeros((T, N, N, token_z))
    for t in range(T):
        t_norm = layer_norm(template_features[t])
        template_embeds[t] = t_norm @ W_t
    
    # Compute query from current pair representation
    z_norm = layer_norm(z)
    W_q = np.random.randn(c_z, token_z) * (c_z ** -0.5)
    q = z_norm @ W_q  # [N, N, token_z]
    
    # Attention over templates for each (i,j) pair
    # Compute similarity scores
    attn_scores = np.zeros((T, N, N))
    for t in range(T):
        # Dot product between query and template
        attn_scores[t] = np.sum(q * template_embeds[t], axis=-1) / np.sqrt(token_z)
    
    # Softmax over templates
    attn = softmax(attn_scores, axis=0)  # [T, N, N]
    
    # Weighted combination of templates
    template_combined = np.zeros((N, N, token_z))
    for t in range(T):
        template_combined += attn[t, :, :, np.newaxis] * template_embeds[t]
    
    # Project back to c_z
    W_o = np.random.randn(token_z, c_z) * (token_z ** -0.5)
    update = template_combined @ W_o
    
    output = z + update
    
    print(f"Output: {output.shape}")
    
    return output

In [None]:
# Test
print("Test: Template Module v2")
print("="*60)

N = 32
num_templates = 4
c_t = 64
c_z = 128

template_features = np.random.randn(num_templates, N, N, c_t)
z = np.random.randn(N, N, c_z)

z_out = template_module_v2(template_features, z)

print(f"\nOutput finite: {np.isfinite(z_out).all()}")

## Key Insights

1. **Multi-Template**: Handles multiple templates with attention
2. **Learned Weighting**: Attention over templates per position
3. **Pair Integration**: Adds template info to pair representation
4. **Flexible**: Works with varying numbers of templates