# Algorithm 5: One-Hot Encoding

One-hot encoding converts categorical indices to binary vectors. Used throughout AlphaFold2 for encoding amino acid types, relative positions, etc.

## Algorithm Pseudocode

![one_hot](../imgs/algorithms/one_hot.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Usage**: Throughout, via `jax.nn.one_hot`

In [None]:
import numpy as np

np.random.seed(42)

In [None]:
def one_hot(indices, num_classes, dtype=np.float32):
    """
    One-Hot Encoding - Algorithm 5.
    
    Converts integer indices to one-hot vectors.
    
    Args:
        indices: Integer array of any shape
        num_classes: Number of classes (length of one-hot vector)
        dtype: Output dtype
    
    Returns:
        One-hot encoded array with extra dimension of size num_classes
    """
    # Create identity matrix
    eye = np.eye(num_classes, dtype=dtype)
    
    # Index into identity matrix
    return eye[indices]


def one_hot_explicit(indices, num_classes, dtype=np.float32):
    """
    Explicit implementation showing the algorithm.
    """
    # Create output array
    shape = indices.shape + (num_classes,)
    output = np.zeros(shape, dtype=dtype)
    
    # Set 1s at appropriate positions
    # For each position, set output[..., indices[position]] = 1
    flat_indices = indices.flatten()
    flat_output = output.reshape(-1, num_classes)
    flat_output[np.arange(len(flat_indices)), flat_indices] = 1.0
    
    return output

In [None]:
# Test
print("Test One-Hot Encoding")
print("="*50)

# Example 1: Simple array
indices = np.array([0, 1, 2, 3])
result = one_hot(indices, num_classes=5)
print(f"Input: {indices}")
print(f"Output shape: {result.shape}")
print(f"Output:\n{result}")

# Example 2: Amino acid encoding
print(f"\nAmino acid encoding (20 classes):")
aa_sequence = np.array([0, 5, 10, 15, 19])  # A, G, L, S, Y
aa_onehot = one_hot(aa_sequence, num_classes=20)
print(f"Sequence indices: {aa_sequence}")
print(f"One-hot shape: {aa_onehot.shape}")
print(f"Verification - argmax recovers indices: {np.argmax(aa_onehot, axis=-1)}")

In [None]:
# 2D example (MSA)
print("\n2D Example (MSA-like):")
msa = np.random.randint(0, 21, size=(4, 8))  # 4 sequences, 8 residues
msa_onehot = one_hot(msa, num_classes=21)
print(f"MSA shape: {msa.shape}")
print(f"One-hot MSA shape: {msa_onehot.shape}")

## Usage in AlphaFold2

One-hot encoding is used for:
1. **Amino acid types**: 20 or 21 classes (+ unknown)
2. **Relative positions**: 65 classes (-32 to +32)
3. **Template features**: Various categorical features

```python
# JAX implementation
onehot = jax.nn.one_hot(indices, num_classes)
```