# Chapter 1: Feature Extraction

AlphaFold 3 uses three main types of features to predict protein structures:
1. **Token Features**: These include an encoding of the residue type, ids for the residue's index, which chain it belongs to, and which chain type it belongs to.
2. **MSA Features**: These are derived from the multiple sequence alignment, and contain information about its residue profile and deletion counts. Multiple MSA features are sampled, one per recycling iteration.
3. **RefStruct Features**: These are atom-level features and encode information on the atom's element type, charge, and so on.

Additionally, there is the contact matrix (encoding which tokens are covalently bonded to which other tokens) and template features (which we will leave out, we are doing template-free prediction).

In our implementation, we will be using atomworks for the input pipeline, a toolkit by the bakerlab for working with biological data. Its prime object is the AtomArray, which is basically a set of numpy arrays of shape (n_atoms,) encoding various properties: `atom_array.charge[i]` would give the charge of the i-th atom, as would `atom_array[i].charge`.  

All imports are relative to the tutorials folder. Make sure that the tutorials folder is on the system path (output of the next cell). If not, you can add it to the VSCode settings under Jupyter: Notebook File Root. Alternatively, you can manually add it to the system path in python.

In [None]:
import sys
sys.path

# Manually add the folder:
# sys.path.append('path/to/alphafold3-decoded/tutorials)

Run the following cell to set up the environment and the test cases.

In [None]:
import os
# Set so that Atomworks does not raise a warning, we don't need to actually download the mirrors for this notebook.
os.environ["PDB_MIRROR_PATH"] = "data/datasets/pdb_mirror"
os.environ["CCD_MIRROR_PATH"] = "data/datasets/ccd_mirror"

import numpy as np
import torch
from common.utils import load_alphafold_input
import tensortrace as ttr
from atomworks.constants import STANDARD_AA, STANDARD_RNA, STANDARD_DNA
from atomworks.ml.transforms.atom_array import AddGlobalAtomIdAnnotation
from atomworks.ml.transforms.atomize import AtomizeByCCDName
from atomworks.ml.transforms.filters import RemoveHydrogens
from atomworks.ml.transforms.base import Compose

%load_ext autoreload
%autoreload 2


base_transforms = [
    RemoveHydrogens(),
    AddGlobalAtomIdAnnotation(),
    AtomizeByCCDName(atomize_by_default=True, res_names_to_ignore=STANDARD_AA + STANDARD_RNA + STANDARD_DNA),
]

ttr.TensorTrace('data/tensortraces/feature_extraction', mode='write', framework='numpy').start_trace()

ttr.current_trace().atol = 1e-6
ttr.current_trace().rtol = 1e-5

test_inputs_pipeline = [
    load_alphafold_input(f'data/fold_inputs/fold_input_{name}.json')
    for name in ['lysozyme', 'multimer', 'protein_dna_ion', 'protein_rna_ion']
]

To warm up, we will start by implementing a few utility functions and the contact_matrix. Afterward, we are going through token features, RefStruct features, and MSA features in that order. 

Head over to `common/utils.py` to implement the functions `pad_to_shape`, `round_down_to`, and `masked_mean`. Afterward, run the following cell to test your implementation.

In [None]:
from common import utils

np.random.seed(0)
ttr.reset_loading_index()

# Test pad_to_shape
initial_shapes = [(3, 5), (16, 24, 8)]
target_shapes = ((4, 6), (16, 32, 16))
padding_vals = [0, -1]

success = True
for initial_shape, target_shape, padding_val in zip(initial_shapes, target_shapes, padding_vals):
    x = np.random.randn(*initial_shape)
    y = utils.pad_to_shape(x, target_shape, padding_val)
    success = success and ttr.log_or_compare(y, 'pad_to_shape')

if success: print('pad_to_shape tests created.')

# Test round_down_to
values_1 = np.random.randint(0, 100, size=50)
values_2 = np.random.randint(0, 100, size=(50, 4, 2))
bases = np.sort(np.random.randint(1, 100, size=10))

success = True
for values in [values_1, values_2]:
    y, y_inds = utils.round_down_to(values, bases, return_indices=True)
    success = success and ttr.log_or_compare(y, 'round_down_to')
    success = success and ttr.log_or_compare(y_inds, 'round_down_to_inds')

if success: print('round_down_to tests created.')

# Test masked_mean
shapes = [(5,), (4, 6), (3, 4, 5)]
axes = [0, -1, 2]
success = True
for shape, axis in zip(shapes, axes):
    x = np.random.randn(*shape)
    mask = np.random.rand(*shape) > 0.5
    y = utils.masked_mean(x, mask, axis=axis)
    # So that it is an array, even if ndim = 0
    y = np.array(y)
    success = success and ttr.log_or_compare(y, 'masked_mean')

