# AlphaFold 2

This notebook contains notes on the AlphaFold 2 model.

## Overview
First, let's look at a high level at the AlphFold2 model.

The required input for AlphaFold2 is a query amino acid sequence representing the protein of interest.
The target output from evaluating AlphaFold2 with this query is a set of atom coordinates $\{\vec{x}_i^a\}$ and the per-residue confidence $\{p_i^{pLDDT}\}$.


The neural network part of AlphaFold2 doesn't take just the query amino acid sequence as input, but rather makes use of the wealth of knowledge that we already have about protein structure. This is done by incorporating evolutionarily related sequences from other species in a Multiple Sequence Alignment (MSA). Features are computed for the query sequence on the basis of this MSA, and these features are used to represent the query sequence.

We'll start by inspecting the data pre-processing steps that happen immediately following the specification of a query sequence. 
![Architecture1](pics/AlphaFold_1.png)
The query pre-processing steps generate two main inputs to the neural network. First, the MSA itself is embedded in a vector space and used as input. Second, the query sequence is represented as a residue pair matrix using features computed in the context of the MSA and structural templates selected from the PDB.

These two inputs (the MSA embedding and the pair representation) are then used as input to the first neural network, the Evoformer. This network performs an additional embedding and outputs a new MSA embedding and new pair representation. These are used as input to the second neural network, the Structure Module. This module outputs the predicted atomic coordinates and per-residue confidence.

The network block is evaluated iteratively, including current predictions as input for the next iteration. The final structure prediction is only achieved after many iterations.

![Architecture2](pics/AlphaFold_2.png)

Now let's take a closer look at the data pre-processing steps.

## 1.1 Notation

In [10]:
Nres = 256  # residues in the human primary sequence
Ntempl = 4  # number of templates
Nall_seq = 1152  # number of MSA sequences
Nclust = 124  # number of MSA clusters
Nseq = 128  # Nclust + Ntempl
Nextra_seq = 1024  # unclustered MSA sequences

- Capitalized function names contain trainable parameters.
- Lower case function names have no parameters, but simply perform a transformation of the input.
- Indices (i, j, k) are reserved for residue dimensions.
- Indices (s, t) are reserved for sequence dimensions.
- Index (h) is reserved for indexing the attention heads.

## 1.2 Data pipeline

The data pipeline is the first step when running AlphaFold. It takes an input FASTA file and produces the input features for the ML model.

Several genetic searches are performed (using JackHMMER and HHBlits on different databases) and the output MSAs are stacked. 

Structural templates (i.e., 3D atomic coordinates to use as priors) are retrieved from the PDB using the MSA sequences as queries. The top 4 templates are input the model. These top templates are selected by the expected number of correctly aligned residues in the MSA. 

### Training filters
Several filters are applied to the training data to help generate diversity and avoid particular edge cases.

- Some QA is applied to remove particularly uncommon sequences. For example, sequences are rejected if any single a.a. accounts for more than 80\% of the sequence.
- Probabilistic rejection of training sequences based on length helps to rebalance the length distribution in training data
- MSA block deletion: contiguous blocks of sequences in the MSA are randomly deleted, which tends to remove phylogenetic branches since MSA ordering typically places evolutionarily similar sequences contiguously
- During training, fewer than 4 structural templates (even 0 templates) can be randomly selected, to try and avoid the ML model simply copying the template structure. 

### MSA clustering
The peak computational complexity of the ML model scales as $\mathcal{O}(N_{seq}^2\times N_{res})$ so it is valuable to reduce the number of sequences in the MSA using unsupervised reduction. The procedure of reducing the MSA size is as follows:

1. $N_{clust}$ sequences are selected with random uniform probability to be cluster centers, although the query sequence (e.g. human) is always one of the cluster centers.
2. A mask is applied along the residue dimension of the MSA cluster center sequences, which is used later in the model, but involves random perturbations to the sequence.
3. The remaining sequences in the MSA (i.e., not cluster centers) are assigned to the closest cluster center as measured by Hamming distance on un-masked residues. 
4. For each cluster, statistical features such as the per-residue a.a. distribution are computed and these features are also input to the ML model.
5. An additional $N_{extra\_seq}$ are selected from the MSA in addition to the cluster centers, and these sequences are input to the model. The remaining MSA sequences only enter the model through their contribution to the per-residue statistical summary features.

