# Algorithm 17: Template Pointwise Attention

Template Pointwise Attention aggregates information from multiple templates into the pair representation using attention. Each position in the pair representation attends over all templates.

## Algorithm Pseudocode

![TemplatePointwiseAttention](../imgs/algorithms/TemplatePointwiseAttention.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Class**: `Attention` (used for template attention)
- **Location**: Within `EmbeddingsAndEvoformer`

## Overview

When multiple templates are available:

1. **Query**: Current pair representation z[i,j]
2. **Key/Value**: Template features t[n,i,j] for n templates
3. **Output**: Weighted combination of template information

This allows the model to:
- Selectively use relevant template information
- Combine insights from multiple templates
- Weight templates by their relevance to each position

In [None]:
import numpy as np

np.random.seed(42)

## NumPy Implementation

In [None]:
def layer_norm(x, axis=-1, eps=1e-5):
    """Layer normalization."""
    mean = np.mean(x, axis=axis, keepdims=True)
    var = np.var(x, axis=axis, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)


def softmax(x, axis=-1):
    """Numerically stable softmax."""
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)


def template_pointwise_attention(z, t, num_heads=4):
    """
    Template Pointwise Attention - Algorithm 17.
    
    Aggregates template information using attention.
    
    Args:
        z: Pair representation [N_res, N_res, c_z] (query)
        t: Template features [N_templ, N_res, N_res, c_t] (key/value)
        num_heads: Number of attention heads
    
    Returns:
        Template-aggregated features [N_res, N_res, c_z]
    """
    N_res, _, c_z = z.shape
    N_templ = t.shape[0]
    c_t = t.shape[-1]
    c = c_z // num_heads
    
    print(f"Template Pointwise Attention")
    print(f"="*50)
    print(f"Query (pair): [{N_res}, {N_res}, {c_z}]")
    print(f"Key/Value (templates): [{N_templ}, {N_res}, {N_res}, {c_t}]")
    print(f"Heads: {num_heads}, Head dim: {c}")
    
    # Step 1: Layer normalize inputs
    z_norm = layer_norm(z)
    t_norm = layer_norm(t)
    
    # Step 2: Initialize projection weights
    W_q = np.random.randn(c_z, num_heads, c) * (c_z ** -0.5)
    W_k = np.random.randn(c_t, num_heads, c) * (c_t ** -0.5)
    W_v = np.random.randn(c_t, num_heads, c) * (c_t ** -0.5)
    W_o = np.random.randn(num_heads, c, c_z) * ((num_heads * c) ** -0.5)
    
    # Step 3: Compute query from pair representation (Line 2)
    # z: [N_res, N_res, c_z] -> q: [N_res, N_res, H, c]
    q = np.einsum('ijc,chd->ijhd', z_norm, W_q)
    print(f"\nStep 3 - Query: {q.shape}")
    
    # Step 4: Compute key/value from templates (Lines 3-4)
    # t: [N_templ, N_res, N_res, c_t] -> k,v: [N_templ, N_res, N_res, H, c]
    k = np.einsum('nijc,chd->nijhd', t_norm, W_k)
    v = np.einsum('nijc,chd->nijhd', t_norm, W_v)
    print(f"Step 4 - Key: {k.shape}, Value: {v.shape}")
    
    # Step 5: Compute attention (Lines 5-7)
    # For each (i,j) position, attend over N_templ templates
    # q[i,j]: [H, c], k[n,i,j]: [H, c] for n in templates
    
    # Attention logits: [N_res, N_res, H, N_templ]
    attn_logits = np.einsum('ijhd,nijhd->ijhn', q, k) / np.sqrt(c)
    print(f"Step 5 - Attention logits: {attn_logits.shape}")
    
    # Softmax over templates
    attn_weights = softmax(attn_logits, axis=-1)
    print(f"         Attention weights: {attn_weights.shape}")
    
    # Step 6: Apply attention to values (Line 8)
    # attended[i,j,h,d] = sum_n attn[i,j,h,n] * v[n,i,j,h,d]
    attended = np.einsum('ijhn,nijhd->ijhd', attn_weights, v)
    print(f"Step 6 - Attended: {attended.shape}")
    
    # Step 7: Output projection (Line 9)
    output = np.einsum('ijhd,hdc->ijc', attended, W_o)
    print(f"Step 7 - Output: {output.shape}")
    
    return output

## Test Examples

In [None]:
# Test 1: Basic functionality
print("Test 1: Basic Functionality")
print("="*60)

N_res, c_z = 32, 64
N_templ, c_t = 4, 64

z = np.random.randn(N_res, N_res, c_z).astype(np.float32)
t = np.random.randn(N_templ, N_res, N_res, c_t).astype(np.float32)

output = template_pointwise_attention(z, t, num_heads=4)

print(f"\nInput pair shape: {z.shape}")
print(f"Output shape: {output.shape}")
print(f"Shape matches pair: {output.shape == z.shape}")