if success: print('masked_mean tests created.')


 We want to implement the first AlphaFold feature, the contact matrix. The matrix is of shape (n_tokens, n_tokens),and denotes which tokens have a covalent bond between each other. In AlphaFold, one token corresponds to one amino acid (in proteins), one nucleotide (in DNA/RNA), or one atom of an _atomized_ structure (e.g. a modified residue or ligand). The number of tokens is theoretically arbitrary, but AlphaFold rounds it up to certain bucket sizes (e.g. 256, 512, 768, ...) for better performance, padding the features with zeros if necessary. 

 Go to `utils.py` and implement the function `round_to_bucket` to calculate the padded number of tokens based on the actual number of tokens. Then, implement the transform `CalculateContactMatrix` in `feature_extraction/contact_features.py` to compute the contact matrix. The contact matrix should be one at indices (i, j) if i&lt;j, if there is a bond between any atom of token i and any atom of token j, and if at least one of them is atomized (e.g. a modified residue or ligand). 

 After implementing both functions, test your code by running the following cell.

In [None]:
from feature_extraction.contact_features import CalculateContactMatrix
from common.utils import round_to_bucket

ttr.reset_loading_index()

for x, y in zip([2583, 512, 0], [3072, 512, 256]):
    assert round_to_bucket(x) == y, f'Problem with round_to_bucket, x={x}, y_desired={y}, y_actual={round_to_bucket(x)}'

print('round_to_bucket tests created.')

transform = Compose(base_transforms + [CalculateContactMatrix()])


success = True
for inputs in test_inputs_pipeline:
    contact_matrix = transform(inputs)['contact_matrix']
    success = success and ttr.log_or_compare(contact_matrix, 'contact_matrix')

if success: print('Contact matrix tests created.')

## Token Features
Now that we are used to working with arrays again, we move over to the first bigger feature: Token features. The code for these is in `feature_extraction/token_Features.py`.

Start with the token features by implementing the utility function `encode_restype`. This maps the string residue names provided by the atomworks atom_array to integer encodings. Also, read through the definition of the dataclass `TokenFeatures` to get familiar with the different features. Don't worry about the `block_mask` property for now, this will be covered in the next chapter on input embedding. Test your implementation by running the following cell.

In [None]:
from feature_extraction.token_features import encode_restype

ttr.reset_loading_index()

success = True
for test_input in test_inputs_pipeline:
    restype_strs = test_input['atom_array'].res_name
    restype_ints = encode_restype(restype_strs)
    success = success and ttr.log_or_compare(restype_ints, 'encoded_restypes')

if success: print('Restype encoding tests created.')

Now, we will implement the actual feature computation: The `forward` method of the `CalculateTokenFeatures` transform. The docstring contains a detailed description of how to compute the features. Notably, we will construct the features based on the atom_array, which has shape (n_atoms,), not (n_tokens,). To get representative atoms for each token, we will use the method `get_token_starts` from atomworks to get indices of the first atom of each token. We can index into the atom_array with these to get the representative atoms. 

After you are done, run the following cell to test your implementation.

In [None]:
from feature_extraction.token_features import CalculateTokenFeatures


ttr.reset_loading_index()

transform = Compose(base_transforms + [CalculateTokenFeatures()])

success = True

for test_input in test_inputs_pipeline:
    token_features = transform(test_input)['token_features'].__dict__
    success = success and ttr.log_or_compare(token_features, 'token_features')

if success: print('Token feature calculation tests created.')

## RefStruct Features
RefStruct features are atom-level features, meaning they have shape (n_atoms,) instead of (n_tokens,). Here, just as n_tokens is padded to certain bucket sizes and not the real number of tokens, n_atoms is not the real number of atoms, but n_tokens * 24 (24 is the fixed number of atoms per token that AlphaFold users). We will refer to the real number of atoms as the unpadded atom count. However, during feature construction, the docstrings generally use n_atoms for the unpadded atom count for simplicity. Padding is done at the end of the ref_struct feature construction transform.

The RefStruct features themselves are quite simple to construct, most of them can be directly obtained from the atom_array without any processing. The exception is `positions`, which require the computation of a reference conformer for the residues and ligands (a proposed 3D structure of the molecule, computed using chemistry software). 

