# ESMFold batch

* This notebook is based on the ESMFold implementation of [ColabFold](https://github.com/sokrypton/ColabFold#making-protein-folding-accessible-to-all-via-google-colab).

  * Input: a fasta file with your proteins
  * Output: predicted structures saved in your Google drive and a csv file with quality metrics.
* All `.pdb` files are saved in a directory hosted in your google Drive. **You have to create a directory in the main root of your Drive**.
* Due to the memory limits of the free GPUs provided by Google (Tesla T4), the maximum length of a protein sequence that can be predicted is ~800aa.
* This notebook is designed to predict thousands of structures. If you only want to predict a single sequence under 400aa try with the [ESMFold webserver](https://esmatlas.com/resources?action=fold).

----
* Written by GAMA ([@miangoar](https://twitter.com/miangoar) on X/Twitter)
* Date: 08/2024

In [1]:
%%time
#@title 1) Install ESMFold and OpenFold utils
version = "1" #
model_name = "esmfold_v0.model" if version == "0" else "esmfold.model"
import os, time
if not os.path.isfile(model_name):
  # download esmfold params
  os.system("apt-get install aria2 -qq")
  os.system(f"aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/esm/{model_name} &")

  if not os.path.isfile("finished_install"):
    # install libs
    print("installing libs...")
    os.system("pip install -q omegaconf pytorch_lightning biopython ml_collections einops py3Dmol modelcif")
    os.system("pip install -q git+https://github.com/NVIDIA/dllogger.git")

    print("installing openfold...")
    # install openfold
    os.system(f"pip install -q git+https://github.com/sokrypton/openfold.git")

    print("installing esmfold...")
    # install esmfold
    os.system(f"pip install -q git+https://github.com/sokrypton/esm.git")
    os.system("touch finished_install")

  # wait for Params to finish downloading...
  while not os.path.isfile(model_name):
    time.sleep(5)
  if os.path.isfile(f"{model_name}.aria2"):
    print("downloading params...")
  while os.path.isfile(f"{model_name}.aria2"):
    time.sleep(5)

installing libs...
installing openfold...
installing esmfold...
CPU times: user 891 ms, sys: 158 ms, total: 1.05 s
Wall time: 4min 2s


In [2]:
%%time
#@title 2) Connect to google drive (It will ask for permissions)

# import libs
from pydrive2.drive import GoogleDrive
from pydrive2.auth import GoogleAuth
from google.colab import auth
from oauth2client.client import GoogleCredentials
from tqdm import tqdm
import pandas as pd
import plotly.express as px
from google.colab import files

# google drive login
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

CPU times: user 1.05 s, sys: 127 ms, total: 1.17 s
Wall time: 14.6 s


In [3]:
%%time
#@title 3) Set the name of your directory in Google Drive where the pdb files will be saved
#@markdown Note: you have to manually create the directory in your Drive, `ESMFold_results` is just an arbitrary name of example.

# get the name of the dir
folder_id = 'ESMFold_results' #@param {type:"string"}
file_list = drive.ListFile({'q': "mimeType='application/vnd.google-apps.folder' and trashed=false"}).GetList()
for file in file_list:
    if file['title'] == folder_id:
        folder_id = file['id']
        break

CPU times: user 132 ms, sys: 7.74 ms, total: 140 ms
Wall time: 3.68 s


In [4]:
#@title ## 4) Import utilities
%%time
from string import ascii_uppercase, ascii_lowercase
import hashlib, re, os
import numpy as np
import torch
from jax.tree_util import tree_map
import matplotlib.pyplot as plt
from scipy.special import softmax
import gc
from google.colab import files
import os
import pandas as pd

def parse_output(output):
  pae = (output["aligned_confidence_probs"][0] * np.arange(64)).mean(-1) * 31
  plddt = output["plddt"][0,:,1]

  bins = np.append(0,np.linspace(2.3125,21.6875,63))
  sm_contacts = softmax(output["distogram_logits"],-1)[0]
  sm_contacts = sm_contacts[...,bins<8].sum(-1)
  xyz = output["positions"][-1,0,:,1]
  mask = output["atom37_atom_exists"][0,:,1] == 1
  o = {"pae":pae[mask,:][:,mask],
       "plddt":plddt[mask],
       "sm_contacts":sm_contacts[mask,:][:,mask],
       "xyz":xyz[mask]}
  return o

def get_hash(x): return hashlib.sha1(x.encode()).hexdigest()
alphabet_list = list(ascii_uppercase+ascii_lowercase)

# seqkit install
! curl -s -O -L https://github.com/shenwei356/seqkit/releases/download/v2.8.0/seqkit_linux_amd64.tar.gz
! tar -xzvf seqkit_linux_amd64.tar.gz
! cp seqkit /usr/local/bin/

seqkit
CPU times: user 3.18 s, sys: 329 ms, total: 3.51 s
Wall time: 8.39 s


In [10]:
%%time
#@title ## 5) Load your protein sequences in fasta format
#@markdown

#@title 4) Upload your custom homologous sequences (in .fasta format)

uploaded = files.upload()

# get the original name and rename it
uploaded_filename = list(uploaded.keys())[0]
os.rename(uploaded_filename, "input_proteins.fasta")

# convert the fasta to csv
! seqkit fx2tab input_proteins.fasta -l -Q > input_proteins.csv

# create a df from the msa
df = pd.read_csv("input_proteins.csv", sep="\t",
                 names=["protein_name","seq", "length"], usecols=[0,1,2])

# rename protein names
df["protein_name"] = df["protein_name"].str.replace(r"\W+", "_", regex=True)
df["protein_name"] = df["protein_name"].str.rstrip("_")

# print stats
print("\n###########################################")
print("Descriptive statistics from input sequences")
print("###########################################\n")
! seqkit stats -a input_proteins.fasta

df

Saving proteins.fasta to proteins.fasta

###########################################
Descriptive statistics from input sequences
###########################################

file                  format  type     num_seqs  sum_len  min_len  avg_len  max_len  Q1  Q2  Q3  sum_gap  N50  N50_num  Q20(%)  Q30(%)  AvgQual  GC(%)
input_proteins.fasta  FASTA   Protein        21    1,433       65     68.2       75  66  68  69        0   68        4       0       0        0    6.7
CPU times: user 60.7 ms, sys: 13.2 ms, total: 73.9 ms
Wall time: 7.62 s


Unnamed: 0,protein_name,seq,length
0,seq_0,DPLEWTPEHVQQWLSWVSKKFSLDPIDPDRFPMNGKELCALSKEDF...,67
1,seq_1,QPIYWSRDDVAQWLKWAENEFSLRPIDSNTFEMNGKALLLLTKEDF...,66
2,seq_2,DPRQWTETHVRDWVMWAVNEFSLKGVDFQKFCMSGAALCALGKECF...,67
3,seq_3,QPQFWSKTQVLDWISYQVEKNKYDASAIDFSRCDMDGATLCNCALE...,68
4,seq_4,DPIHWSTDQVLHWVVWVMKEFSMTDIDLTTLNISGRELCSLNQEDF...,65
5,seq_5,DPTLWTQEHVRQWLEWAIKEYSLMEIDTSFFQNMDGKELCKMNKED...,68
6,seq_6,DPMDWSPSNVQKWLLWTEHQYRLPPMGKAFQELAGKELCAMSEEQF...,66
7,seq_7,DPRDWTRADVWKWLINMAVSEGLEVTAELPQKFPMNGKALCLMSLD...,68
8,seq_8,DPLEWTNTHIKSWLSWCSRKFSLNPKPDFEKFPTTGKELCELTRTD...,69
9,seq_9,DPAEWNSEHVSQWLNWTTKKFRLNPKPDCDKFPKTGVELCELTKSD...,69


In [11]:
#@title ## 6) Run ESMFold
%%time

model_lst = []
plddt_lst = []
ptm_lst = []
length_lst = []

print("Performing structure prediction with ESMFold v1")
pbar = tqdm(total=len(df), )

for idx, row in df.iterrows():

  protein_name = str(row["protein_name"])
  seq = str(row["seq"])
  jobname = protein_name
  jobname = re.sub(r'\W+', '', jobname)[:50]
  sequence = seq

  sequence = re.sub("[^A-Z:]", "", sequence.replace("/",":").upper())
  sequence = re.sub(":+",":",sequence)
  sequence = re.sub("^[:]+","",sequence)
  sequence = re.sub("[:]+$","",sequence)
  copies = 1
  if copies == "" or copies <= 0: copies = 1
  sequence = ":".join([sequence] * copies)
  num_recycles = 3
  chain_linker = 25

  ID = jobname+"_"+get_hash(sequence)[:5]
  seqs = sequence.split(":")
  lengths = [len(s) for s in seqs]
  length = sum(lengths)
  # print("length",length)

  u_seqs = list(set(seqs))
  if len(seqs) == 1: mode = "mono"
  elif len(u_seqs) == 1: mode = "homo"
  else: mode = "hetero"

  if "model" not in dir() or model_name != model_name_:
    if "model" in dir():
      # delete old model from memory
      del model
      gc.collect()
      if torch.cuda.is_available():
        torch.cuda.empty_cache()

    model = torch.load(model_name)
    model.eval().cuda().requires_grad_(False)
    model_name_ = model_name

  # optimized for Tesla T4
  if length > 700:
    model.set_chunk_size(64)
  else:
    model.set_chunk_size(128)

  torch.cuda.empty_cache()
  output = model.infer(sequence,
                       num_recycles=num_recycles,
                       chain_linker="X"*chain_linker,
                       residue_index_offset=512)

  pdb_str = model.output_to_pdb(output)[0]
  output = tree_map(lambda x: x.cpu().numpy(), output)
  ptm = output["ptm"][0]
  plddt = output["plddt"][0,...,1].mean()
  O = parse_output(output)
  #print(f'ptm: {ptm:.3f} plddt: {plddt:.3f}')
  #os.system(f"mkdir -p {ID}")
  #prefix = f"{ID}/ptm{ptm:.3f}_r{num_recycles}_default"
  #np.savetxt(f"{prefix}.pae.txt",O["pae"],"%.3f")
  with open(f"{protein_name}.pdb","w") as out:
    out.write(pdb_str)

  model_lst.append(jobname)
  ptm_lst.append(ptm)
  plddt_lst.append(plddt)
  length_lst.append(length)

  # upload the results to google drive
  uploaded = drive.CreateFile({'title': f"{protein_name}.pdb", 'parents': [{'id': folder_id}]})
  uploaded.SetContentFile(f"{protein_name}.pdb")
  uploaded.Upload()
  # update the progress bar
  pbar.update(1)
pbar.close()

Performing structure prediction with ESMFold v1


100%|██████████| 21/21 [01:30<00:00,  4.32s/it]

CPU times: user 39.9 s, sys: 355 ms, total: 40.2 s
Wall time: 1min 30s





In [12]:
#@title 8) Plot the quality of the models

# create the csv
dfq = pd.DataFrame()
dfq["model"] = model_lst
dfq["plddt"] = plddt_lst
dfq["ptm"] = ptm_lst

# plot
fig = px.histogram(dfq, x="model", y="plddt",  height = 600, width = 1200, hover_name="model")
fig.update_traces(marker=dict(color = "red", size=8, line=dict(width=1, color='black')),selector=dict(mode='markers'))
fig.update_layout(template="plotly_white")
fig.update_yaxes(showline=True, linewidth=1, linecolor='LightGrey')
fig.update_xaxes(showline=True, linewidth=1, linecolor='LightGrey')
fig.update_layout(yaxis_title="Average pLDDT")
fig.show()

In [13]:
#@title 9) Download a csv with the quality metrics

dfq.to_csv("ESMFold_predictions_metrics.csv", index=False)

files.download('ESMFold_predictions_metrics.csv')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>