# Algorithm 4: Relative Position Encoding (AlphaFold3)

Relative position encoding captures the sequential and chain relationships between tokens.

## Source Code Location
- **File**: `AF3-Ref-src/alphafold3-official/src/alphafold3/model/network/featurization.py`
- **Function**: `create_relative_encoding`

In [None]:
import numpy as np
np.random.seed(42)

def one_hot(indices, num_classes):
    return np.eye(num_classes)[np.clip(indices, 0, num_classes-1)]

In [None]:
def create_relative_encoding(residue_index, chain_index, max_relative_idx=32, max_relative_chain=2):
    """
    Create relative position encoding for AF3.
    
    Args:
        residue_index: Residue indices [N]
        chain_index: Chain indices [N]
        max_relative_idx: Maximum relative position
        max_relative_chain: Maximum relative chain distance
    
    Returns:
        Relative encoding features [N, N, num_features]
    """
    N = len(residue_index)
    
    print(f"Relative Position Encoding")
    print(f"="*50)
    print(f"Tokens: {N}")
    
    # Relative residue position
    rel_pos = residue_index[:, None] - residue_index[None, :]  # [N, N]
    rel_pos_clipped = np.clip(rel_pos, -max_relative_idx, max_relative_idx)
    rel_pos_shifted = rel_pos_clipped + max_relative_idx  # [0, 2*max]
    rel_pos_onehot = one_hot(rel_pos_shifted, 2 * max_relative_idx + 1)  # [N, N, 65]
    
    # Same chain indicator
    same_chain = (chain_index[:, None] == chain_index[None, :]).astype(np.float32)  # [N, N]
    
    # Relative chain position (for multi-chain)
    rel_chain = chain_index[:, None] - chain_index[None, :]
    rel_chain_clipped = np.clip(rel_chain, -max_relative_chain, max_relative_chain)
    rel_chain_shifted = rel_chain_clipped + max_relative_chain
    rel_chain_onehot = one_hot(rel_chain_shifted, 2 * max_relative_chain + 1)  # [N, N, 5]
    
    # Concatenate features
    features = np.concatenate([
        rel_pos_onehot,  # 65
        same_chain[:, :, None],  # 1
        rel_chain_onehot,  # 5
    ], axis=-1)
    
    print(f"Relative position features: {rel_pos_onehot.shape[-1]}")
    print(f"Same chain features: 1")
    print(f"Relative chain features: {rel_chain_onehot.shape[-1]}")
    print(f"Total features: {features.shape[-1]}")
    
    return features

In [None]:
# Test: Single chain
print("Test 1: Single Chain")
print("="*60)

N = 50
residue_index = np.arange(N)
chain_index = np.zeros(N, dtype=np.int32)

rel_enc = create_relative_encoding(residue_index, chain_index)
print(f"Output: {rel_enc.shape}")

In [None]:
# Test: Multi-chain
print("\nTest 2: Multi-Chain Complex")
print("="*60)

# Two chains of length 30 and 20
residue_index = np.concatenate([np.arange(30), np.arange(20)])
chain_index = np.concatenate([np.zeros(30), np.ones(20)]).astype(np.int32)

rel_enc = create_relative_encoding(residue_index, chain_index)
print(f"Output: {rel_enc.shape}")

# Check same-chain indicator
same_chain = rel_enc[:, :, 65]  # Index 65 is same_chain
print(f"Same chain (0,0): {same_chain[0, 0]}")
print(f"Same chain (0,35): {same_chain[0, 35]}")

## Key Insights

1. **Relative Position**: Encodes distance between residues (clipped to Â±32)
2. **Chain Information**: Includes same-chain indicator and relative chain position
3. **Multi-chain Support**: Designed for protein complexes
4. **One-hot Encoding**: Discrete bins for precise position information