<a href="https://colab.research.google.com/github/espickle1/esmfold_colabfold/blob/main/ESMFold_ColabFold_second.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
### Colabfold running ESMFold

In [None]:
## Install necessary packages
%%time
version = "1"
model_name = "esmfold_v0.model" if version == "0" else "esmfold.model"

import os, time

if not os.path.isfile(model_name):
  # download esmfold params
  os.system("apt-get install aria2 -qq")
  os.system(f"aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/esm/{model_name} &")

  if not os.path.isfile("finished_install"):
    # install libs
    print("installing libs...")
    os.system("pip install -q omegaconf pytorch_lightning biopython ml_collections einops py3Dmol modelcif")
    os.system("pip install -q git+https://github.com/NVIDIA/dllogger.git")

    print("installing openfold...")
    # install openfold
    os.system(f"pip install -q git+https://github.com/sokrypton/openfold.git")

    print("installing esmfold...")
    # install esmfold
    os.system(f"pip install -q git+https://github.com/sokrypton/esm.git")
    os.system("touch finished_install")

  # wait for Params to finish downloading...
  while not os.path.isfile(model_name):
    time.sleep(5)
  if os.path.isfile(f"{model_name}.aria2"):
    print("downloading params...")
  while os.path.isfile(f"{model_name}.aria2"):
    time.sleep(5)

In [None]:
## Import dependencies
import torch
from jax.tree_util import tree_map
import gc

from string import ascii_uppercase, ascii_lowercase
import hashlib, re, os
import numpy as np

import pandas as pd
import matplotlib.pyplot as plt
from scipy.special import softmax
import py3Dmol

from google.colab import drive
drive.mount('/content/drive')

In [None]:
## Parsing outputs and get unique hash
def parse_output(output):
  pae = (output["aligned_confidence_probs"][0] * np.arange(64)).mean(-1) * 31
  plddt = output["plddt"][0,:,1]

  bins = np.append(0,np.linspace(2.3125,21.6875,63))
  sm_contacts = softmax(output["distogram_logits"],-1)[0]
  sm_contacts = sm_contacts[...,bins<8].sum(-1)
  xyz = output["positions"][-1,0,:,1]
  mask = output["atom37_atom_exists"][0,:,1] == 1
  o = {"pae":pae[mask,:][:,mask],
       "plddt":plddt[mask],
       "sm_contacts":sm_contacts[mask,:][:,mask],
       "xyz":xyz[mask]}
  return o

def get_hash(x): return hashlib.sha1(x.encode()).hexdigest()

In [None]:
## Load model
if "model" not in dir() or model_name != model_name_:
  if "model" in dir():
    # delete old model from memory
    del model
    gc.collect()
    if torch.cuda.is_available():
      torch.cuda.empty_cache()

  model = torch.load(model_name)
  model.eval().cuda().requires_grad_(False)
  model_name_ = model_name

In [None]:
## Import as a function
def sequence_read(sequence_input, position, copies):
  sequence = sequence_input.loc[position]['Translation']
  sequence_clean = re.sub("[^A-Z:]", "", sequence.replace("/",":").upper())
  sequence_clean = re.sub(":+",":",sequence)
  sequence_clean = re.sub("^[:]+","",sequence)
  sequence_clean = re.sub("[:]+$","",sequence)

  meta = sequence_input.loc[position]['meta']

  if copies == "" or copies <= 0: copies = 1
  sequence = ":".join([sequence] * copies)

  ID = "Number_" + str(position) + "_" + jobname+"_"+get_hash(sequence)[:5]
  seqs = sequence.split(":")
  lengths = [len(s) for s in seqs]
  length = sum(lengths)
  print("length",length)

  u_seqs = list(set(seqs))
  if len(seqs) == 1: mode = "mono"
  elif len(u_seqs) == 1: mode = "homo"
  else: mode = "hetero"

  return sequence_clean, meta, copies, ID, u_seqs, length, lengths

In [None]:
## Predict protein structure from input sequence
def prediction_block(sequence, ID, row_number):
  start_time = time.time()

  torch.cuda.empty_cache()
  output = model.infer(
      sequence,
      num_recycles=num_recycles,
      chain_linker="X"*chain_linker,
      residue_index_offset=512
      )

  pdb_str = model.output_to_pdb(output)[0]
  output = tree_map(lambda x: x.cpu().numpy(), output)
  ptm = output["ptm"][0]
  plddt = output["plddt"][0,...,1].mean()
  O = parse_output(output)
  print(f'ptm: {ptm:.3f} plddt: {plddt:.3f}')

  end_time = time.time()
  print(f"Inference time for entry {row_number}, {meta}: {end_time - start_time:.2f}s")

  # os.system(f"mkdir -p {ID}")
  prefix = f"{ID}_ptm{ptm:.3f}_r{num_recycles}_default"
  np.savetxt(f"{prefix}.pae.txt",O["pae"],"%.3f")
  with open(f"{prefix}.pdb","w") as out:
    out.write(pdb_str)

  return pdb_str, prefix, O

In [None]:
## Import settings: manual settings
alphabet_list = list(ascii_uppercase+ascii_lowercase)

num_recycles = 3
chain_linker = 25
multimer_n = 1
row_number = 6

jobname = "dir_test"
jobname = re.sub(r'\W+', '', jobname)[:50]

input_directory = "/content/drive/MyDrive/ww_virome/esmfold_colab/sequences/"
output_directory = "/content/drive/MyDrive/ww_virome/esmfold_colab/structures"
os.chdir(output_directory)
file_path = f"{input_directory}big_merge_norovirus_translation.csv"
sequence_file = pd.read_csv(file_path)


# memory.free [MiB] A100: 40506 MiB -
# memory.free [MiB] L4: 22692 MiB - 1000
# memory.free [MiB]T4: 15095 MiB - 700

# length = len(sequence)
# if length > 700:
#   model.set_chunk_size(64)
# else:
#   model.set_chunk_size(128)

total_memory = torch.cuda.get_device_properties(0).total_memory / (1024 ** 2)  # Convert to MiB
free_memory = total_memory * 0.8  # Assume 80% available

if free_memory > 20000:  # 20 GB+
    chunk_size = 256
elif free_memory > 10000:  # 10-20 GB (T4, 15GB)
    chunk_size = 128
else:
    chunk_size = 64  # Low VRAM GPUs

model.set_chunk_size(chunk_size)
print(f"Using chunk_size: {chunk_size}")

In [None]:
%%time
sequence_clean, meta, copies, ID, u_seqs, length, lengths = sequence_read(
    sequence_file,
    row_number,
    copies=multimer_n
    )

pdb_str, prefix, O = prediction_block(
    sequence_clean,
    ID,
    row_number
    )