<a href="https://colab.research.google.com/github/hbp5181/Linear-Model-uisng-homolog-survey-data/blob/main/future_learning(sequence_to_numbers).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Inverse folding with ESM-IF1

The ESM-IF1 inverse folding model is built for predicting protein sequences from their backbone atom coordinates. We provide examples here 1) to sample sequence designs for a given structure and 2) to score sequences for a given structure.

Trained with 12M protein structures predicted by AlphaFold2, the ESM-IF1 model consists of invariant geometric input processing layers followed by a sequence-to-sequence transformer, and achieves 51% native sequence recovery on structurally held-out backbones. The model is also trained with span masking to tolerate missing backbone coordinates and therefore can predict sequences for partially masked structures.

See [GitHub README](https://github.com/facebookresearch/esm/tree/main/examples/inverse_folding) for the complete user guide, and see our [bioRxiv pre-print](https://doi.org/10.1101/2022.04.10.487779) for more details.

## Environment setup (colab)
This step might take up to 10 minutes the first time.

If using a local jupyter environment, instead of the following, we recommend configuring a conda environment upon first use in command line:
```
conda create -n inverse python=3.9
conda activate inverse
conda install pytorch cudatoolkit=11.3 -c pytorch
conda install pyg -c pyg -c conda-forge
conda install pip
pip install biotite
pip install git+https://github.com/facebookresearch/esm.git
```

Afterwards, `conda activate inverse` to activate this environment before starting `jupyter notebook`.

Below is the setup for colab notebooks:

We recommend using GPU runtimes on colab (Menu bar -> Runtime -> Change runtime type -> Hardware accelerator -> GPU)

In [2]:
# Colab environment setup

# Install the correct version of Pytorch Geometric.
import torch
import os

def format_pytorch_version(version):
  return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
  return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-cluster -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-spline-conv -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-geometric

# Install esm
!pip install -q git+https://github.com/facebookresearch/esm.git

# Install biotite
!pip install -q biotite

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m77.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.0/5.0 MB[0m [31m47.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m25.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m932.1/932.1 kB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for fair-esm (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.8/52.8 MB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[?25h

### Verify that pytorch-geometric is correctly installed

If the notebook crashes at the import, there is likely an issue with the version of torch_geometric and torch_sparse being incompatible with the torch version.

In [3]:
## Verify that pytorch-geometric is correctly installed
import torch_geometric
import torch_sparse
from torch_geometric.nn import MessagePassing

## Load model
This steps takes a few minutes for the model to download.

**UPDATE**: It is important to set the model in eval mode through `model = model.eval()` to disable random dropout for optimal performance.

In [4]:
import esm
model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()
model = model.eval()

Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm_if1_gvp4_t16_142M_UR50.pt" to /root/.cache/torch/hub/checkpoints/esm_if1_gvp4_t16_142M_UR50.pt


## Extract encoder output as structure representation
The encoder output may also be used as a representation for the structure.

For a set of input coordinates with L amino acids, the encoder output will have shape L x 512.

In [9]:
! esm-extract esm2_t33_650M_UR50D /content/ancestors_unique.fasta.txt\
  coordoutputRBD.fasta --repr_layers 33 --include mean
! esm-extract esm2_t33_650M_UR50D /content/ACE2_aa_modified.fasta \
  coordoutputACE2.fasta --repr_layers 33 --include mean



Transferred model to GPU
Read /content/ancestors_unique.fasta.txt with 34 sequences
Processing 1 of 2 batches (20 sequences)
Processing 2 of 2 batches (14 sequences)
Transferred model to GPU
Read /content/ACE2_aa_modified.fasta with 62 sequences
Processing 1 of 13 batches (5 sequences)
Processing 2 of 13 batches (5 sequences)
Processing 3 of 13 batches (5 sequences)
Processing 4 of 13 batches (5 sequences)
Processing 5 of 13 batches (5 sequences)
Processing 6 of 13 batches (5 sequences)
Processing 7 of 13 batches (5 sequences)
Processing 8 of 13 batches (5 sequences)
Processing 9 of 13 batches (5 sequences)
Processing 10 of 13 batches (5 sequences)
Processing 11 of 13 batches (5 sequences)
Processing 12 of 13 batches (5 sequences)
Processing 13 of 13 batches (2 sequences)


In [10]:
# Specify the folders containing the .pt files
folder_paths = ['/content/coordoutputRBD.fasta', '/content/coordoutputACE2.fasta']

# Flatten the list of filenames
pt_files = [os.path.join(folder, f) for folder in folder_paths for f in os.listdir(folder) if f.endswith('.pt')]

# Iterate over each .pt file
for file_path in pt_files:
    # Load the model using torch.load
    model_dict = torch.load(file_path, map_location=torch.device('cpu'))
    for key, value in model_dict.items():
        print(value)


AncClade2_alt3_del2-only
{33: tensor([ 0.0325, -0.0057, -0.0603,  ..., -0.0508, -0.0531, -0.0547])}
GD-Pangolin
{33: tensor([ 0.0261, -0.0281, -0.0530,  ..., -0.0307, -0.0705, -0.0272])}
SARS-CoV-1_PC4-137_PC04
{33: tensor([ 0.0064, -0.0248, -0.0798,  ..., -0.0482, -0.0354, -0.0502])}
AncSARS-CoV-1_04_MAP
{33: tensor([ 0.0064, -0.0248, -0.0798,  ..., -0.0482, -0.0354, -0.0502])}
AncSARS-CoV-1_MAP
{33: tensor([ 0.0106, -0.0189, -0.0610,  ..., -0.0418, -0.0390, -0.0469])}
AncSARS1c_MAP
{33: tensor([ 0.0083, -0.0246, -0.0575,  ..., -0.0392, -0.0420, -0.0449])}
AncSARS1b_MAP
{33: tensor([ 0.0082, -0.0257, -0.0588,  ..., -0.0348, -0.0495, -0.0406])}
AncSARS1a_altALL
{33: tensor([ 0.0504, -0.0259, -0.0451,  ..., -0.0230, -0.0652, -0.0641])}
AncSARS1a_tree1
{33: tensor([ 0.0362, -0.0343, -0.0391,  ..., -0.0121, -0.0693, -0.0757])}
AncSARS-CoV-1_04_human_MAP
{33: tensor([ 0.0042, -0.0262, -0.0718,  ..., -0.0444, -0.0309, -0.0504])}
AncAsia_tree2
{33: tensor([ 0.0377, -0.0165, -0.0584,  ..., -0

In [11]:
# Specify the folders containing the .pt files
folder_paths = ['/content/coordoutputRBD.fasta', '/content/coordoutputACE2.fasta']

formatted_dict = {}

# Iterate over each folder
for folder_path in folder_paths:
    # List all files in the folder with .pt extension
    pt_files = [f for f in os.listdir(folder_path) if f.endswith('.pt')]

    # Iterate over each .pt file in the current folder
    for file_name in pt_files:
        # Construct the full path to the file
        file_path = os.path.join(folder_path, file_name)

        # Load the model using torch.load
        model_dict = torch.load(file_path, map_location=torch.device('cpu'))

        # Extract label and tensor values
        label = model_dict['label']
        tensor_values = model_dict['mean_representations'][33].numpy()

        # Include the first three and last three numbers in the tensor
        first_three = ','.join(map(str, tensor_values[:3]))
        last_three = ','.join(map(str, tensor_values[-3:]))
        formatted_dict[label] = f'{first_three} {last_three}'
print(formatted_dict)

{'AncClade2_alt3_del2-only': '0.03252849,-0.0057009296,-0.06025492 -0.050784655,-0.05307722,-0.054704864', 'GD-Pangolin': '0.02613955,-0.02813793,-0.053003557 -0.0306876,-0.0704816,-0.027157158', 'SARS-CoV-1_PC4-137_PC04': '0.0064305943,-0.024829453,-0.07980017 -0.04821357,-0.035376027,-0.05023123', 'AncSARS-CoV-1_04_MAP': '0.0064305943,-0.024829453,-0.07980017 -0.04821357,-0.035376027,-0.05023123', 'AncSARS-CoV-1_MAP': '0.010614625,-0.01888332,-0.0609688 -0.04184789,-0.038951673,-0.046921305', 'AncSARS1c_MAP': '0.008294074,-0.024623323,-0.057482157 -0.03919286,-0.041987896,-0.04487475', 'AncSARS1b_MAP': '0.00818473,-0.025745649,-0.058836233 -0.03479355,-0.049494307,-0.04055334', 'AncSARS1a_altALL': '0.050407987,-0.025873747,-0.04511696 -0.022991257,-0.06522087,-0.06413244', 'AncSARS1a_tree1': '0.03622056,-0.03428107,-0.039133392 -0.012105076,-0.06929726,-0.07569118', 'AncSARS-CoV-1_04_human_MAP': '0.0041503105,-0.026185496,-0.07177147 -0.044387903,-0.030921025,-0.050378557', 'AncAsia_

In [12]:
formatted_dict[label] = f'{first_three} {last_three}'
for key, value in formatted_dict.items():
    print(f'{key}: {value}')

AncClade2_alt3_del2-only: 0.03252849,-0.0057009296,-0.06025492 -0.050784655,-0.05307722,-0.054704864
GD-Pangolin: 0.02613955,-0.02813793,-0.053003557 -0.0306876,-0.0704816,-0.027157158
SARS-CoV-1_PC4-137_PC04: 0.0064305943,-0.024829453,-0.07980017 -0.04821357,-0.035376027,-0.05023123
AncSARS-CoV-1_04_MAP: 0.0064305943,-0.024829453,-0.07980017 -0.04821357,-0.035376027,-0.05023123
AncSARS-CoV-1_MAP: 0.010614625,-0.01888332,-0.0609688 -0.04184789,-0.038951673,-0.046921305
AncSARS1c_MAP: 0.008294074,-0.024623323,-0.057482157 -0.03919286,-0.041987896,-0.04487475
AncSARS1b_MAP: 0.00818473,-0.025745649,-0.058836233 -0.03479355,-0.049494307,-0.04055334
AncSARS1a_altALL: 0.050407987,-0.025873747,-0.04511696 -0.022991257,-0.06522087,-0.06413244
AncSARS1a_tree1: 0.03622056,-0.03428107,-0.039133392 -0.012105076,-0.06929726,-0.07569118
AncSARS-CoV-1_04_human_MAP: 0.0041503105,-0.026185496,-0.07177147 -0.044387903,-0.030921025,-0.050378557
AncAsia_tree2: 0.037692346,-0.016511634,-0.058424648 -0.0393