# AlphaFold 2

This notebook contains notes on the AlphaFold 2 model.

## Overview
First, let's look at a high level at the AlphFold2 model.

The required input for AlphaFold2 is a query amino acid sequence representing the protein of interest.
The target output from evaluating AlphaFold2 with this query is a set of atom coordinates $\{\vec{x}_i^a\}$ and the per-residue confidence $\{p_i^{pLDDT}\}$.


The neural network part of AlphaFold2 doesn't take just the query amino acid sequence as input, but rather makes use of the wealth of knowledge that we already have about protein structure. This is done by incorporating evolutionarily related sequences from other species in a Multiple Sequence Alignment (MSA). Features are computed for the query sequence on the basis of this MSA, and these features are used to represent the query sequence.

We'll start by inspecting the data pre-processing steps that happen immediately following the specification of a query sequence. 
![Architecture1](pics/AlphaFold_1.png)
The query pre-processing steps generate two main inputs to the neural network. First, the MSA itself is embedded in a vector space and used as input. Second, the query sequence is represented as a residue pair matrix using features computed in the context of the MSA and structural templates selected from the PDB.

These two inputs (the MSA embedding and the pair representation) are then used as input to the first neural network, the Evoformer. This network performs an additional embedding and outputs a new MSA embedding and new pair representation. These are used as input to the second neural network, the Structure Module. This module outputs the predicted atomic coordinates and per-residue confidence.

The network block is evaluated iteratively, including current predictions as input for the next iteration. The final structure prediction is only achieved after many iterations.

![Architecture2](pics/AlphaFold_2.png)

Now let's take a closer look at the data pre-processing steps.

## 1.1 Notation

In [1]:
Nres = 256  # residues in the human primary sequence
Ntempl = 4  # number of templates
Nall_seq = 1152  # number of MSA sequences
Nclust = 124  # number of MSA clusters
Nseq = 128  # Nclust + Ntempl
Nextra_seq = 1024  # unclustered MSA sequences

- Capitalized function names contain trainable parameters.
- Lower case function names have no parameters, but simply perform a transformation of the input.
- Indices (i, j, k) are reserved for residue dimensions.
- Indices (s, t) are reserved for sequence dimensions.
- Index (h) is reserved for indexing the attention heads.

## 1.2 Data pipeline

The data pipeline is the first step when running AlphaFold. It takes an input FASTA file and produces the input features for the ML model.

Several genetic searches are performed (using JackHMMER and HHBlits on different databases) and the output MSAs are stacked. 

Structural templates (i.e., 3D atomic coordinates to use as priors) are retrieved from the PDB using the MSA sequences as queries. The top 4 templates are input the model. These top templates are selected by the expected number of correctly aligned residues in the MSA. 

### Training filters
Several filters are applied to the training data to help generate diversity and avoid particular edge cases.

- Some QA is applied to remove particularly uncommon sequences. For example, sequences are rejected if any single a.a. accounts for more than 80\% of the sequence.
- Probabilistic rejection of training sequences based on length helps to rebalance the length distribution in training data
- MSA block deletion: contiguous blocks of sequences in the MSA are randomly deleted, which tends to remove phylogenetic branches since MSA ordering typically places evolutionarily similar sequences contiguously
- During training, fewer than 4 structural templates (even 0 templates) can be randomly selected, to try and avoid the ML model simply copying the template structure. 

### MSA clustering
The peak computational complexity of the ML model scales as $\mathcal{O}(N_{seq}^2\times N_{res})$ so it is valuable to reduce the number of sequences in the MSA using unsupervised reduction. The procedure of reducing the MSA size is as follows:

1. $N_{clust}$ sequences are selected with random uniform probability to be cluster centers, although the query sequence (e.g. human) is always one of the cluster centers.
2. A mask is applied along the residue dimension of the MSA cluster center sequences, which is used later in the model, but involves random perturbations to the sequence.
3. The remaining sequences in the MSA (i.e., not cluster centers) are assigned to the closest cluster center as measured by Hamming distance on un-masked residues. 
4. For each cluster, statistical features such as the per-residue a.a. distribution are computed and these features are also input to the ML model.
5. An additional $N_{extra\_seq}$ are selected from the MSA in addition to the cluster centers, and these sequences are input to the model. The remaining MSA sequences only enter the model through their contribution to the per-residue statistical summary features.