Aside from that, the main challenge in implementing RefStruct features is allowing for conversion between a so-called atom-layout (having the features in shape (n_atoms,)) and a token-layout (having the features in shape (n_tokens, 24)). This is a bit tricky, because in the atom-layout, the atoms are dense (e.g. token_1_atom_1, token_1_atom_2, token_2_atom_1, ..., if the first token has only two atoms), while the token-layout would be sparse ([token_1_atom_1, token_1_atom_2, pad, pad, pad, ...], [token_2_atom_1, ...], ...). Because of this, the conversion is not a simple reshape. But it's not super complex either.

We will start with the computation of the features. In `ref_struct_features.py`, calculate the whole class `CalculateRefStructFeatures` (e.g. the methods `calculate_ref_positions`, `prep_atom_name_chars` for encoding the string atom names into integers, and the main `forward` method). After you are done, run the following cell to test your implementation. Note that this is below the dataclass `RefStructFeatures`, we will implement that afterwards.

In [None]:
from feature_extraction.ref_struct_features import CalculateRefStructFeatures
import numpy as np

np.random.seed(0)
ttr.reset_loading_index()

transform_no_ref = Compose(base_transforms + [CalculateTokenFeatures()])
calc_ref = CalculateRefStructFeatures()
transform = Compose(base_transforms + [CalculateTokenFeatures(), calc_ref])

success = True

for test_input in test_inputs_pipeline:
    ref_struct_features = transform(test_input)['ref_struct'].__dict__
    success = success and ttr.log_or_compare(ref_struct_features, 'ref_struct_features')


if success: print('Ref struct feature calculation (without unknown entries) tests created.')

test_unknown_ligand = test_inputs_pipeline[0]
data_no_ref = transform_no_ref(test_unknown_ligand)
atom_array = test_unknown_ligand['atom_array']
atom_array.res_name[atom_array.res_name == 'GLY'] = 'UNL'
ref_struct_features = calc_ref(data_no_ref)['ref_struct'].__dict__
success = ttr.log_or_compare(ref_struct_features, 'ref_struct_features_unknown_ligand')

if success: print('Ref struct feature calculation (with unknown entries) tests created.')

Now, we will complete the RefStruct features by implementing the conversion between atom-layout and token-layout, e.g. by implementing the methods `to_token_layout` and `to_atom_layout`, as well as the two helper methods `token_layout_ref_mask` and `patch_atom_dimension` in the class `RefStructFeatures`. The basic idea is this: We already have the mask for the RefStruct features in atom-layout, e.g. of shape `(**batch_shape, n_atoms)`. Now, we can just build an equivalent mask in token-layout, of shape `(**batch_shape, n_tokens, 24)`, based on the info how many atoms are present in each token, which we do in `token_layout_ref_mask`. Then, we can do feature conversion as `feat = zeros(out_shape); feat[token_layout_ref_mask] = feat_atom_layout[atom_layout_ref_mask]` and vice versa. The method `patch_atom_dimension` is simply utility for the case that we want to do token-layout -> atom-layout conversion, but the input feature doesn't have an atom dimension and simply requires broadcasting along that dimension.

Implement these methods in `ref_struct_features.py`, and test your implementation by running the following cell. Note that the layout conversions need to support batch dimensions (in contrast to the feature extraction pipeline, which only works on individual samples).

In [None]:
from feature_extraction.ref_struct_features import CalculateRefStructFeatures, RefStructFeatures
from feature_extraction.feature_extraction import collate_batch


ttr.reset_loading_index()
np.random.seed(0)
transform = Compose(base_transforms + [CalculateTokenFeatures(), CalculateRefStructFeatures()])

test_data_individual = [transform(test_input) for test_input in test_inputs_pipeline]

n_token_list = [data['token_features'].token_count for data in test_data_individual]
test_token_layout = [np.random.randn(n_tokens, 24) for n_tokens in n_token_list]
test_token_layout_with_feats = [np.random.randn(n_tokens, 24, 3, 2) for n_tokens in n_token_list]
test_token_layout_no_atoms = [np.random.randn(n_tokens,) for n_tokens in n_token_list]
test_atom_layout = [np.random.randn(n_tokens*24, 3) for n_tokens in n_token_list]

to_atom_list = []
to_atom_with_feats_list = []
to_atom_no_atom_dim_list = []  
to_token_list = []