In [None]:
# Test 2: Attention weight distribution
print("\nTest 2: Attention Weight Distribution")
print("="*60)

np.random.seed(42)
N_res, c_z = 16, 32
N_templ, c_t = 4, 32

z = np.random.randn(N_res, N_res, c_z).astype(np.float32)

# Create templates with different characteristics
# Template 0: Strong signal
# Template 3: Weak signal
t = np.zeros((N_templ, N_res, N_res, c_t), dtype=np.float32)
for n in range(N_templ):
    scale = 1.0 / (n + 1)  # Decreasing scale
    t[n] = np.random.randn(N_res, N_res, c_t) * scale
    print(f"Template {n}: scale={scale:.2f}, norm={np.linalg.norm(t[n]):.2f}")

output = template_pointwise_attention(z, t, num_heads=4)

In [None]:
# Test 3: Single template case
print("\nTest 3: Single Template")
print("="*60)

N_res, c_z = 24, 64
N_templ, c_t = 1, 64  # Single template

z = np.random.randn(N_res, N_res, c_z).astype(np.float32)
t = np.random.randn(N_templ, N_res, N_res, c_t).astype(np.float32)

output = template_pointwise_attention(z, t, num_heads=4)

print(f"\nWith single template, attention weights are all 1.0 (trivial softmax)")

In [None]:
# Test 4: Residual connection pattern
print("\nTest 4: Residual Connection Pattern")
print("="*60)

np.random.seed(42)
N_res, c_z = 16, 32
N_templ, c_t = 4, 32

z = np.random.randn(N_res, N_res, c_z).astype(np.float32)
t = np.random.randn(N_templ, N_res, N_res, c_t).astype(np.float32)

# In practice: z = z + TemplatePointwiseAttention(z, t)
delta = template_pointwise_attention(z, t, num_heads=4)
z_updated = z + delta

print(f"\nInput z norm: {np.linalg.norm(z):.4f}")
print(f"Delta norm: {np.linalg.norm(delta):.4f}")
print(f"Output z norm: {np.linalg.norm(z_updated):.4f}")
print(f"Delta/Input ratio: {np.linalg.norm(delta) / np.linalg.norm(z):.4f}")

## Verification: Key Properties

In [None]:
print("Verification: Key Properties")
print("="*60)

np.random.seed(42)
N_res, c_z = 20, 64
N_templ, c_t = 4, 64

z = np.random.randn(N_res, N_res, c_z).astype(np.float32)
t = np.random.randn(N_templ, N_res, N_res, c_t).astype(np.float32)

output = template_pointwise_attention(z, t, num_heads=4)

# Property 1: Output shape matches pair representation
shape_correct = output.shape == z.shape
print(f"Property 1 - Shape matches pair: {shape_correct}")

# Property 2: Finite output
output_finite = np.isfinite(output).all()
print(f"Property 2 - Output finite: {output_finite}")

# Property 3: Non-trivial (different from input)
not_identity = not np.allclose(output, z)
print(f"Property 3 - Non-trivial: {not_identity}")

# Property 4: Depends on templates
t_modified = t * 2.0
output_modified = template_pointwise_attention(z, t_modified, num_heads=4)
depends_on_templates = not np.allclose(output, output_modified)
print(f"Property 4 - Depends on templates: {depends_on_templates}")

# Property 5: Depends on query
z_modified = z * 2.0
output_z_modified = template_pointwise_attention(z_modified, t, num_heads=4)
depends_on_query = not np.allclose(output, output_z_modified)
print(f"Property 5 - Depends on query: {depends_on_query}")

## Source Code Reference

```python
# Template attention in EmbeddingsAndEvoformer
# Uses standard Attention class with:
#   - Query: current pair representation
#   - Key/Value: processed template features

# From modules.py (simplified):
template_pair_representation = TemplatePairStack(...)(template_feat)

# Pointwise attention: query from pair, k/v from templates
template_embedding = Attention(
    num_head=c.num_head,
    key_dim=c.key_dim,
    value_dim=c.value_dim
)(
    query=pair_activations,  # [N, N, c_z]
    key=template_pair_representation,   # [T, N, N, c_t]
    value=template_pair_representation,
)

pair_activations += template_embedding
```

## Key Insights

1. **Pointwise Operation**: The attention is "pointwise" because each (i,j) position in the pair representation independently attends over templates.

2. **Query/Key Separation**: Query comes from the current pair representation while keys/values come from templates, enabling content-based template selection.

3. **Weighted Combination**: The softmax attention allows the model to weight templates differently based on their relevance to each position.

4. **Optional Component**: When no templates are available, this attention is skipped (zeros added via residual connection).

5. **Complementary to MSA**: Templates provide structural prior information that complements the evolutionary information from MSA.

6. **Position-Specific Weighting**: Different (i,j) positions can weight the same template differently, allowing flexible template usage.