In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [21]:
import torch
from torch import nn
import math
import os

In [5]:
# You might need to change this path, depending on your directory structure and current working directory
control_folder = 'solutions/feature_extraction/control_values'
assert os.path.isdir(control_folder), 'Folder for control values not found. Please check your path.'

# Feature Extraction

In this Notebook, we will implement the feature extraction pipeline for AlphaFold. The pipeline consists of the following steps:

- Parse the a3m file
- Count and remove deletions (residues that are present in the aligned sequences, but aren't present in the query sequence)
- Randomly select cluster centers
- Randomly change some residues from the cluster center (this is called masking)
- Assign non-cluster sequences to their closest cluster center
- Summarize the features of all sequences assigned to the cluster
- Crop non-cluster sequences to a fixed number
- Assemble the features

Most of the work is done to create the `msa_feat` and the `extra_msa_feat`. The input features additionally consist of the `target_feat` and the `residue_index`, but these are easy to implement.

For an overview of the features, you can read Section 1.2.9 from [AlphaFold's Supplement](https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf).

We will skip over the template features in this series, but predictions often work great without them (ColabFold has the option to use templates disabled by default).

## File parsing

Take a look at the file `alignment_tautomerase.a3m`. In it, you will find the alignment data of the 2-Hydroxymuconate Tautomerase in a3m format. The alignment was generated by ColabFold, which uses the mmseqs algorithm to create alignments.

Note the format of the file: It consists of lines starting with '>' containing an identifier and some values from the alignment, followed by a sequence. The first of these sequences is the query sequence.

In the sequence string, there are upper-case and lower-case letters. Upper-case letters denote aligned residues, while lower-case letters denote residues that are present in the aligned sequence, but not in the query sequence (in other formats, this might be denoted by a dash in the query sequence).

With this knowledge on the file format, implement the method `load_a3m_file` in the file `feature_extractor.py` and test your implementation with the following code cell. Conserve the order of the sequences in the file.

In [22]:
from solutions.feature_extraction.feature_extractor import load_a3m_file

seqs = load_a3m_file('solutions/feature_extraction/alignment_tautomerase.a3m')

first_expected = ['PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK', 'PVVTIELWEGRTPEQKRELVRAVSSAISRVLGCPEEAVHVILHEVPKANWGIGGRLASE', 'PVVTIEMWEGRTPEQKKALVEAVTSAVAGAIGCPPEAVEVIIHEVPKVNWGIGGQIASE', 'PIIQVQMLKGRSPELKKQLISEITDTISRTLGSPPEAVRVILTEVPEENWGVGGVPINE', 'PFVQIHMLEGRTPEQKKAVIEKVTQALVQAVGVPASAVRVLIQEVPKEHWGIGGVSARE']

assert len(seqs) == 8361 and seqs[:5] == first_expected

Now, we will parse the individual sequence to remove deletions and encode them as one-hot encoding. For one-hot encoding, the classes must have a predetermined order. The usual way to order the residues is to alphabetically sort the 3-letter codes and to then use this order for the 1-letter codes. 

The order of the amino acids is provided at the top of `feature_extractor.py`. Initialize the two dictionaries as maps from the letter to the index. Then, implement `onehot_encode_aa_type` and check your implementation with the following cell.

In [25]:
from solutions.feature_extraction.feature_extractor import onehot_encode_aa_type

test_seq = "ARNDCQEGHILKMFPSTWYV"

enc1 = onehot_encode_aa_type(test_seq, include_gap_token=False)
enc2 = onehot_encode_aa_type(test_seq, include_gap_token=True)
enc3 = onehot_encode_aa_type(test_seq+'-', include_gap_token=True)

assert torch.allclose(enc1, nn.functional.one_hot(torch.arange(20), num_classes=21))
assert torch.allclose(enc2, nn.functional.one_hot(torch.arange(20), num_classes=22))
enc3_exp = nn.functional.one_hot(torch.cat((torch.arange(20),torch.tensor([21]))), num_classes=22)
assert torch.allclose(enc3, enc3_exp)

Now implement `initial_data_from_seqs`. The method counts and removes deletions, and removes sequences that are duplicates without the deletions. 

After that, test your code by running the following cell:

In [28]:
from solutions.feature_extraction.feature_extractor import initial_data_from_seqs
seqs = load_a3m_file('solutions/feature_extraction/alignment_tautomerase.a3m')

features = initial_data_from_seqs(seqs)

expected_features = torch.load(f'{control_folder}/initial_data.pt')

for key, param in features.items():
    assert torch.allclose(param, expected_features[key]), f'Error in calculation of feature {key}.'