# Algorithm 1: MSA Features (AlphaFold3)

MSA (Multiple Sequence Alignment) features provide evolutionary information about protein sequences. AlphaFold3 simplifies MSA processing compared to AlphaFold2.

## Source Code Location
- **File**: `AF3-Ref-src/alphafold3-official/src/alphafold3/model/features.py`

## Overview

### MSA Features in AF3

| Feature | Shape | Description |
|---------|-------|-------------|
| `msa` | [N_msa, N_token, 32] | One-hot encoded residue types |
| `has_deletion` | [N_msa, N_token] | Binary deletion indicator |
| `deletion_value` | [N_msa, N_token] | Normalized deletion count |

### Key Difference from AF2
- AF3 uses a simplified 4-block MSA module instead of full Evoformer processing
- MSA is subsampled before processing

In [None]:
import numpy as np
np.random.seed(42)

In [None]:
# Amino acid vocabulary
AA_VOCAB = 'ACDEFGHIKLMNPQRSTVWY-X'  # 20 AAs + gap + unknown
NUM_AA = len(AA_VOCAB)  # 22, padded to 32 in practice

def one_hot(indices, num_classes=32):
    """One-hot encode indices."""
    return np.eye(num_classes)[indices]

def create_msa_features(msa_sequences, max_msa=1024):
    """
    Create MSA features from sequences.
    
    Args:
        msa_sequences: List of aligned sequences
        max_msa: Maximum number of MSA sequences
    
    Returns:
        Dictionary of MSA features
    """
    N_msa = min(len(msa_sequences), max_msa)
    N_token = len(msa_sequences[0])
    
    print(f"Creating MSA Features")
    print(f"="*50)
    print(f"MSA depth: {N_msa}, Length: {N_token}")
    
    # Initialize
    msa_indices = np.zeros((N_msa, N_token), dtype=np.int32)
    has_deletion = np.zeros((N_msa, N_token), dtype=np.float32)
    deletion_value = np.zeros((N_msa, N_token), dtype=np.float32)
    
    # Process sequences
    for i, seq in enumerate(msa_sequences[:N_msa]):
        for j, aa in enumerate(seq):
            if aa in AA_VOCAB:
                msa_indices[i, j] = AA_VOCAB.index(aa)
            else:
                msa_indices[i, j] = AA_VOCAB.index('X')  # Unknown
            
            # Simulate deletion features
            if aa == '-':
                has_deletion[i, j] = 1.0
                deletion_value[i, j] = np.random.uniform(0, 1)
    
    # One-hot encode
    msa_onehot = one_hot(msa_indices, num_classes=32)
    
    print(f"\nFeatures:")
    print(f"  msa: {msa_onehot.shape}")
    print(f"  has_deletion: {has_deletion.shape}")
    print(f"  deletion_value: {deletion_value.shape}")
    
    return {
        'msa': msa_onehot,
        'has_deletion': has_deletion,
        'deletion_value': deletion_value,
        'msa_indices': msa_indices,
    }

In [None]:
def subsample_msa(msa_features, n_samples=128):
    """
    Subsample MSA for efficient processing.
    
    AF3 uses random subsampling to reduce MSA depth.
    """
    N_msa = msa_features['msa'].shape[0]
    
    if N_msa <= n_samples:
        return msa_features
    
    # Always keep first sequence (query)
    indices = [0] + list(np.random.choice(
        range(1, N_msa), size=n_samples-1, replace=False
    ))
    
    return {
        'msa': msa_features['msa'][indices],
        'has_deletion': msa_features['has_deletion'][indices],
        'deletion_value': msa_features['deletion_value'][indices],
    }

In [None]:
# Test: Create MSA features
print("Test: MSA Feature Creation")
print("="*60)

# Simulate MSA
query = "MKFLILLFNILCLFPVLAADNHGVGPQGAS"
N_msa = 256

msa_seqs = [query]
for _ in range(N_msa - 1):
    # Create variants with mutations and gaps
    seq = list(query)
    for j in range(len(seq)):
        if np.random.random() < 0.1:  # 10% mutation rate
            seq[j] = np.random.choice(list('ACDEFGHIKLMNPQRSTVWY-'))
    msa_seqs.append(''.join(seq))

features = create_msa_features(msa_seqs)

# Subsample
print("\nSubsampling to 64 sequences:")
features_sub = subsample_msa(features, n_samples=64)
print(f"  Subsampled msa: {features_sub['msa'].shape}")

In [None]:
# Verification
print("\nVerification")
print("="*60)

# Check one-hot validity
print(f"One-hot sums to 1: {np.allclose(features['msa'].sum(axis=-1), 1)}")
print(f"Query preserved after subsample: {np.allclose(features['msa'][0], features_sub['msa'][0])}")
print(f"Finite values: {np.isfinite(features['msa']).all()}")

## Key Insights

1. **Simplified Processing**: AF3 processes MSA in only 4 blocks vs 48 in AF2's Evoformer
2. **Subsampling**: Random subsampling reduces computational cost
3. **Query Preservation**: First sequence (query) is always preserved
4. **Feature Concatenation**: MSA one-hot + deletion features = 34 dimensions