[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/facebookresearch/esm/blob/master/examples/variant_prediction.ipynb)

# Variant prediction with ESM

This tutorial demonstrates how to train a variant predictor using representations from ESM. You can adopt a similar protocol to train a model for any downstream task, even with limited data.

In this tutorial, we will build a simple downstream head in sklearn to predict the effects of mutations from ESM representations. All representations can be dumped before fitting the top model. Therefore, representations for your dataset can be dumped once using a GPU. Then, the rest of your analysis can be performed on CPU. 

### Background

In this particular example, we will train a model to predict the activity of ß-lactamase variants.

Training data is located in `examples/P62593.fasta`, a FASTA file where each entry contains:
- the mutated ß-lactamase sequence, where a single residue is mutated (swapped with another amino acid)
- a float describing the scaled effect of the mutation

The data was retrieved from the Envision paper (Gray, et al. 2018).

### Goals
- Obtain a representation for each mutated sequence.
- Train a regression model in sklearn that can predict the "effect" score given the representation.


### Prerequisites
- You will need the following modules : tqdm, matplotlib, numpy, pandas, seaborn, scipy, scikit-learn
- You have obtained sequence representations for ß-lactamase either by:
    - downloading them from [here](https://dl.fbaipublicfiles.com/fair-esm/example/P62593_reprs.tar.gz)
    - running `python extract.py esm1_t34_670M_UR50S examples/P62593.fasta my_reprs/ --repr_layers 34 --include mean`



### Table of Contents
1. [Prelims](#prelims)
2. [Loading Model](#load_model)
3. [Loading Representations](#load_representations)
4. [Visualizing Representations](#viz_representations)
5. [Initializing / Running Grid Search](#grid_search)
6. [Browse Grid Search Results](#browse)
7. [Evaluating Results](#eval)

<a id='prelims'></a>
## Prelims

We assume you pip installed the repo as per the README. Otherwise, you can `sys.path.append(<path_to_repo>)`

In [None]:
# import sys
# PATH_TO_REPO = ""
# sys.path.append(PATH_TO_REPO)

In [None]:
import glob
import random
from collections import Counter
from pathlib import Path

from tqdm import tqdm

import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import esm

In [None]:
import scipy

from sklearn.model_selection import GridSearchCV, train_test_split

from sklearn.decomposition import PCA

from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
from sklearn.svm import SVC, SVR
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegression, SGDRegressor

## Add the path to your representations here:


In [None]:
REPR_PATH = "../my_reprs" # Path to directory of representations for P62593.fasta
FASTA_PATH = "../examples/P62593.fasta" # Path to P62593.fasta

<a id='load_model'></a>
## Load model

In [None]:
model, alphabet = esm.pretrained.esm1_t34_670M_UR50S()
batch_converter = alphabet.get_batch_converter()

<a id='load_representations'></a>
## Load Representations (Xs) and Target Effects (ys)
Our FASTA file is formatted as such:
```
>{index}|{mutation_id}|{effect}
{seq}
```
We will be extracting the effect from each entry.

Our representations are stored under the file name `{index}|{mutation_id}|{effect}.pt`

So, entries and representations should be linked by `{index}`

In [None]:
ys = []
for header, _seq in esm.data.read_fasta(FASTA_PATH):
    scaled_effect = header.split('|')[-1]
    ys.append(float(scaled_effect))
print(len(ys))

In [None]:
pattern = Path(REPR_PATH) / '*.pt'
files = glob.glob(str(pattern))
num_datapoints = len(files)
dimension = 1280
Xs = torch.zeros((num_datapoints, dimension))
ind_set = set()
for f in tqdm(files):
    data = torch.load(f)
    ind = int(Path(f).name.split('|')[0])
    Xs[ind] = data['mean_representations'][34]
    ind_set.add(ind)
print(len(ind_set))

### PCA

Principal Component Analysis is a popular technique for dimensionality reduction. Given `n_features` (1280 in our case), PCA computes a new set of `X` that "best explain the data." We've found that this enables downstream models to be trained faster with minimal loss in performance.  

Here, we set `X` to 60, but feel free to change it!


In [None]:
pca = PCA(60)
Xs_pca = pca.fit_transform(Xs)

<a id='viz_representations'></a>
## Visualize Representations

Here, we plot the first two principal components on the x- and y- axes. Each point is then colored by its scaled effect (what we want to predict).

Visually, we can see a separation based on color/effect, suggesting that our representations are useful for this task, without any task-specific training!

In [None]:
fig_dims = (7, 6)
fig, ax = plt.subplots(figsize=fig_dims)
sc = ax.scatter(Xs_pca[:,0], Xs_pca[:,1], c=ys, marker='.')
ax.set_xlabel('PCA first principal component')
ax.set_ylabel('PCA second principal component')
plt.colorbar(sc, label='Variant Effect')

<a id='grid_search'></a>

## Initialize / Run GridSearch

We will run grid search for three different regression models:
1. [K-nearest-neighbors](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsRegressor.html)
2. [SVM](https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVR.html?highlight=svr#sklearn.svm.SVR)
3. [Random Forest Regressor](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestRegressor.html?highlight=randomforestregressor#sklearn.ensemble.RandomForestRegressor)

### Initialize grids for different regression techniques

In [None]:
knn_grid = {
    'n_neighbors': [5, 10],
    'weights': ['uniform', 'distance'],
    'algorithm': ['ball_tree', 'kd_tree', 'brute'],
    'leaf_size' : [15, 30],
    'p' : [1, 2],
}

svm_grid = {
    'C' : [1.0],
    'kernel' :['linear', 'poly', 'rbf', 'sigmoid'],
    'degree' : [3],
    'gamma': ['scale'],
}

rfr_grid = {
    'n_estimators' : [20],
    'criterion' : ['mse', 'mae'],
    'max_features': ['sqrt', 'log2'],
    'min_samples_split' : [5, 10],
    'min_samples_leaf': [1, 4]
}

In [None]:
cls_list = [KNeighborsRegressor, SVR, RandomForestRegressor]
param_grid_list = [knn_grid, svm_grid, rfr_grid]

### Train / Test Split

Choose what fraction of the data to use for training. The Envision paper uses 80% of the data for training, but we find that pre-trained ESM representations require fewer downstream training examples to reach the same level of performance.

Here, we will be using `Xs_pca`, because we observe it does just as well as `Xs` while allowing for faster training. You can easily swap it out for `Xs`.

Some values to try out:
- 0.01
- 0.10
- 0.30
- 0.50
- 0.80

In [None]:
train_size = 0.8
Xs_train, Xs_test, ys_train, ys_test = train_test_split(Xs_pca, ys, train_size=train_size, random_state=42)

In [None]:
Xs_train.shape, Xs_test.shape, len(ys_train), len(ys_test)

### Run Grid Search 

(will take a few minutes on a single core)

In [None]:
result_list = []
grid_list = []
for cls_name, param_grid in zip(cls_list, param_grid_list):
    print(cls_name)
    grid = GridSearchCV(
        estimator = cls_name(), 
        param_grid = param_grid,
        scoring = 'r2',
        verbose = 1
    )
    grid.fit(Xs_train, ys_train)
    result_list.append(pd.DataFrame.from_dict(grid.cv_results_))
    grid_list.append(grid)

<a id='browse'></a>
## Browse the Sweep Results

The following tables show the top 5 parameter settings, based on `mean_test_score`. Given our setup, this should really be thought of as `validation_score`.

### K Nearest Neighbors

In [None]:
result_list[0].sort_values('rank_test_score')[:5]

### SVM

In [None]:
result_list[1].sort_values('rank_test_score')

### Random Forest

In [None]:
result_list[2].sort_values('rank_test_score')[:5]

<a id='eval'></a>
## Evaluation

Now that we have run grid search, each `grid` object contains a `best_estimator_`.

We can use this to evaluate the correlation between our predictions and the true effect scores on the held-out validation set.

In [None]:
for grid in grid_list:
    print(grid.best_estimator_)
    print()
    preds = grid.predict(Xs_test)
    print(f'{scipy.stats.spearmanr(ys_test, preds)}')
    print('\n', '-' * 80, '\n')


The SVM performs the best on the `test` set, with a spearman rho of 0.80! 

This is in line with our grid-search results, where it also had the best `validation` performance.

In conclusion, our downstream model was able to use pre-trained ESM representations and obtain a decent result.

(For reference, we report correlation of 0.89 in Table 7 of our paper, but this was achieved by fine-tuning the model)