# Algorithm 5: One-Hot Encoding

One-hot encoding converts categorical indices to binary vectors. This fundamental operation is used throughout AlphaFold2 for encoding amino acid types, relative positions, and other discrete features.

## Algorithm Pseudocode

![one_hot](../imgs/algorithms/one_hot.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Usage**: Throughout, via `jax.nn.one_hot`

## Overview

One-hot encoding transforms integer indices into binary vectors:

```
Index 0 with 5 classes → [1, 0, 0, 0, 0]
Index 2 with 5 classes → [0, 0, 1, 0, 0]
Index 4 with 5 classes → [0, 0, 0, 0, 1]
```

### Usage in AlphaFold2

| Feature | Classes | Description |
|---------|---------|-------------|
| Amino acid type | 20-21 | Standard amino acids (+ unknown) |
| MSA amino acid | 23 | 20 AA + gap + mask + unknown |
| Relative position | 65 | -32 to +32 relative positions |
| Distance bins | 64 | For distogram predictions |

In [None]:
import numpy as np

np.random.seed(42)

## NumPy Implementation

In [None]:
def one_hot(indices, num_classes, dtype=np.float32):
    """
    One-Hot Encoding - Algorithm 5.
    
    Converts integer indices to one-hot vectors.
    
    Args:
        indices: Integer array of any shape
        num_classes: Number of classes (length of one-hot vector)
        dtype: Output data type (default: float32)
    
    Returns:
        One-hot encoded array with shape [..., num_classes]
    """
    # Efficient implementation using identity matrix indexing
    eye = np.eye(num_classes, dtype=dtype)
    return eye[indices]


def one_hot_explicit(indices, num_classes, dtype=np.float32):
    """
    Explicit implementation showing the algorithm step-by-step.
    
    This version is clearer about what one-hot encoding does.
    """
    # Create output array
    output_shape = indices.shape + (num_classes,)
    output = np.zeros(output_shape, dtype=dtype)
    
    # Flatten for easier indexing
    flat_indices = indices.flatten()
    flat_output = output.reshape(-1, num_classes)
    
    # Set 1s at appropriate positions
    # For each position i, set output[i, indices[i]] = 1
    flat_output[np.arange(len(flat_indices)), flat_indices] = 1.0
    
    return output


def one_hot_with_mask(indices, num_classes, mask=None, dtype=np.float32):
    """
    One-hot encoding with optional masking.
    
    Masked positions get zero vectors.
    
    Args:
        indices: Integer indices
        num_classes: Number of classes
        mask: Boolean mask (True = valid, False = masked)
        dtype: Output data type
    
    Returns:
        One-hot encoded array with masked positions zeroed
    """
    onehot = one_hot(indices, num_classes, dtype)
    
    if mask is not None:
        # Expand mask to match one-hot dimensions
        mask_expanded = mask[..., None].astype(dtype)
        onehot = onehot * mask_expanded
    
    return onehot

## Test Examples

In [None]:
# Test 1: Basic one-hot encoding
print("Test 1: Basic One-Hot Encoding")
print("="*60)

indices = np.array([0, 1, 2, 3, 4])
num_classes = 5

onehot = one_hot(indices, num_classes)

print(f"Input indices: {indices}")
print(f"Number of classes: {num_classes}")
print(f"Output shape: {onehot.shape}")
print(f"\nOne-hot matrix:")
print(onehot)

In [None]:
# Test 2: Amino acid encoding
print("\nTest 2: Amino Acid Encoding")
print("="*60)

# Amino acid indices (0-19 for standard amino acids)
AA_NAMES = ['A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I',
            'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']

# Example sequence: "ACDEF"
aa_indices = np.array([0, 4, 3, 6, 13])  # A, C, D, E, F

aa_onehot = one_hot(aa_indices, num_classes=20)

print(f"Sequence: {[AA_NAMES[i] for i in aa_indices]}")
print(f"Indices: {aa_indices}")
print(f"One-hot shape: {aa_onehot.shape}")

# Verify we can recover the original indices
recovered = np.argmax(aa_onehot, axis=-1)
print(f"Recovered indices: {recovered}")
print(f"Recovery correct: {np.array_equal(aa_indices, recovered)}")

In [None]:
# Test 3: 2D input (MSA)
print("\nTest 3: MSA One-Hot Encoding")
print("="*60)

N_seq, N_res = 4, 8
num_aa_classes = 23  # 20 AA + gap + mask + unknown

# Random MSA
msa = np.random.randint(0, num_aa_classes, size=(N_seq, N_res))

msa_onehot = one_hot(msa, num_aa_classes)

print(f"MSA shape: {msa.shape}")
print(f"MSA one-hot shape: {msa_onehot.shape}")
print(f"Expected: ({N_seq}, {N_res}, {num_aa_classes})")

# Verify sum along last axis equals 1
sums = msa_onehot.sum(axis=-1)
print(f"\nAll rows sum to 1: {np.allclose(sums, 1.0)}")

In [None]:
# Test 4: With masking
print("\nTest 4: One-Hot with Masking")
print("="*60)

indices = np.array([0, 1, 2, 3, 4])
mask = np.array([True, True, False, True, False])  # Mask positions 2 and 4

onehot_masked = one_hot_with_mask(indices, num_classes=5, mask=mask)

print(f"Indices: {indices}")
print(f"Mask: {mask}")
print(f"\nMasked one-hot:")
print(onehot_masked)
print(f"\nMasked positions are zero vectors: {onehot_masked[2].sum() == 0 and onehot_masked[4].sum() == 0}")

In [None]:
# Test 5: Compare implementations
print("\nTest 5: Compare Implementations")
print("="*60)

indices = np.random.randint(0, 20, size=(100, 64))

result1 = one_hot(indices, 20)
result2 = one_hot_explicit(indices, 20)

print(f"Efficient implementation shape: {result1.shape}")
print(f"Explicit implementation shape: {result2.shape}")
print(f"Results match: {np.allclose(result1, result2)}")

In [None]:
# Test 6: Relative position one-hot (as used in Algorithm 4)
print("\nTest 6: Relative Position One-Hot")
print("="*60)

max_relative = 32
num_classes = 2 * max_relative + 1  # 65 classes

# Simulate clipped and shifted relative positions
offsets = np.array([-50, -32, -10, 0, 10, 32, 50])  # Various offsets
clipped = np.clip(offsets, -max_relative, max_relative)
shifted = clipped + max_relative  # Now in [0, 64]

rel_onehot = one_hot(shifted, num_classes)

print(f"Original offsets: {offsets}")
print(f"Clipped: {clipped}")
print(f"Shifted (indices): {shifted}")
print(f"\nOne-hot peaks at: {np.argmax(rel_onehot, axis=-1)}")
print(f"Shape: {rel_onehot.shape}")

## Verification: Key Properties

In [None]:
print("Verification: Key Properties")
print("="*60)

indices = np.random.randint(0, 20, size=(32, 64))
onehot = one_hot(indices, 20)

# Property 1: Binary values (0 or 1)
is_binary = np.all((onehot == 0) | (onehot == 1))
print(f"Property 1 - Binary values: {is_binary}")

# Property 2: Exactly one 1 per vector
row_sums = onehot.sum(axis=-1)
is_valid_onehot = np.allclose(row_sums, 1.0)
print(f"Property 2 - Exactly one 1 per vector: {is_valid_onehot}")

# Property 3: Correct shape
expected_shape = indices.shape + (20,)
shape_correct = onehot.shape == expected_shape
print(f"Property 3 - Shape [..., num_classes]: {shape_correct}")

# Property 4: Invertible via argmax
recovered = np.argmax(onehot, axis=-1)
invertible = np.array_equal(indices, recovered)
print(f"Property 4 - Invertible via argmax: {invertible}")

# Property 5: Correct dtype
dtype_correct = onehot.dtype == np.float32
print(f"Property 5 - Float32 dtype: {dtype_correct}")

## Performance Comparison

In [None]:
import time

print("Performance Comparison")
print("="*60)

# Large input
indices = np.random.randint(0, 20, size=(128, 256))
n_runs = 100

# Efficient implementation
start = time.time()
for _ in range(n_runs):
    _ = one_hot(indices, 20)
efficient_time = (time.time() - start) / n_runs * 1000

# Explicit implementation
start = time.time()
for _ in range(n_runs):
    _ = one_hot_explicit(indices, 20)
explicit_time = (time.time() - start) / n_runs * 1000

print(f"Input shape: {indices.shape}")
print(f"Number of runs: {n_runs}")
print(f"\nEfficient (eye indexing): {efficient_time:.3f} ms")
print(f"Explicit (manual setting): {explicit_time:.3f} ms")
print(f"Speedup: {explicit_time / efficient_time:.2f}x")

## Source Code Reference

```python
# JAX implementation used in AlphaFold2
import jax.nn

# One-hot encoding in JAX
onehot = jax.nn.one_hot(indices, num_classes)

# Usage examples in AlphaFold2:

# 1. Amino acid encoding (in data preprocessing)
aatype_onehot = jax.nn.one_hot(batch['aatype'], 21)

# 2. Relative position encoding (Algorithm 4)
rel_pos = jax.nn.one_hot(
    jnp.clip(offset + max_relative, 0, 2 * max_relative),
    2 * max_relative + 1
)

# 3. Distance binning for distogram
dgram = jax.nn.one_hot(distance_bins, num_bins)
```

## Key Insights

1. **Fundamental Operation**: One-hot encoding is one of the most basic but essential operations in neural networks for handling categorical data.

2. **Memory Trade-off**: One-hot encoding expands a single integer to a vector of length `num_classes`. For large vocabularies, this can be memory-intensive.

3. **Embedding Alternative**: In practice, one-hot vectors are often immediately multiplied by a weight matrix (embedding lookup). This is equivalent to selecting a row from the weight matrix using the index directly.

4. **Gradient Flow**: One-hot encoding creates sparse gradients, which can be handled efficiently by deep learning frameworks.

5. **Implementation Efficiency**: The identity matrix indexing approach (`np.eye(n)[indices]`) is typically faster than explicit loop-based implementations due to NumPy's optimized array operations.