# Algorithm 16: Template Pair Stack

The Template Pair Stack processes template structure information to produce pair features that are combined with the main pair representation. It uses a simplified version of the Evoformer architecture.

## Algorithm Pseudocode

![TemplatePairStack](../imgs/algorithms/TemplatePairStack.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Class**: `TemplatePairStack`, `TemplatePairStackIteration`
- **Lines**: 522-623

## Overview

Templates are experimentally determined structures of homologous proteins. The Template Pair Stack:

1. **Input**: Template pair features [N_templ, N_res, N_res, c_t]
2. **Processing**: Multiple blocks of triangle operations (like Evoformer)
3. **Output**: Processed template features

### Block Structure

Each block contains:
- Triangle Multiplication (Outgoing)
- Triangle Multiplication (Incoming)
- Triangle Attention (Starting)
- Triangle Attention (Ending)
- Pair Transition

In [None]:
import numpy as np

np.random.seed(42)

## NumPy Implementation

In [None]:
def layer_norm(x, axis=-1, eps=1e-5):
    """Layer normalization."""
    mean = np.mean(x, axis=axis, keepdims=True)
    var = np.var(x, axis=axis, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)


def sigmoid(x):
    """Sigmoid activation."""
    return 1.0 / (1.0 + np.exp(-np.clip(x, -20, 20)))


def relu(x):
    """ReLU activation."""
    return np.maximum(0, x)


def triangle_multiplication_outgoing(z, c=32):
    """Simplified triangle multiplication (outgoing)."""
    N, _, c_z = z.shape
    z_norm = layer_norm(z)
    
    # Project to gates and values
    W_a = np.random.randn(c_z, c) * (c_z ** -0.5)
    W_b = np.random.randn(c_z, c) * (c_z ** -0.5)
    W_g = np.random.randn(c_z, c_z) * (c_z ** -0.5)
    W_o = np.random.randn(c, c_z) * (c ** -0.5)
    
    a = sigmoid(z_norm @ W_a)  # [N, N, c]
    b = sigmoid(z_norm @ W_b)
    g = sigmoid(z_norm @ W_g)  # [N, N, c_z]
    
    # Triangle update: sum over k of a[i,k] * b[k,j]
    # Simplified version
    ab = np.einsum('ikc,kjc->ijc', a, b)
    output = g * (layer_norm(ab) @ W_o)
    
    return output


def triangle_multiplication_incoming(z, c=32):
    """Simplified triangle multiplication (incoming)."""
    N, _, c_z = z.shape
    z_norm = layer_norm(z)
    
    W_a = np.random.randn(c_z, c) * (c_z ** -0.5)
    W_b = np.random.randn(c_z, c) * (c_z ** -0.5)
    W_g = np.random.randn(c_z, c_z) * (c_z ** -0.5)
    W_o = np.random.randn(c, c_z) * (c ** -0.5)
    
    a = sigmoid(z_norm @ W_a)
    b = sigmoid(z_norm @ W_b)
    g = sigmoid(z_norm @ W_g)
    
    # Triangle update: sum over k of a[k,i] * b[k,j]
    ab = np.einsum('kic,kjc->ijc', a, b)
    output = g * (layer_norm(ab) @ W_o)
    
    return output


def pair_transition(z, n=2):
    """Simplified pair transition."""
    N, _, c_z = z.shape
    c_hidden = n * c_z
    
    z_norm = layer_norm(z)
    
    W1 = np.random.randn(c_z, c_hidden) * (c_z ** -0.5)
    W2 = np.random.randn(c_hidden, c_z) * (c_hidden ** -0.5)
    
    a = relu(z_norm @ W1)
    output = a @ W2
    
    return output

In [None]:
def template_pair_stack_iteration(t, c_t=64, c_hidden=32):
    """
    One iteration of Template Pair Stack.
    
    Applies triangle operations to each template independently.
    
    Args:
        t: Template pair features [N_templ, N_res, N_res, c_t]
        c_t: Template channel dimension
        c_hidden: Hidden dimension for triangle operations
    
    Returns:
        Updated template features [N_templ, N_res, N_res, c_t]
    """
    N_templ, N_res, _, _ = t.shape
    
    output = np.zeros_like(t)
    
    for templ_idx in range(N_templ):
        z = t[templ_idx]  # [N_res, N_res, c_t]
        
        # Step 1: Triangle Multiplication (Outgoing)
        z = z + triangle_multiplication_outgoing(z, c=c_hidden)
        
        # Step 2: Triangle Multiplication (Incoming)
        z = z + triangle_multiplication_incoming(z, c=c_hidden)
        
        # Step 3: Pair Transition
        z = z + pair_transition(z, n=2)
        
        output[templ_idx] = z
    
    return output


def template_pair_stack(t, num_blocks=2, c_t=64, c_hidden=32):
    """
    Template Pair Stack - Algorithm 16.
    
    Processes template pair features through multiple blocks.
    
    Args:
        t: Template pair features [N_templ, N_res, N_res, c_t]
        num_blocks: Number of stack iterations
        c_t: Template channel dimension
        c_hidden: Hidden dimension for triangle operations
    
    Returns:
        Processed template features
    """
    N_templ, N_res, _, _ = t.shape
    
    print(f"Template Pair Stack")
    print(f"="*50)
    print(f"Templates: {N_templ}")
    print(f"Residues: {N_res}")
    print(f"Channels: {c_t}")
    print(f"Blocks: {num_blocks}")
    
    for block_idx in range(num_blocks):
        print(f"\nBlock {block_idx + 1}:")
        t = template_pair_stack_iteration(t, c_t=c_t, c_hidden=c_hidden)
        print(f"  Processed {N_templ} templates")
        print(f"  Output norm: {np.linalg.norm(t):.4f}")
    
    return t

## Test Examples

In [None]:
# Test 1: Basic functionality
print("Test 1: Basic Functionality")
print("="*60)

np.random.seed(42)
N_templ, N_res, c_t = 4, 32, 64
t = np.random.randn(N_templ, N_res, N_res, c_t).astype(np.float32)

output = template_pair_stack(t, num_blocks=2, c_t=c_t, c_hidden=32)

print(f"\nInput shape: {t.shape}")
print(f"Output shape: {output.shape}")
print(f"Shape preserved: {output.shape == t.shape}")

In [None]:
# Test 2: Template independence
print("\nTest 2: Template Independence")
print("="*60)

np.random.seed(42)
N_templ, N_res, c_t = 4, 16, 32
t = np.random.randn(N_templ, N_res, N_res, c_t).astype(np.float32)

# Process all templates
output_all = template_pair_stack(t.copy(), num_blocks=1)

# Process each template independently and compare
print("\nVerifying independence:")
for i in range(N_templ):
    np.random.seed(42)  # Reset seed for same weights
    single = t[i:i+1].copy()  # Single template [1, N, N, c]
    output_single = template_pair_stack(single, num_blocks=1)
    
    match = np.allclose(output_all[i], output_single[0], rtol=1e-4)
    print(f"  Template {i}: Independent processing matches: {match}")

In [None]:
# Test 3: Compare with standard Evoformer operations
print("\nTest 3: Component Analysis")
print("="*60)

np.random.seed(42)
N_res, c_t = 16, 32
z = np.random.randn(N_res, N_res, c_t).astype(np.float32)

print(f"Input: {z.shape}")

# Test each component
np.random.seed(42)
out_tri_out = triangle_multiplication_outgoing(z, c=16)
print(f"Triangle Mult (Out): {out_tri_out.shape}, norm={np.linalg.norm(out_tri_out):.4f}")

np.random.seed(42)
out_tri_in = triangle_multiplication_incoming(z, c=16)
print(f"Triangle Mult (In): {out_tri_in.shape}, norm={np.linalg.norm(out_tri_in):.4f}")

np.random.seed(42)
out_trans = pair_transition(z, n=2)
print(f"Pair Transition: {out_trans.shape}, norm={np.linalg.norm(out_trans):.4f}")

In [None]:
# Test 4: Multiple templates scenario
print("\nTest 4: Multiple Templates Scenario")
print("="*60)

# Simulate realistic template scenario
N_templ, N_res, c_t = 4, 64, 64

# Create templates with different quality levels
t = np.zeros((N_templ, N_res, N_res, c_t), dtype=np.float32)

for i in range(N_templ):
    # Better templates (lower index) have stronger signal
    signal_strength = 1.0 / (i + 1)
    t[i] = np.random.randn(N_res, N_res, c_t) * signal_strength
    print(f"Template {i}: signal strength = {signal_strength:.2f}, norm = {np.linalg.norm(t[i]):.2f}")

np.random.seed(42)
output = template_pair_stack(t, num_blocks=2)

print(f"\nAfter processing:")
for i in range(N_templ):
    print(f"Template {i}: norm = {np.linalg.norm(output[i]):.2f}")

## Verification: Key Properties

In [None]:
print("Verification: Key Properties")
print("="*60)

np.random.seed(42)
N_templ, N_res, c_t = 3, 24, 48
t = np.random.randn(N_templ, N_res, N_res, c_t).astype(np.float32)

output = template_pair_stack(t.copy(), num_blocks=2)

# Property 1: Shape preserved
shape_preserved = output.shape == t.shape
print(f"Property 1 - Shape preserved: {shape_preserved}")

# Property 2: Finite output
output_finite = np.isfinite(output).all()
print(f"Property 2 - Output finite: {output_finite}")

# Property 3: Non-trivial transformation
not_identity = not np.allclose(output, t)
print(f"Property 3 - Non-trivial: {not_identity}")

# Property 4: Each template processed independently
# Modifying one template shouldn't affect others
t_modified = t.copy()
t_modified[0] *= 2.0
np.random.seed(42)
output_modified = template_pair_stack(t_modified, num_blocks=2)

# Templates 1 and 2 should be unchanged
other_unchanged = np.allclose(output[1:], output_modified[1:], rtol=1e-4)
print(f"Property 4 - Template independence: {other_unchanged}")

# Property 5: Reasonable output scale
input_norm = np.linalg.norm(t)
output_norm = np.linalg.norm(output)
scale_ratio = output_norm / input_norm
print(f"Property 5 - Scale ratio: {scale_ratio:.4f}")

## Source Code Reference

```python
# From AF2-source-code/model/modules.py

class TemplatePairStack(hk.Module):
  """Pair stack for the templates.

  Jumper et al. (2021) Suppl. Alg. 16 "TemplatePairStack"
  """

  def __call__(self, pair_act, pair_mask, is_training, safe_key=None):
    c = self.config
    gc = self.global_config
    
    if c.num_block > 0:
      # Stack multiple iterations
      pair_act = hk.remat(
          TemplatePairStackIteration(c, gc, name='pair_stack_iteration'),
          ...)
    
    return pair_act


class TemplatePairStackIteration(hk.Module):
  """Single iteration of TemplatePairStack.
  
  Similar to EvoformerIteration but simpler (no MSA track).
  """
  
  def __call__(self, pair_act, pair_mask, is_training, safe_key):
    # Triangle multiplication (outgoing)
    pair_act += TriangleMultiplication(c, gc, name='triangle_multiplication_outgoing')(
        pair_act, pair_mask, is_training)
    
    # Triangle multiplication (incoming)
    pair_act += TriangleMultiplication(c, gc, name='triangle_multiplication_incoming')(
        pair_act, pair_mask, is_training)
    
    # Triangle attention (starting)
    pair_act += TriangleAttention(c, gc, name='triangle_attention_starting_node')(
        pair_act, pair_mask, is_training)
    
    # Triangle attention (ending)
    pair_act += TriangleAttention(c, gc, name='triangle_attention_ending_node')(
        pair_act, pair_mask, is_training)
    
    # Pair transition
    pair_act += Transition(c, gc, name='pair_transition')(
        pair_act, pair_mask, is_training)
    
    return pair_act
```

## Key Insights

1. **Template Processing**: Each template is processed independently, allowing the model to handle varying numbers of templates.

2. **Simplified Evoformer**: The Template Pair Stack is similar to the pair track of Evoformer, but without the MSA track.

3. **Fewer Blocks**: Typically uses 2 blocks vs 48 for the main Evoformer, reflecting the auxiliary nature of template information.

4. **Geometric Consistency**: The triangle operations ensure that template features capture geometric relationships consistent with protein structure.

5. **Integration**: After processing, template features are combined with the main pair representation via attention (Algorithm 17).

6. **Optional Component**: Templates are optional; the model can work without them for orphan proteins.