During training, the number of residues is cropped to further reduce the computational complexity per training sample. This is not done during evaluation.


In [2]:
with open("five_letter_words.txt", 'r') as f:
    five_letter_words = f.readlines()
five_letter_words = sorted(five_letter_words)
five_letter_words = list(map(str.strip, five_letter_words))
print(five_letter_words[0:3])

['aargh', 'abaca', 'abaci']


In [4]:
# Step 1. Choose Nclust cluster centers with uniform random probability
import random
import string

cluster_centers = [random.choice(five_letter_words) for _ in range(Nclust)]
print("Query word: ", cluster_centers[0])

# Step 2. Mask along residue dimension
for res in range(5):
    if random.randint(1,100) <= 15:  # 15% chance of residue position being included in mask
        for pos, word in enumerate(cluster_centers):
            if random.randint(1,100) <= 70:  # 70% chance of masking residue in a given word
                new_word = word[0:res] + '<masked_msa_token>' + word[res+1:]
                cluster_centers[pos] = new_word
            elif random.randint(1,100) <= 10:  # 10% chance of mutating residue in a given word
                new_word = word[0:res] + random.choice(string.ascii_letters.lower()) + word[res+1:] 
                cluster_centers[pos] = new_word
            # 20% chance of doing nothing
            
# Step 3. Assign remaining sequences to nearest cluster center
def Hamming(query, center):
    """ Computes the number of unmatched residues between query
        and center, not including '<masked_msa_token>' """
    count = 0
    for res in range(len(query)):
        if query[res] != center[res]:
            if center[res] == '<masked_msa_token>':
                pass
            else:
                count += 1
    return count


center_to_cluster = {key: [] for key in cluster_centers}
sequence_to_center = {key: '' for key in five_letter_words if key not in cluster_centers}
for sequence in sequence_to_center.keys():
    nearest = 6
    candidate_centers = []
    for center in cluster_centers:
        distance = Hamming(sequence, center)
        if distance < nearest:
            nearest = distance
            candidate_centers = [center]
        elif distance == nearest:
            candidate_centers.append(center)
        # otherwise, not a candidate center

    assignment = random.choice(candidate_centers)  # choose randomly between equally nearest cluster centers
    center_to_cluster[assignment].append(sequence)
    sequence_to_center[sequence] = assignment

print("Query word's cluster: ", center_to_cluster[cluster_centers[0]])
        

Query word:  sudsy
Query word's cluster:  ['judos', 'kudos', 'kudzu', 'oddly', 'pudgy', 'ruddy', 'sadly', 'scrod', 'sedge', 'sedum', 'sided', 'sprog', 'sudsy', 'suers', 'suets', 'sugar', 'suing', 'suite', 'supra', 'suras', 'surds']


In [37]:
# Step 4. Compute per-residue statistical features for each cluster
amino_acids = [el for el in string.ascii_letters[0:26]] + ['<masked_msa_token>']

features_per_cluster = {center: {} for center in cluster_centers}
for center in cluster_centers:
    # dict of statistical features indexed by residue position
    statistical_features = {key: {aa: 0 for aa in amino_acids} for key in range(5)}

    for res in range(5):  # iterate through residue positions
        for sequence in center_to_cluster[center]:  # iterate through sequences in a cluster
            aa = sequence[res]
            statistical_features[res][aa] += 1

    features_per_cluster[center] = statistical_features

# Step 5. An additional Nextra_seq are selected for input to the model, in addition to the cluster centers
MSA_input = cluster_centers
target_length = min(Nextra_seq + len(cluster_centers), len(five_letter_words) - len(cluster_centers))
while(len(MSA_input) < target_length):
    candidate_sequence = random.choice(list(sequence_to_center.keys()))
    if candidate_sequence not in MSA_input:
        MSA_input.append(candidate_sequence)


## The Evoformer Block

![Evoformer](pics/AlphaFold_5.png)