During training, the number of residues is cropped to further reduce the computational complexity per training sample. This is not done during evaluation.


In [9]:
with open("five_letter_words.txt", 'r') as f:
    five_letter_words = f.readlines()
five_letter_words = sorted(five_letter_words)
five_letter_words = list(map(str.strip, five_letter_words))
print(five_letter_words[0:3])

['aargh', 'abaca', 'abaci']


In [26]:
# Step 1. Choose Nclust cluster centers with uniform random probability
import random
import string

cluster_centers = [random.choice(five_letter_words) for _ in range(Nclust)]
print("Query word: ", cluster_centers[0])

# Step 2. Mask along residue dimension
for res in range(5):
    if random.randint(1,100) <= 15:  # 15% chance of residue position being included in mask
        for pos, word in enumerate(cluster_centers):
            if random.randint(1,100) <= 70:  # 70% chance of masking residue in a given word
                new_word = word[0:res] + '<masked_msa_token>' + word[res+1:]
                cluster_centers[pos] = new_word
            if random.randint(1,100) <= 10:  # 10% chance of mutating residue in a given word
                new_word = word[0:res] + random.choice(string.ascii_letters.lower()) + word[res+1:] 
                cluster_centers[pos] = new_word
            # 20% chance of doing nothing
            
# Step 3. Assign remaining sequences to nearest cluster center
def Hamming(query, center):
    """ Computes the number of unmatched residues between query
        and center, not including '<masked_msa_token>' """
    count = 0
    for res in range(len(query)):
        if query[res] != center[res]:
            if center[res] == '<masked_msa_token>':
                pass
            else:
                count += 1
    return count


center_to_cluster = {key: [] for key in cluster_centers}
sequence_to_center = {key: '' for key in five_letter_words if key not in cluster_centers}
for sequence in sequence_to_center.keys():
    nearest = 6
    candidate_centers = []
    for center in cluster_centers:
        distance = Hamming(sequence, center)
        if distance < nearest:
            nearest = distance
            candidate_centers = [center]
        elif distance == nearest:
            candidate_centers.append(center)
        # otherwise, not a candidate center

    assignment = random.choice(candidate_centers)  # choose randomly between equally nearest cluster centers
    center_to_cluster[assignment].append(sequence)
    sequence_to_center[sequence] = assignment

print("Query word's cluster: ", center_to_cluster[cluster_centers[0]])
        

Query word:  unfed
Query word's cluster:  ['annex', 'defer', 'ended', 'enter', 'envoi', 'infer', 'infix', 'infra', 'inked', 'inker', 'inter', 'knave', 'kneel', 'lifer', 'offal', 'offen', 'oncet', 'onset', 'unarc', 'unarm', 'unban', 'unbar', 'unbox', 'uncap', 'uncle', 'under', 'undid', 'undue', 'unfed', 'unfit', 'unfix', 'unhip', 'unhit', 'unify', 'unite', 'unlit', 'unman', 'unmap', 'unmet', 'unpeg', 'unrig', 'unsee', 'unset', 'unsew', 'unsex', 'untie', 'until', 'unwed', 'unwon', 'unzip']


In [37]:
# Step 4. Compute per-residue statistical features for each cluster
amino_acids = [el for el in string.ascii_letters[0:26]] + ['<masked_msa_token>']

features_per_cluster = {center: {} for center in cluster_centers}
for center in cluster_centers:
    # dict of statistical features indexed by residue position
    statistical_features = {key: {aa: 0 for aa in amino_acids} for key in range(5)}

    for res in range(5):
        for sequence in center_to_cluster[center]:
            aa = sequence[res]
            statistical_features[res][aa] += 1

    features_per_cluster[center] = statistical_features

# Step 5. An additional Nextra_seq are selected for input to the model, in addition to the cluster centers
MSA_input = cluster_centers
target_length = min(Nextra_seq + len(cluster_centers), len(five_letter_words) - len(cluster_centers))
while(len(MSA_input) < target_length):
    candidate_sequence = random.choice(list(sequence_to_center.keys()))
    if candidate_sequence not in MSA_input:
        MSA_input.append(candidate_sequence)

# The input features for the model will be the sequences listed in MSA_input, 
# plus the list of vectors contained in features_per_cluster