# Algorithm 4: Relative Position Encoding (relpos)

Relative position encoding adds positional information to the pair representation by encoding the distance between residue positions. This is crucial for the model to understand sequence structure and relative positioning.

## Algorithm Pseudocode

![relpos](../imgs/algorithms/relpos.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Location**: `EmbeddingsAndEvoformer.__call__`
- **Lines**: 1744-1758

## Overview

Relative position encoding captures the sequential distance between residues:

1. **Compute Offset**: For each residue pair (i, j), compute `offset = pos[i] - pos[j]`
2. **Clip Range**: Limit to `[-max_relative_feature, max_relative_feature]` (typically [-32, 32])
3. **Shift to Positive**: Add `max_relative_feature` to get values in `[0, 2*max+1]`
4. **One-Hot Encode**: Convert to one-hot vectors (65 classes for max=32)

This encoding helps the model understand:
- Adjacent residues (offset = ±1)
- Local structure (|offset| < 5)
- Long-range interactions (|offset| > 32, clipped)

In [None]:
import numpy as np

np.random.seed(42)

## NumPy Implementation

In [None]:
def one_hot(indices, num_classes, dtype=np.float32):
    """
    Algorithm 5: One-hot encoding.
    
    Converts integer indices to one-hot vectors.
    
    Args:
        indices: Integer array of any shape
        num_classes: Number of output classes
        dtype: Output data type
    
    Returns:
        One-hot array with shape [..., num_classes]
    """
    # Create identity matrix and index into it
    return np.eye(num_classes, dtype=dtype)[indices]


def relpos(residue_index, max_relative_feature=32):
    """
    Relative Position Encoding - Algorithm 4.
    
    Computes one-hot encoded relative positions between all residue pairs.
    
    Args:
        residue_index: [N_res] array of residue position indices
        max_relative_feature: Maximum relative distance to encode (default: 32)
    
    Returns:
        rel_pos: [N_res, N_res, 2*max+1] one-hot encoded relative positions
    """
    N_res = len(residue_index)
    num_classes = 2 * max_relative_feature + 1  # 65 for max=32
    
    print(f"Relative Position Encoding")
    print(f"="*50)
    print(f"Residues: {N_res}")
    print(f"Max relative feature: ±{max_relative_feature}")
    print(f"Number of position classes: {num_classes}")
    
    # Step 1: Compute pairwise offset (Line 1 in pseudocode)
    # offset[i,j] = residue_index[i] - residue_index[j]
    offset = residue_index[:, None] - residue_index[None, :]
    
    print(f"\nStep 1 - Offset matrix: {offset.shape}")
    print(f"  Range: [{offset.min()}, {offset.max()}]")
    
    # Step 2: Clip to valid range (Line 2)
    clipped = np.clip(offset, -max_relative_feature, max_relative_feature)
    
    print(f"Step 2 - Clipped range: [{clipped.min()}, {clipped.max()}]")
    
    # Step 3: Shift to [0, 2*max] for one-hot encoding (Line 3)
    shifted = clipped + max_relative_feature
    
    print(f"Step 3 - Shifted range: [{shifted.min()}, {shifted.max()}]")
    
    # Step 4: One-hot encode (Algorithm 5)
    rel_pos = one_hot(shifted.astype(np.int32), num_classes)
    
    print(f"Step 4 - Output shape: {rel_pos.shape}")
    
    return rel_pos

## Test Examples

In [None]:
# Test 1: Standard sequential indices
print("Test 1: Standard Sequential Indices")
print("="*60)

N_res = 64
residue_index = np.arange(N_res)

rel_pos = relpos(residue_index, max_relative_feature=32)

# Verify diagonal (self-distance = 0)
print(f"\nVerification:")
diag_indices = np.argmax(rel_pos[np.arange(N_res), np.arange(N_res)], axis=-1)
print(f"  Diagonal positions (should all be 32): {np.unique(diag_indices)}")

# Verify adjacent (i, i+1) has offset = -1, so index = 31
adjacent_indices = np.argmax(rel_pos[np.arange(N_res-1), np.arange(1, N_res)], axis=-1)
print(f"  Adjacent (i, i+1) positions (should all be 31): {np.unique(adjacent_indices)}")

In [None]:
# Test 2: Check symmetry property
print("\nTest 2: Symmetry Property")
print("="*60)

N_res = 16
residue_index = np.arange(N_res)
rel_pos = relpos(residue_index, max_relative_feature=32)

# For position encoding: rel_pos[i,j,k] should relate to rel_pos[j,i, 2*max-k]
# Because if offset(i,j) = d, then offset(j,i) = -d
max_rel = 32
is_antisymmetric = True

for i in range(N_res):
    for j in range(N_res):
        idx_ij = np.argmax(rel_pos[i, j])
        idx_ji = np.argmax(rel_pos[j, i])
        # idx_ij + idx_ji should equal 2*max_rel (since they're symmetric around max_rel)
        if idx_ij + idx_ji != 2 * max_rel:
            is_antisymmetric = False
            break

print(f"Antisymmetry verified: {is_antisymmetric}")
print(f"  rel_pos[0,5] peak at: {np.argmax(rel_pos[0,5])} (offset = -5 → 27)")
print(f"  rel_pos[5,0] peak at: {np.argmax(rel_pos[5,0])} (offset = +5 → 37)")
print(f"  Sum: {np.argmax(rel_pos[0,5]) + np.argmax(rel_pos[5,0])} (should be 64)")

In [None]:
# Test 3: Non-contiguous residue indices (e.g., multi-chain)
print("\nTest 3: Non-contiguous Indices (Multi-chain)")
print("="*60)

# Simulate two chains with a large gap
chain1 = np.arange(20)           # Residues 0-19
chain2 = np.arange(200, 220)     # Residues 200-219
residue_index = np.concatenate([chain1, chain2])

rel_pos = relpos(residue_index, max_relative_feature=32)

# Check within-chain distances
within_chain1 = np.argmax(rel_pos[0, 10])  # Should be 32-10=22
print(f"Within chain1 (0→10): index={within_chain1} (expected 22, offset=-10)")

# Check cross-chain distances (should be clipped)
cross_chain = np.argmax(rel_pos[0, 25])  # 0 to first residue of chain2 (200)
print(f"Cross-chain (0→chain2_first): index={cross_chain} (expected 0, offset=-200 clipped to -32)")

In [None]:
# Test 4: Visualize relative position pattern
print("\nTest 4: Position Pattern Analysis")
print("="*60)

N_res = 32
residue_index = np.arange(N_res)
rel_pos = relpos(residue_index, max_relative_feature=32)

# Extract the argmax indices (which position class is active)
active_class = np.argmax(rel_pos, axis=-1)

print(f"\nActive position class matrix (first 8x8):")
print(active_class[:8, :8])

# Verify it's a Toeplitz-like structure
# Each diagonal should have constant value
print(f"\nDiagonal values:")
for d in range(-3, 4):
    diag = np.diag(active_class, k=d)
    print(f"  Diagonal {d:+d}: all={diag[0]}, consistent={np.all(diag == diag[0])}")

## Integration with Pair Representation

In [None]:
def add_relpos_to_pair(pair_act, residue_index, max_relative_feature=32):
    """
    Add relative position encoding to pair representation.
    
    This is how relpos is used in the Input Embedder.
    
    Args:
        pair_act: Pair activations [N_res, N_res, c_z]
        residue_index: Residue position indices [N_res]
        max_relative_feature: Maximum relative distance
    
    Returns:
        Updated pair activations
    """
    N_res, _, c_z = pair_act.shape
    num_classes = 2 * max_relative_feature + 1
    
    # Compute relative position features
    rel_pos = relpos(residue_index, max_relative_feature)
    
    # Project to pair dimension
    W_relpos = np.random.randn(num_classes, c_z) * 0.01
    relpos_emb = np.einsum('ijk,kc->ijc', rel_pos, W_relpos)
    
    # Add to pair representation
    pair_act_updated = pair_act + relpos_emb
    
    print(f"\nIntegration with Pair Representation:")
    print(f"  Input pair: {pair_act.shape}")
    print(f"  Relpos embedding: {relpos_emb.shape}")
    print(f"  Output pair: {pair_act_updated.shape}")
    
    return pair_act_updated

# Test integration
N_res, c_z = 32, 128
pair_act = np.random.randn(N_res, N_res, c_z) * 0.1
residue_index = np.arange(N_res)

pair_updated = add_relpos_to_pair(pair_act, residue_index)

## Verification: Key Properties

In [None]:
print("Verification: Key Properties")
print("="*60)

N_res = 64
residue_index = np.arange(N_res)
max_rel = 32
rel_pos = relpos(residue_index, max_relative_feature=max_rel)

# Property 1: One-hot (each position has exactly one 1)
sums = rel_pos.sum(axis=-1)
is_onehot = np.allclose(sums, 1.0)
print(f"Property 1 - Valid one-hot: {is_onehot}")

# Property 2: Correct number of classes
expected_classes = 2 * max_rel + 1
actual_classes = rel_pos.shape[-1]
print(f"Property 2 - Correct classes: {actual_classes == expected_classes} ({actual_classes})")

# Property 3: Diagonal represents zero offset
diag_class = np.argmax(rel_pos[0, 0])
print(f"Property 3 - Diagonal class is center: {diag_class == max_rel} (class {diag_class})")

# Property 4: Clipping works for large offsets
# Create long sequence
long_idx = np.arange(100)
long_rel = relpos(long_idx, max_relative_feature=max_rel)
far_offset = np.argmax(long_rel[0, 99])  # 99 positions apart
print(f"Property 4 - Large offset clipped: {far_offset == 0} (class {far_offset}, should be 0)")

# Property 5: Output shape
expected_shape = (N_res, N_res, expected_classes)
print(f"Property 5 - Output shape: {rel_pos.shape == expected_shape}")

## Source Code Reference

```python
# From AF2-source-code/model/modules.py (EmbeddingsAndEvoformer.__call__)

# Relative position encoding.
# Jumper et al. (2021) Suppl. Alg. 4 "relpos"
# Jumper et al. (2021) Suppl. Alg. 5 "one_hot"

if c.max_relative_feature:
    pos = batch['residue_index']
    offset = pos[:, None] - pos[None, :]  # Pairwise offset
    
    rel_pos = jax.nn.one_hot(
        jnp.clip(
            offset + c.max_relative_feature,  # Shift to [0, 2*max]
            a_min=0,
            a_max=2 * c.max_relative_feature),  # Clip
        2 * c.max_relative_feature + 1)  # One-hot encode
    
    # Add to pair representation via linear projection
    pair_activations += common_modules.Linear(
        c.pair_channel, name='pair_activiations')(rel_pos)
```

## Key Insights

1. **Sequential Information**: Relative position encoding provides essential information about sequence order that would otherwise be lost in the permutation-invariant attention mechanisms.

2. **Clipping Strategy**: By clipping to ±32, the model treats all "far" residues equally. This is a design choice that balances expressivity with parameter efficiency.

3. **One-Hot vs. Continuous**: Using one-hot encoding (65 discrete classes) rather than continuous features allows the model to learn distinct behaviors for each relative distance.

4. **Multi-Chain Support**: For multi-chain proteins, residue indices have gaps between chains, which naturally results in clipped (maximal) relative positions for cross-chain pairs.

5. **Toeplitz Structure**: The relative position matrix has a Toeplitz-like structure where each diagonal has a constant value, encoding translation-invariant sequence features.