The input features for the model will be the sequences listed in MSA_input, plus the list of vectors contained in features_per_cluster. The pair representation, which we have not discussed in too much detail, is also used as input. The MSA representation and extra MSA sequence features are passed through a series of simple transformations (a shallow neural network) to 'embed' them. Included in this embedding is a 'positional' encoding of the residues in the protein sequence.

The Evoformer block takes in the embedded MSA representation $\{\bf{m}_{si}\}$ and Pair representation $\{\bf{z}_{ij}\}$, and outputs objects with the same dimensions but which have been highly processed. The algorithm for this processing is as follows:

![Algorithm6](pics/AlphaFold_3.png)

Below is an example implementation of one of the functions involved in the Evoformer block, the MSA Row-Attention With Pair Bias. Here are some notes on the other functions involved:
- MSAColumnAttention() is similar to MSARowAttentionWithPairBias
- MSATransition() is a 2-layer neural net which expands the number of channels of the MSA representation
- OuterProductMean() takes columns $c_i$ and $c_j$ from the current form of the MSA representation, constructs an outer product, passes this tensor through a linear transformation, and updates the Pair representation at position $i$, $j$ with this object
- Triangular multiplicative update is used to update the Pair representation using the triplet of elements from the matrix at coordinates at $ij$, $ik$, and $jk$.


In [10]:
import numpy as np
m_test = np.random.random((3, 15, 5)) # (num channels, num species, sequence length)
z_test = np.random.random((4, 5, 5)) # (num channels, sequence length, sequence length)


# Some auxillary functions which will be useful

def LayerNorm(x):
    """ Assumes x is channels first, normalizes within each channel and return normed x """
    for channel in range(x.shape[0]):
        top = np.amax(x[channel])
        x[channel] = np.divide(x[channel], top)
    return x

def Linear(x, A, b=None):
    """ Performs linear transformation of x according to transformation z=dot(x.T, A), possibly with bias b"""
    z = np.dot(x.T, A)
    if b != None:
        z += b
    return z


def sigmoid(x):
    return 1. / (1. + np.exp(-x))


def softmax(z, idx):
    numerator = np.exp(z[idx])
    denominator = np.sum(z)
    return numerator / denominator


# The main function of interest

def MSARowAttentionWithPairBias(m, z, c=32, Nhead=8):
    # Perform some checks and interpret input dimensions
    assert len(m.shape) == 3
    assert len(z.shape) == 3
    assert (z.shape[1] == z.shape[2]) and (z.shape[1] == m.shape[2])
    nMSAfeat = m.shape[0]  # number of MSA features
    nMSAseq = m.shape[1]  # number of sequences in the MSA representation
    MSAseqlen = m.shape[2]  # the MSA sequence length (may be truncated compared to original query sequence)
    nPairFeat = z.shape[0]  # number of pair representation features

    # -----------------
    # Input projections
    # -----------------
    m = LayerNorm(m)
    #   Query, key, value:
    query_kernel = np.random.random((Nhead, nMSAfeat, c))
    key_kernel = np.random.random((Nhead, nMSAfeat, c))
    value_kernel = np.random.random((Nhead, nMSAfeat, c))
    q = np.zeros((Nhead, c, nMSAseq, MSAseqlen))
    k = np.zeros((Nhead, c, nMSAseq, MSAseqlen))
    v = np.zeros((Nhead, c, nMSAseq, MSAseqlen))
    for head in range(Nhead):
        qh = Linear(m, query_kernel[head])
        kh = Linear(m, key_kernel[head])
        vh = Linear(m, value_kernel[head])

        q[head] = qh.T
        k[head] = kh.T
        v[head] = vh.T
        
    print("Shape of q: ", q.shape)  # (Nhead, c, nMSAseq, MSAseqlen)
    
    #   Bias:
    bias_kernel = np.random.random((Nhead, nPairFeat, 1))
    b = np.zeros((Nhead, MSAseqlen, MSAseqlen))
    for head in range(Nhead):
        bhij = Linear(LayerNorm(z), bias_kernel[head]).T
        b[head] = bhij
    
    print("Shape of bias: ", b.shape)
    
    #   Gates:
    gate_kernel = np.random.random((Nhead, nMSAfeat, c))
    g = np.zeros((Nhead, c, nMSAseq, MSAseqlen))
    for head in range(Nhead):
        gh = sigmoid(Linear(m, gate_kernel[head]))
        g[head] = gh.T
    print("Shape of gate: ", g.shape)

    # -----------------
    # Attention
    # -----------------
    a = np.zeros((Nhead, nMSAseq, MSAseqlen, MSAseqlen))
    for head in range(b.shape[0]):
        for s in range(nMSAseq):
            for i in range(MSAseqlen):
                for j in range(MSAseqlen): 
                    # contract along channel dimension
                    for c_idx in range(c):
                        a[head][s][i][j] += q[head][c_idx][s][i] * k[head][c_idx][s][j] / np.sqrt(c)
                    a[head][s][i][j] += b[head][i][j]
    # Apply softmax
    for head in range(b.shape[0]):
        for s in range(nMSAseq):
            for i in range(MSAseqlen):
                for j in range(MSAseqlen):
                    a[head][s][i][j] = softmax(a[head][s][i][:], j)
    # Compute output
    o = np.zeros(q.shape)
    for head in range(b.shape[0]):
        for c_idx in range(c):
            for s in range(nMSAseq):
                for i in range(MSAseqlen): 
                    o[head][c_idx][s][i] = g[head][c_idx][s][i] * np.dot(a[head][s][i][:], v[head][c_idx][s][:])
    
    # Output projection
    out = np.concatenate(list([o[h_idx] for h_idx in range(Nhead)]), axis=0)
    output_kernel = np.random.random(out.T.shape)
    m_tilde = np.multiply(out.T, output_kernel)
    return m_tilde
    
                    
