<a href="https://colab.research.google.com/github/hbp5181/Linear-Model-uisng-homolog-survey-data/blob/main/future_learning(sequence_to_numbers).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Inverse folding with ESM-IF1

The ESM-IF1 inverse folding model is built for predicting protein sequences from their backbone atom coordinates. We provide examples here 1) to sample sequence designs for a given structure and 2) to score sequences for a given structure.

Trained with 12M protein structures predicted by AlphaFold2, the ESM-IF1 model consists of invariant geometric input processing layers followed by a sequence-to-sequence transformer, and achieves 51% native sequence recovery on structurally held-out backbones. The model is also trained with span masking to tolerate missing backbone coordinates and therefore can predict sequences for partially masked structures.

See [GitHub README](https://github.com/facebookresearch/esm/tree/main/examples/inverse_folding) for the complete user guide, and see our [bioRxiv pre-print](https://doi.org/10.1101/2022.04.10.487779) for more details.

## Environment setup (colab)
This step might take up to 10 minutes the first time.

If using a local jupyter environment, instead of the following, we recommend configuring a conda environment upon first use in command line:
```
conda create -n inverse python=3.9
conda activate inverse
conda install pytorch cudatoolkit=11.3 -c pytorch
conda install pyg -c pyg -c conda-forge
conda install pip
pip install biotite
pip install git+https://github.com/facebookresearch/esm.git
```

Afterwards, `conda activate inverse` to activate this environment before starting `jupyter notebook`.

Below is the setup for colab notebooks:

We recommend using GPU runtimes on colab (Menu bar -> Runtime -> Change runtime type -> Hardware accelerator -> GPU)

In [2]:
# Colab environment setup

# Install the correct version of Pytorch Geometric.
import torch
import os

def format_pytorch_version(version):
  return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
  return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-cluster -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-spline-conv -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-geometric

# Install esm
!pip install -q git+https://github.com/facebookresearch/esm.git

# Install biotite
!pip install -q biotite

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m77.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.0/5.0 MB[0m [31m47.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m25.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m932.1/932.1 kB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for fair-esm (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.8/52.8 MB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[?25h

### Verify that pytorch-geometric is correctly installed

If the notebook crashes at the import, there is likely an issue with the version of torch_geometric and torch_sparse being incompatible with the torch version.

In [3]:
## Verify that pytorch-geometric is correctly installed
import torch_geometric
import torch_sparse
from torch_geometric.nn import MessagePassing

## Load model
This steps takes a few minutes for the model to download.

**UPDATE**: It is important to set the model in eval mode through `model = model.eval()` to disable random dropout for optimal performance.

In [4]:
import esm
model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()
model = model.eval()

Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm_if1_gvp4_t16_142M_UR50.pt" to /root/.cache/torch/hub/checkpoints/esm_if1_gvp4_t16_142M_UR50.pt


## Extract encoder output as structure representation
The encoder output may also be used as a representation for the structure.

For a set of input coordinates with L amino acids, the encoder output will have shape L x 512.

In [5]:
! esm-extract esm2_t33_650M_UR50D /content/RBD_align_SSM-backgrounds.fasta.txt\
  coordoutputRBD.fasta --repr_layers 33 --include mean
! esm-extract esm2_t33_650M_UR50D /content/ACE2_aa_modified.fasta \
  coordoutputACE2.fasta --repr_layers 33 --include mean



Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t33_650M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D-contact-regression.pt
Transferred model to GPU
Read /content/RBD_align_SSM-backgrounds.fasta.txt with 14 sequences
Processing 1 of 1 batches (14 sequences)
Transferred model to GPU
Read /content/ACE2_aa_modified.fasta with 62 sequences
Processing 1 of 13 batches (5 sequences)
Processing 2 of 13 batches (5 sequences)
Processing 3 of 13 batches (5 sequences)
Processing 4 of 13 batches (5 sequences)
Processing 5 of 13 batches (5 sequences)
Processing 6 of 13 batches (5 sequences)
Processing 7 of 13 batches (5 sequences)
Processing 8 of 13 batches (5 sequences)
Processing 9 of 13 batches (5 sequences)
Processing 10 of 13 batches (5 sequences)
Processing 11 of 13 batches (5 sequences)

In [6]:
# Specify the folders containing the .pt files
folder_paths = ['/content/coordoutputRBD.fasta', '/content/coordoutputACE2.fasta']

# Flatten the list of filenames
pt_files = [os.path.join(folder, f) for folder in folder_paths for f in os.listdir(folder) if f.endswith('.pt')]

# Iterate over each .pt file
for file_path in pt_files:
    # Load the model using torch.load
    model_dict = torch.load(file_path, map_location=torch.device('cpu'))
    for key, value in model_dict.items():
        print(value)


GD-Pangolin
{33: tensor([ 0.0261, -0.0281, -0.0530,  ..., -0.0307, -0.0705, -0.0272])}
SARS-CoV-1_PC4-137_PC04
{33: tensor([ 0.0064, -0.0248, -0.0798,  ..., -0.0482, -0.0354, -0.0502])}
RaTG13
{33: tensor([ 0.0063, -0.0141, -0.0844,  ..., -0.0330, -0.0599, -0.0377])}
SARS-CoV-2
{33: tensor([ 0.0224, -0.0117, -0.0591,  ..., -0.0322, -0.0758, -0.0261])}
Rs7327
{33: tensor([ 0.0082, -0.0257, -0.0588,  ..., -0.0348, -0.0495, -0.0406])}
AncSARS2a_MAP
{33: tensor([ 0.0359, -0.0191, -0.0611,  ..., -0.0384, -0.0611, -0.0403])}
BM48-31
{33: tensor([ 0.0467, -0.0302, -0.0562,  ..., -0.0146, -0.0907, -0.0511])}
SARS-CoV-1_Urbani_HP03L
{33: tensor([ 0.0076, -0.0189, -0.0704,  ..., -0.0407, -0.0424, -0.0426])}
AncSARS2c_MAP
{33: tensor([ 0.0261, -0.0249, -0.0607,  ..., -0.0385, -0.0698, -0.0339])}
AncSARS1a_MAP
{33: tensor([-0.0324, -0.0559, -0.0650,  ..., -0.0284, -0.0549, -0.1151])}
AncAsia_MAP
{33: tensor([ 0.0361, -0.0178, -0.0582,  ..., -0.0418, -0.0524, -0.0402])}
AncClade2_MAP
{33: tensor([ 

In [7]:
# Specify the folders containing the .pt files
folder_paths = ['/content/coordoutputRBD.fasta', '/content/coordoutputACE2.fasta']

formatted_dict = {}

# Iterate over each folder
for folder_path in folder_paths:
    # List all files in the folder with .pt extension
    pt_files = [f for f in os.listdir(folder_path) if f.endswith('.pt')]

    # Iterate over each .pt file in the current folder
    for file_name in pt_files:
        # Construct the full path to the file
        file_path = os.path.join(folder_path, file_name)

        # Load the model using torch.load
        model_dict = torch.load(file_path, map_location=torch.device('cpu'))

        # Extract label and tensor values
        label = model_dict['label']
        tensor_values = model_dict['mean_representations'][33].numpy()

        # Include the first three and last three numbers in the tensor
        first_three = ','.join(map(str, tensor_values[:3]))
        last_three = ','.join(map(str, tensor_values[-3:]))
        formatted_dict[label] = f'{first_three} {last_three}'
print(formatted_dict)

{'GD-Pangolin': '0.02613955,-0.02813793,-0.053003557 -0.0306876,-0.0704816,-0.027157158', 'SARS-CoV-1_PC4-137_PC04': '0.0064305943,-0.024829453,-0.07980017 -0.04821357,-0.035376027,-0.05023123', 'RaTG13': '0.0062999446,-0.014084336,-0.084428184 -0.033009015,-0.059891574,-0.037733473', 'SARS-CoV-2': '0.022413155,-0.011717588,-0.059060603 -0.032238685,-0.075829074,-0.02612066', 'Rs7327': '0.00818473,-0.025745649,-0.058836233 -0.03479355,-0.049494307,-0.04055334', 'AncSARS2a_MAP': '0.035894725,-0.019097477,-0.061080266 -0.038425736,-0.061066203,-0.0403473', 'BM48-31': '0.04666289,-0.030151362,-0.05622057 -0.014586593,-0.0907113,-0.051075354', 'SARS-CoV-1_Urbani_HP03L': '0.00760287,-0.018901302,-0.07042486 -0.040663492,-0.042407922,-0.04255275', 'AncSARS2c_MAP': '0.026069278,-0.024882397,-0.060680624 -0.038499042,-0.0697849,-0.03388014', 'AncSARS1a_MAP': '-0.032351878,-0.0559463,-0.06501481 -0.028435616,-0.05491547,-0.11509451', 'AncAsia_MAP': '0.036096375,-0.017799782,-0.058247004 -0.0417

In [8]:
formatted_dict[label] = f'{first_three} {last_three}'
for key, value in formatted_dict.items():
    print(f'{key}: {value}')

GD-Pangolin: 0.02613955,-0.02813793,-0.053003557 -0.0306876,-0.0704816,-0.027157158
SARS-CoV-1_PC4-137_PC04: 0.0064305943,-0.024829453,-0.07980017 -0.04821357,-0.035376027,-0.05023123
RaTG13: 0.0062999446,-0.014084336,-0.084428184 -0.033009015,-0.059891574,-0.037733473
SARS-CoV-2: 0.022413155,-0.011717588,-0.059060603 -0.032238685,-0.075829074,-0.02612066
Rs7327: 0.00818473,-0.025745649,-0.058836233 -0.03479355,-0.049494307,-0.04055334
AncSARS2a_MAP: 0.035894725,-0.019097477,-0.061080266 -0.038425736,-0.061066203,-0.0403473
BM48-31: 0.04666289,-0.030151362,-0.05622057 -0.014586593,-0.0907113,-0.051075354
SARS-CoV-1_Urbani_HP03L: 0.00760287,-0.018901302,-0.07042486 -0.040663492,-0.042407922,-0.04255275
AncSARS2c_MAP: 0.026069278,-0.024882397,-0.060680624 -0.038499042,-0.0697849,-0.03388014
AncSARS1a_MAP: -0.032351878,-0.0559463,-0.06501481 -0.028435616,-0.05491547,-0.11509451
AncAsia_MAP: 0.036096375,-0.017799782,-0.058247004 -0.04177695,-0.052415904,-0.040235203
AncClade2_MAP: 0.040610