# Algorithm 17: Template Pointwise Attention

Template Pointwise Attention aggregates information from multiple templates into a single pair representation using attention.

## Algorithm Pseudocode

![TemplatePointwiseAttention](../imgs/algorithms/TemplatePointwiseAttention.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Class**: `Attention` (used for template attention)
- **Location**: Within `EmbeddingsAndEvoformer`

## Purpose

When multiple templates are available, this attention mechanism:
1. Uses the current pair representation as query
2. Attends over template features
3. Produces a weighted combination of template information

In [None]:
import numpy as np

np.random.seed(42)

In [None]:
def template_pointwise_attention(z, t, num_heads=4):
    """
    Template Pointwise Attention - Algorithm 17.
    
    Aggregates template information using attention.
    
    Args:
        z: Pair representation [N_res, N_res, c_z] (query)
        t: Template features [N_templ, N_res, N_res, c_t] (key/value)
        num_heads: Number of attention heads
    
    Returns:
        Template-aggregated features [N_res, N_res, c_z]
    """
    N_res, _, c_z = z.shape
    N_templ = t.shape[0]
    c_t = t.shape[-1]
    c = c_z // num_heads
    
    print(f"Template Pointwise Attention")
    print(f"  Query (pair): [{N_res}, {N_res}, {c_z}]")
    print(f"  Key/Value (templates): [{N_templ}, {N_res}, {N_res}, {c_t}]")
    print(f"  Heads: {num_heads}, Head dim: {c}")
    
    # Weights
    W_q = np.random.randn(c_z, num_heads, c) * 0.02
    W_k = np.random.randn(c_t, num_heads, c) * 0.02
    W_v = np.random.randn(c_t, num_heads, c) * 0.02
    W_o = np.random.randn(num_heads, c, c_z) * 0.02
    
    # Query from pair representation
    # z: [N_res, N_res, c_z] -> q: [N_res, N_res, H, c]
    q = np.einsum('ijc,chd->ijhd', z, W_q)
    
    # Key/Value from templates
    # t: [N_templ, N_res, N_res, c_t] -> k,v: [N_templ, N_res, N_res, H, c]
    k = np.einsum('tijc,chd->tijhd', t, W_k)
    v = np.einsum('tijc,chd->tijhd', t, W_v)
    
    # Attention over templates
    # For each (i,j) position, attend over N_templ templates
    # q[i,j]: [H, c], k[t,i,j]: [H, c] for t in templates
    
    # Reshape for batch attention
    q_flat = q.reshape(N_res * N_res, num_heads, c)  # [N*N, H, c]
    k_flat = k.reshape(N_templ, N_res * N_res, num_heads, c).transpose(1, 0, 2, 3)  # [N*N, T, H, c]
    v_flat = v.reshape(N_templ, N_res * N_res, num_heads, c).transpose(1, 0, 2, 3)
    
    # Attention weights: [N*N, H, T]
    attn_logits = np.einsum('bhc,bthc->bht', q_flat, k_flat) / np.sqrt(c)
    attn_weights = np.exp(attn_logits - attn_logits.max(axis=-1, keepdims=True))
    attn_weights /= attn_weights.sum(axis=-1, keepdims=True)
    
    # Apply attention: [N*N, H, c]
    attended = np.einsum('bht,bthc->bhc', attn_weights, v_flat)
    
    # Reshape and project
    attended = attended.reshape(N_res, N_res, num_heads, c)
    output = np.einsum('ijhc,hcd->ijd', attended, W_o)
    
    print(f"  Output: {output.shape}")
    
    return output

In [None]:
# Test
N_res, c_z = 32, 64
N_templ, c_t = 4, 64

z = np.random.randn(N_res, N_res, c_z)
t = np.random.randn(N_templ, N_res, N_res, c_t)

print("Test Template Pointwise Attention")
print("="*50)

output = template_pointwise_attention(z, t, num_heads=4)
print(f"\nShape matches pair repr: {output.shape == z.shape}")

## Source Code Reference

```python
# Template attention in EmbeddingsAndEvoformer
# Uses standard Attention class with:
#   - Query: current pair representation
#   - Key/Value: processed template features

# The attention mechanism aggregates information from
# multiple templates into the pair representation
```