MSARowAttentionWithPairBias(m_test, z_test)

Shape of q:  (8, 32, 15, 5)
Shape of bias:  (8, 5, 5)
Shape of gate:  (8, 32, 15, 5)


array([[[2.40620717e-01, 2.58168239e-01, 2.88862477e+00, ...,
         5.86423847e+00, 2.44761383e+00, 3.30228062e+00],
        [2.87441219e+00, 1.50218657e+00, 1.80428852e+01, ...,
         1.89387488e+00, 1.25631746e+01, 1.79627132e+01],
        [2.53785355e+00, 1.09704935e+00, 2.26570838e+00, ...,
         5.38563470e+00, 4.07125820e+00, 6.43282813e+00],
        ...,
        [4.29298243e+00, 2.58256149e+00, 3.19258138e+00, ...,
         6.99912087e+00, 3.18050866e+00, 5.94646685e+00],
        [1.43163335e+00, 1.24332109e+00, 1.65897406e+00, ...,
         4.89691864e+00, 2.10718864e-01, 1.93614068e+00],
        [4.46290324e+00, 1.71063912e+01, 1.77111791e+01, ...,
         2.49652131e+00, 1.78440157e+01, 4.80581656e+01]],

       [[1.48688113e+00, 6.48556569e-01, 2.50120552e+00, ...,
         3.31244045e+00, 1.67528203e+00, 2.19096645e+00],
        [1.95299614e+01, 1.91332993e+01, 5.84188259e+00, ...,
         1.51593494e+01, 1.93496591e+01, 4.25268619e+01],
        [3.39384357e+01, 

The triangle updates are emphasized by the authors as particularly important for the success of AlphaFold2.

Multiplicative updates using outgoing or incoming edges are fairly straightforward. The difference is just whether the update is done row-wise (outgoing) or column-wise (incoming) from the Pair representation.
- element $z_{ij}$ from the Pair representation is updated by taking rows (columns) $i$ and $j$, performing linear projections along the channel dimension, multiplying and rescaling, and adding the resulting vector to element $z_{ij}$

Triangular self-attention is very much like the MSA Row Attention but combines the query vector $q_{ij}^h$ with key vector $k_{kj}^h$ to form the attention value for value vector $v_{kj}^h$, and the output is summed over index $k$

The extra MSA sequences are treated somewhat differently to the MSA representation, using global attention (i.e., averaged over species) rather than local attention.

![TriangleUpdates](pics/AlphaFold_4.png)