# Single feature tests
success = True
for i, test_data in enumerate(test_data_individual):
    ref_struct_features: RefStructFeatures = test_data['ref_struct']
    to_atom = ref_struct_features.to_atom_layout(test_token_layout[i], has_atom_dimension=True)
    to_atom_with_feats = ref_struct_features.to_atom_layout(test_token_layout_with_feats[i], has_atom_dimension=True)
    to_atom_no_atom_dim = ref_struct_features.to_atom_layout(test_token_layout_no_atoms[i], has_atom_dimension=False)
    to_token = ref_struct_features.to_token_layout(test_atom_layout[i])

    success = success and ttr.log_or_compare(to_atom, f'to_atom')
    success = success and ttr.log_or_compare(to_atom_with_feats, f'to_atom_with_feats')
    success = success and ttr.log_or_compare(to_atom_no_atom_dim, f'to_atom_no_atom_dim')
    success = success and ttr.log_or_compare(to_token, f'to_token')

    to_atom_list.append(to_atom)
    to_atom_with_feats_list.append(to_atom_with_feats)
    to_atom_no_atom_dim_list.append(to_atom_no_atom_dim)
    to_token_list.append(to_token)

if success: print('Individual to_atom and to_token tests created.')
# Batch tests
test_data_batch = collate_batch(test_data_individual, drop_unconvertible_entries=True)
ref_struct_features_batch: RefStructFeatures = test_data_batch['ref_struct']

to_atom_batch = ref_struct_features_batch.to_atom_layout(collate_batch(test_token_layout), has_atom_dimension=True)
to_atom_with_feats_batch = ref_struct_features_batch.to_atom_layout(collate_batch(test_token_layout_with_feats), has_atom_dimension=True)
to_atom_no_atom_dim_batch = ref_struct_features_batch.to_atom_layout(collate_batch(test_token_layout_no_atoms), has_atom_dimension=False)
to_token_batch = ref_struct_features_batch.to_token_layout(collate_batch(test_atom_layout))

assert np.allclose(to_atom_batch, collate_batch(to_atom_list)), 'Batch to_atom does not match individual'
assert np.allclose(to_atom_with_feats_batch, collate_batch(to_atom_with_feats_list)), 'Batch to_atom_with_feats does not match individual'
assert np.allclose(to_atom_no_atom_dim_batch, collate_batch(to_atom_no_atom_dim_list)), 'Batch to_atom_no_atom_dim does not match individual'
assert np.allclose(to_token_batch, collate_batch(to_token_list)), 'Batch to_token does not match individual'

print('Batch to_atom and to_token tests created.')

