<a href="https://colab.research.google.com/github/mattfeng/research/blob/main/ESM_Variant_Prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/facebookresearch/esm/blob/master/examples/variant_prediction.ipynb)

# Variant prediction with ESM

This tutorial demonstrates how to train a simple variant predictor, i.e. we predict the biological activity of mutations of a protein, using fixed embeddings from ESM. You can adopt a similar protocol to train a model for any downstream task, even with limited data.

We will use a simple classifier in sklearn (or "head" on top of the transformer features) to predict the mutation effect from precomputed ESM embeddings. The embeddings for your dataset can be dumped once using a GPU. Then, the rest of your analysis can be done on CPU. 

### Background

In this particular example, we will train a model to predict the activity of ß-lactamase variants.

We provide the training in `examples/P62593.fasta`, a FASTA file where each entry contains:
- the mutated ß-lactamase sequence, where a single residue is mutated (swapped with another amino acid)
- the target value in the last field of the header, describing the scaled effect of the mutation

The [data originally comes](https://github.com/FowlerLab/Envision2017/blob/master/data/dmsTraining_2017-02-20.csv) from a deep mutational scan and was released with the Envision paper (Gray, et al. 2018)

### Goals
- Obtain an embedding (fixed-dimensional vector representation) for each mutated sequence.
- Train a regression model in sklearn that can predict the "effect" score given the embedding.


### Prerequisites
- You will need the following modules : tqdm, matplotlib, numpy, pandas, seaborn, scipy, scikit-learn
- You have obtained sequence embeddings for ß-lactamase as described in the README, either by:
    - running `python extract.py esm1_t34_670M_UR50S examples/P62593.fasta examples/P62593_reprs/ --repr_layers 34 --include mean`  OR 
    - for your convenience we precomputed the embeddings and you can download them from [here](https://dl.fbaipublicfiles.com/fair-esm/examples/P62593_reprs.tar.gz) - see below to download this right here from in this notebook


### Table of Contents
1. [Prelims](#prelims)
1. [Loading Embeddings](#load_embeddings)
1. [Visualizing Embeddings](#viz_embeddings)
1. [Initializing / Running Grid Search](#grid_search)
1. [Browse Grid Search Results](#browse)
1. [Evaluating Results](#eval)

<a id='prelims'></a>
## Prelims

If you are using colab, run the cell below.
It will pip install the `esm` code, fetch the fasta file and the pre-computed embeddings.

In [None]:
!pip install git+https://github.com/facebookresearch/esm.git
!curl -O https://dl.fbaipublicfiles.com/fair-esm/examples/P62593_reprs.tar.gz
!tar -xzf P62593_reprs.tar.gz
!curl -O https://dl.fbaipublicfiles.com/fair-esm/examples/P62593.fasta
!pwd
!ls

In [None]:
import random
from collections import Counter
from tqdm import tqdm

import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import esm

In [None]:
import scipy
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
from sklearn.svm import SVC, SVR
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegression, SGDRegressor

## Add the path to your embeddings here:


In [None]:
FASTA_PATH = "./P62593.fasta" # Path to P62593.fasta
EMB_PATH = "./P62593_reprs/" # Path to directory of embeddings for P62593.fasta
EMB_LAYER = 34

<a id='load_embeddings'></a>
## Load embeddings (Xs) and target effects (ys)
Our FASTA file is formatted as such:
```
>{index}|{mutation_id}|{effect}
{seq}
```
We will be extracting the effect from each entry.

Our embeddings are stored with the file name from fasta header: `{index}|{mutation_id}|{effect}.pt`

In [None]:
ys = []
Xs = []
for header, _seq in esm.data.read_fasta(FASTA_PATH):
    scaled_effect = header.split('|')[-1]
    ys.append(float(scaled_effect))
    fn = f'{EMB_PATH}/{header[1:]}.pt'
    embs = torch.load(fn)
    Xs.append(embs['mean_representations'][EMB_LAYER])
Xs = torch.stack(Xs, dim=0).numpy()
print(len(ys))
print(Xs.shape)

### Train / Test Split

Here we choose to follow the Envision paper, using 80% of the data for training, but we actually found that pre-trained ESM embeddings require fewer downstream training examples to reach the same level of performance.

In [None]:
train_size = 0.8
Xs_train, Xs_test, ys_train, ys_test = train_test_split(Xs, ys, train_size=train_size, random_state=42)

In [None]:
Xs_train.shape, Xs_test.shape, len(ys_train), len(ys_test)

### PCA

Principal Component Analysis is a popular technique for dimensionality reduction. Given `n_features` (1280 in our case), PCA computes a new set of `X` that "best explain the data." We've found that this enables downstream models to be trained faster with minimal loss in performance.  

Feel free to change it!!


In [None]:
pca = PCA(48)
Xs_train_pca = pca.fit_transform(Xs_train)

print(Xs_train_pca.shape)

<a id='viz_embeddings'></a>
## Visualize Embeddings

Here, we plot the first two principal components on the x- and y- axes. Each point is then colored by its scaled effect (what we want to predict).

Visually, we can see a separation based on color/effect, suggesting that our representations are useful for this task, without any task-specific training!

In [None]:
fig_dims = (7, 6)
fig, ax = plt.subplots(figsize=fig_dims)
sc = ax.scatter(Xs_train_pca[:,0], Xs_train_pca[:,1], c=ys_train, marker='.')
ax.set_xlabel('PCA first principal component')
ax.set_ylabel('PCA second principal component')
plt.colorbar(sc, label='Variant Effect')

<a id='grid_search'></a>

## Initialize / Run GridSearch

We will run grid search for three different regression models:
1. [K-nearest-neighbors](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsRegressor.html)
2. [SVM](https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVR.html?highlight=svr#sklearn.svm.SVR)
3. [Random Forest Regressor](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestRegressor.html?highlight=randomforestregressor#sklearn.ensemble.RandomForestRegressor)

Here, we will be using the PCA-projected features because we observe it does just as well as `Xs` while allowing for faster training. You can easily swap it out for `Xs`.

### Initialize grids for different regression techniques

In [None]:
knn_grid = {
    'n_neighbors': [5, 10],
    'weights': ['uniform', 'distance'],
    'algorithm': ['ball_tree', 'kd_tree', 'brute'],
    'leaf_size' : [15, 30],
    'p' : [1, 2],
}

svm_grid = {
    'C' : [0.1, 1.0, 10.0],
    'kernel' :['linear', 'poly', 'rbf', 'sigmoid'],
    'degree' : [3],
    'gamma': ['scale'],
}

rfr_grid = {
    'n_estimators' : [20],
    'criterion' : ['mse', 'mae'],
    'max_features': ['sqrt', 'log2'],
    'min_samples_split' : [5, 10],
    'min_samples_leaf': [1, 4]
}

**Here you choose which regressor you want to evaluate, uncomment the one you want and continue**

In [None]:
#regressor = KNeighborsRegressor
#param_grid = knn_grid

# regressor = SVR
# param_grid = svm_grid

# regressor = RandomForestRegressor
# param_grid = rfr_grid

### Run Grid Search 

(will take a few minutes on a single core)

In [None]:
print(regressor)
grid = GridSearchCV(
    estimator = regressor(), 
    param_grid = param_grid,
    scoring = 'r2',
    verbose = 1,
    n_jobs = -1 # use all available cores
)
grid.fit(Xs_train_pca, ys_train)
result = pd.DataFrame.from_dict(grid.cv_results_)

<a id='eval'></a>
## Evaluation

Now that we have run grid search, each `grid` object contains a `best_estimator_`.

We can use this to evaluate the correlation between our predictions and the true effect scores on the held-out validation set.

In [None]:
Xs_test_pca = pca.transform(Xs_test)

print(grid.best_estimator_)
print()
preds = grid.predict(Xs_test_pca)
print(f'{scipy.stats.spearmanr(ys_test, preds)}')
print('\n', '-' * 80, '\n')


Report which <a href="https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient">Spearman rho</a> you got. 