## MSA Features
MSA features are derived from the multiple sequence alignment. The MSA feature is generally simpler in AF3 compared to AF2 (it doesn't use the clustering procedure), but it gets more complex due to the fact that we have multiple MSAs because we have multiple chains. Generally, the steps are the following:

1. Retrieve the MSA and deletion counts from the atom_array, and encode the MSA (e.g. with different integers for different residue types, gaps, ...)
2. Stack the MSAs and deletion counts
3. Extract the features (e.g. deletion mean, MSA profile, ...)
4. Subsample the MSA features for the recycling iterations.

First, implement the transform `EncodeMSA` in `feature_extraction/msa_features.py` to do step 1. Test your implementation by running the following cell.

In [None]:
from feature_extraction.msa_features import EncodeMSA
from atomworks.ml.transforms.msa.msa import LoadPolymerMSAs
from config import Config

config = Config()
np.random.seed(0)
ttr.reset_loading_index()
transform = Compose(base_transforms + [
    CalculateTokenFeatures(), 
    CalculateRefStructFeatures(), 
    LoadPolymerMSAs(max_msa_sequences=config.featurization_config.max_msa_sequences, use_paths_in_chain_info=True), 
    EncodeMSA()
    ])

success = True
for i, test_input in enumerate(test_inputs_pipeline):
    data = transform(test_input)
    msa_encs = { chain_id: data['msa'] for chain_id, data in data['polymer_msas_by_chain_id'].items()}
    success = success and ttr.log_or_compare(msa_encs, f'msa_encodings_{i}')

if success: print('MSA encoding tests created.')

Next up is the transform `ConcatMSAs` to do step 2. Test it by running the following cell.

In [None]:
from feature_extraction.msa_features import EncodeMSA, ConcatMSAs
from atomworks.ml.transforms.msa.msa import LoadPolymerMSAs
from config import Config

ttr.reset_loading_index()
np.random.seed(0)
config = Config()
transform = Compose(base_transforms + [
    CalculateTokenFeatures(), 
    CalculateRefStructFeatures(), 
    LoadPolymerMSAs(max_msa_sequences=config.featurization_config.max_msa_sequences, use_paths_in_chain_info=True), 
    EncodeMSA(),
    ConcatMSAs(max_msa_sequences=config.featurization_config.max_msa_sequences),
    ])

success = True
for test_input in test_inputs_pipeline:
    data = transform(test_input)
    success = success and ttr.log_or_compare(data['msa_features'], 'raw_msa_features')

if success: print('MSA feature concatenation tests created.')

The steps 3 and 4 are both done in the same transform `AssembleMSAFeatures`. Concretely, it creates two features, the target_feat (which does not include an MSA dimension) and the msa_feat (which includes an MSA dimension), then subsamples this msa_feat. Implement the transform and run the following cell to test your implementation.

In [None]:
from feature_extraction.msa_features import EncodeMSA, ConcatMSAs, AssembleMSAFeatures
from atomworks.ml.transforms.msa.msa import LoadPolymerMSAs
from config import Config
import torch

torch.manual_seed(0)
np.random.seed(0)
ttr.reset_loading_index()

config = Config()
config.featurization_config.max_msa_sequences = 256
config.featurization_config.msa_trunc_count = 128

transform = Compose(base_transforms + [
    CalculateTokenFeatures(), 
    CalculateRefStructFeatures(), 
    LoadPolymerMSAs(max_msa_sequences=config.featurization_config.max_msa_sequences, use_paths_in_chain_info=True), 
    EncodeMSA(),
    ConcatMSAs(max_msa_sequences=config.featurization_config.max_msa_sequences),
    AssembleMSAFeatures(config.featurization_config.msa_trunc_count, config.global_config.n_cycle)
    ])

success = True
for test_input in test_inputs_pipeline:
    data = transform(test_input)
    success = success and ttr.log_or_compare(data['msa_features'].__dict__, 'msa_features')

if success: print('MSA feature assembly tests created.')

Lastly, the transform `CalculateMSAFeatures` simply joines the previous transforms together into one single transform. Note that it also includes to transforms you didn't implement yourselves, `HotfixDuplicateRowIfSingleMSA` and `HotfixAF3LigandAsGap`. These account for some perks in AF3 feature encoding that might not be desirable for a general feature extraction pipeline, and are solely includedto mirror AF3 exactly (so that the AF3 weights can be used for the model). 

Implement `CalculateMSAFeatures` and run the following cell to test your implementation.

In [None]:
from feature_extraction.msa_features import CalculateMSAFeatures
from atomworks.ml.transforms.msa.msa import LoadPolymerMSAs
from config import Config
import torch

torch.manual_seed(0)
np.random.seed(0)
ttr.reset_loading_index()

config = Config()
config.featurization_config.max_msa_sequences = 256
config.featurization_config.msa_trunc_count = 128

transform = Compose(base_transforms + [
    CalculateTokenFeatures(), 
    CalculateRefStructFeatures(), 
    CalculateMSAFeatures(config.featurization_config.max_msa_sequences, config.featurization_config.msa_trunc_count, config.global_config.n_cycle)
    ])

success = True
for test_input in test_inputs_pipeline:
    data = transform(test_input)
    success = success and ttr.log_or_compare(data['msa_features'].__dict__, 'msa_features_full_pipeline')

if success: print('Full MSA feature calculation tests created.')

## Putting it all together

There's nothing difficult left to do. We only need to stitch all of the previous transforms together. This is done in `feature_extraction.py` in the method `custom_af3_pipeline`. The file also includes a few utility functions (for example, for recursively mapping a function over each tensor within `Batch`, or collating a list of `Batch` objects into one), which are already implemented. 

After implementing `custom_af3_pipeline`, you can test your final code for this Chapter by running the following cell.

In [None]:
from feature_extraction.feature_extraction import custom_af3_pipeline
from config import Config

ttr.reset_loading_index()
np.random.seed(0)
torch.manual_seed(0)

config = Config()
config.featurization_config.max_msa_sequences = 256
config.featurization_config.msa_trunc_count = 128

transform = custom_af3_pipeline(config)

success = True
for test_input in test_inputs_pipeline:
    data = transform(test_input)
    batch_as_dict = {
        k: v.__dict__ if hasattr(v, '__dict__') else v
        for k, v in data['batch'].__dict__.items()
    }
    success = success and ttr.log_or_compare(batch_as_dict, 'full_batch')

if success: print('Full feature extraction pipeline tests created.')

## Conclusion
You made it! This is everything for the first chapter. I know this was a lot. I wouldn't quite say that the next weeks get easier, but I think they are at least being more fun. Feature Extraction is a bit annoying, because it's just following along with instructions and implementing many small details. The next chapters are real ML modeling, with implementing input embedding in the next one. Great job for making it all the way through this, see you in the next chapter!