<a href="https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/beta/ESMFold_advanced.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**ESMFold_advanced**
for more details see: [Github](https://github.com/facebookresearch/esm/tree/main/esm), [Preprint](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v1)

#### **Tips and Instructions**
- click the little ▶ play icon to the left of each cell below.

#### **Colab Limitations**
- On Tesla T4 (typical free colab GPU), max total length ~ 900



In [None]:
%%time
#@title install
#@markdown install ESMFold, OpenFold and download Params (~2min 30s)

import os, time
if not os.path.isfile("esmfold.model"):
  # download esmfold params
  os.system("apt-get install aria2 -qq")
  os.system("aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/esm/esmfold.model &")

  # install libs
  os.system("pip install -q omegaconf pytorch_lightning biopython ml_collections einops py3Dmol")
  os.system("pip install -q git+https://github.com/NVIDIA/dllogger.git")

  # install openfold
  commit = "6908936b68ae89f67755240e2f588c09ec31d4c8"
  os.system(f"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}")

  # install esmfold
  os.system(f"pip install -q git+https://github.com/sokrypton/esm.git@beta")

  # wait for Params to finish downloading...
  if not os.path.isfile("esmfold.model"):
    # backup source!
    os.system("aria2c -q -x 16 https://files.ipd.uw.edu/pub/esmfold/esmfold.model")
  else:
    while os.path.isfile("esmfold.model.aria2"):
      time.sleep(5)

In [None]:
#@title ##run **ESMFold**
%%time
from string import ascii_uppercase, ascii_lowercase
import hashlib, re, os
import numpy as np
from jax.tree_util import tree_map
import matplotlib.pyplot as plt
from scipy.special import softmax

def parse_output(output):
  pae = (output["aligned_confidence_probs"][0] * np.arange(64)).mean(-1) * 31
  plddt = output["plddt"][0,:,1]
  
  bins = np.append(0,np.linspace(2.3125,21.6875,63))
  sm_contacts = softmax(output["distogram_logits"],-1)[0]
  sm_contacts = sm_contacts[...,bins<8].sum(-1)
  xyz = output["positions"][-1,0,:,1]
  mask = output["atom37_atom_exists"][0,:,1] == 1
  o = {"pae":pae[mask,:][:,mask],
       "plddt":plddt[mask],
       "sm_contacts":sm_contacts[mask,:][:,mask],
       "xyz":xyz[mask]}
  if "contacts" in output["lm_output"]:
    lm_contacts = output["lm_output"]["contacts"].astype(float)[0]
    o["lm_contacts"] = lm_contacts[mask,:][:,mask]
  return o

def get_hash(x): return hashlib.sha1(x.encode()).hexdigest()
alphabet_list = list(ascii_uppercase+ascii_lowercase)

jobname = "test" #@param {type:"string"}
jobname = re.sub(r'\W+', '', jobname)[:50]

sequence = "GWSTELEKHREELKEFLKKEGITNVEIRIDNGRLEVRVEGGTERLKRFLEELRQKLEKKGYTVDIKIE" #@param {type:"string"}
sequence = re.sub("[^A-Z:]", "", sequence.replace("/",":").upper())
sequence = re.sub(":+",":",sequence)
sequence = re.sub("^[:]+","",sequence)
sequence = re.sub("[:]+$","",sequence)

#@markdown ---
#@markdown ###**Advanced Options**
num_recycles = 3 #@param ["0", "1", "2", "3", "6", "12"] {type:"raw"}
get_LM_contacts = False #@param {type:"boolean"}

#@markdown **multimer options (experimental)**
#@markdown - use "/" to specify chainbreaks, (eg. sequence="AAA/AAA")
#@markdown - for homo-oligomeric predictions, set copies > 1

copies = 1 #@param {type:"integer"}
chain_linker = 25 #@param {type:"number"}
if copies == "" or copies <= 0: copies = 1
sequence = ":".join([sequence] * copies)

#@markdown **sampling options (experimental)**
#@markdown - Samples are generated via random masking (defined by `masking_rate`) 
#@markdown of input sequence (stochastic_mode="LM") and/or via dropout within structure module (stochastic_mode="SM").
samples = None #@param ["None", "1", "4", "8", "16", "32", "64"] {type:"raw"}
masking_rate = 0.15 #@param {type:"number"}
stochastic_mode = "LM" #@param ["LM", "LM_SM", "SM"]

ID = jobname+"_"+get_hash(sequence)[:5]
seqs = sequence.split(":")
lengths = [len(s) for s in seqs]
length = sum(lengths)
print("length",length)

u_seqs = list(set(seqs))
if len(seqs) == 1: mode = "mono"
elif len(u_seqs) == 1: mode = "homo"
else: mode = "hetero"

if "model" not in dir():
  import torch
  model = torch.load("esmfold.model")
  model.cuda().requires_grad_(False)

# optimized for Tesla T4
if length > 700:
  model.trunk.set_chunk_size(64)
else:
  model.trunk.set_chunk_size(128)

best_pdb_str = None
best_ptm = 0
best_output = None
traj = []

num_samples = 1 if samples is None else samples
for seed in range(num_samples):
  torch.cuda.empty_cache()
  if samples is None:
    seed = "default"
    mask_rate = 0.0
    model.train(False)
  else:
    torch.manual_seed(seed)
    mask_rate = masking_rate if "LM" in stochastic_mode else 0.0
    model.train("SM" in stochastic_mode)

  output = model.infer(sequence,
                       num_recycles=num_recycles,
                       chain_linker="X"*chain_linker,
                       residue_index_offset=512,
                       mask_rate=mask_rate,
                       return_contacts=get_LM_contacts)
  
  pdb_str = model.output_to_pdb(output)[0]
  output = tree_map(lambda x: x.cpu().numpy(), output)
  ptm = output["ptm"][0]
  plddt = output["plddt"][0,:,1].mean()
  traj.append(parse_output(output))
  print(f'{seed} ptm: {ptm:.3f} plddt: {plddt:.1f}')
  if ptm > best_ptm:
    best_pdb_str = pdb_str
    best_ptm = ptm
    best_output = output
  os.system(f"mkdir -p {ID}")
  if samples is None:
    pdb_filename = f"{ID}/ptm{ptm:.3f}_r{num_recycles}_{seed}.pdb"
  else:
    pdb_filename = f"{ID}/ptm{ptm:.3f}_r{num_recycles}_seed{seed}_{stochastic_mode}_m{masking_rate:.2f}.pdb"

  with open(pdb_filename,"w") as out:
    out.write(pdb_str)

In [None]:
#@title display (optional) {run: "auto"}
#@markdown Note: If samples selected, the model with max pTM is displayed.
import py3Dmol


pymol_color_list = ["#33ff33","#00ffff","#ff33cc","#ffff00","#ff9999","#e5e5e5","#7f7fff","#ff7f00",
                    "#7fff7f","#199999","#ff007f","#ffdd5e","#8c3f99","#b2b2b2","#007fff","#c4b200",
                    "#8cb266","#00bfbf","#b27f7f","#fcd1a5","#ff7f7f","#ffbfdd","#7fffff","#ffff7f",
                    "#00ff7f","#337fcc","#d8337f","#bfff3f","#ff7fff","#d8d8ff","#3fffbf","#b78c4c",
                    "#339933","#66b2b2","#ba8c84","#84bf00","#b24c66","#7f7f7f","#3f3fa5","#a5512b"]

def show_pdb(pdb_str, show_sidechains=False, show_mainchains=False,
             color="pLDDT", chains=None, vmin=50, vmax=90,
             size=(800,480), hbondCutoff=4.0,
             Ls=None,
             animate=False):
  
  if chains is None:
    chains = 1 if Ls is None else len(Ls)
  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=size[0], height=size[1])
  if animate:
    view.addModelsAsFrames(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})
  else:
    view.addModel(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})
  if color == "pLDDT":
    view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':vmin,'max':vmax}}})
  elif color == "rainbow":
    view.setStyle({'cartoon': {'color':'spectrum'}})
  elif color == "chain":
    for n,chain,color in zip(range(chains),alphabet_list,pymol_color_list):
       view.setStyle({'chain':chain},{'cartoon': {'color':color}})
  if show_sidechains:
    BB = ['C','O','N']
    view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                  {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                  {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                  {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})  
  if show_mainchains:
    BB = ['C','O','N','CA']
    view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
  view.zoomTo()
  if animate: view.animate()
  return view

color = "confidence" #@param ["confidence", "rainbow", "chain"]
if color == "confidence": color = "pLDDT"
show_sidechains = False #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}
show_pdb(best_pdb_str, color=color,
         show_sidechains=show_sidechains,
         show_mainchains=show_mainchains,
         Ls=lengths).show()

In [None]:
#@title plot confidence (optional)
#@markdown Note: If samples selected, the model with max pTM is displayed.

dpi = 100 #@param {type:"integer"}

def plot_ticks(Ls):
  Ln = sum(Ls)
  L_prev = 0
  for L_i in Ls[:-1]:
    L = L_prev + L_i
    L_prev += L_i
    plt.plot([0,Ln],[L,L],color="black")
    plt.plot([L,L],[0,Ln],color="black")
  ticks = np.cumsum([0]+Ls)
  ticks = (ticks[1:] + ticks[:-1])/2
  plt.yticks(ticks,alphabet_list[:len(ticks)])

def plot_confidence(output, Ls=None, dpi=100):
  O = parse_output(output)
  if "lm_contacts" in O:
    plt.figure(figsize=(20,4), dpi=dpi)
    plt.subplot(1,4,1)
  else:
    plt.figure(figsize=(15,4), dpi=dpi)
    plt.subplot(1,3,1)

  plt.title('Predicted lDDT')
  plt.plot(O["plddt"])
  if Ls is not None:
    L_prev = 0
    for L_i in Ls[:-1]:
      L = L_prev + L_i
      L_prev += L_i
      plt.plot([L,L],[0,100],color="black")
  plt.xlim(0,O["plddt"].shape[0])
  plt.ylim(0,100)
  plt.ylabel('plDDT')
  plt.xlabel('position')
  plt.subplot(1,4 if "lm_contacts" in O else 3,2)

  plt.title('Predicted Aligned Error')
  Ln = O["pae"].shape[0]
  plt.imshow(O["pae"],cmap="bwr",vmin=0,vmax=30,extent=(0, Ln, Ln, 0))
  if Ls is not None and len(Ls) > 1: plot_ticks(Ls)
  plt.colorbar()
  plt.xlabel('Scored residue')
  plt.ylabel('Aligned residue')

  if "lm_contacts" in O:
    plt.subplot(1,4,3)
    plt.title("contacts from LM")
    plt.imshow(O["lm_contacts"],cmap="Greys",vmin=0,vmax=1,extent=(0, Ln, Ln, 0))
    if Ls is not None and len(Ls) > 1: plot_ticks(Ls)
    plt.subplot(1,4,4)
  else:
    plt.subplot(1,3,3)
  plt.title("contacts from Structure Module")
  plt.imshow(O["sm_contacts"],cmap="Greys",vmin=0,vmax=1,extent=(0, Ln, Ln, 0))
  if Ls is not None and len(Ls) > 1: plot_ticks(Ls)
  return plt

plot_confidence(best_output, Ls=lengths, dpi=dpi)
plt.show()

In [None]:
#@title download predictions
from google.colab import files
os.system(f"zip {ID}.zip {ID}/*")
files.download(f'{ID}.zip')

In [None]:
#@title animate outputs (optional)
#@markdown if more than one sample generated
import os
if not os.path.isdir("colabfold"):
  os.system("git clone https://github.com/sokrypton/ColabFold.git")
import ColabFold.beta.colabfold as cf

dpi =  100#@param {type:"integer"}
color_by_plddt = True #@param {type:"boolean"}
use_pca = True
cycle = True 

import matplotlib
from matplotlib import animation
from IPython.display import HTML
from sklearn.decomposition import PCA

def mk_animation(xyz, labels, plddt, ref=0, Ls=None, line_w=2.0, dpi=100,color_by_plddt=False):

  def ca_align_to_last(positions, ref):
    def align(P, Q):
      if Ls is None or len(Ls) == 1:
        P_,Q_ = P,Q
      else:
        # align relative to first chain
        P_,Q_ = P[:Ls[0]],Q[:Ls[0]]
      p = P_ - P_.mean(0,keepdims=True)
      q = Q_ - Q_.mean(0,keepdims=True)
      return ((P - P_.mean(0,keepdims=True)) @ cf.kabsch(p,q)) + Q_.mean(0,keepdims=True)
    
    pos = positions[ref] - positions[ref].mean(0,keepdims=True)
    best_2D_view = pos @ cf.kabsch(pos,pos,return_v=True)

    new_positions = []
    for i in range(len(positions)):
      new_positions.append(align(positions[i],best_2D_view))
    return np.asarray(new_positions)

  # align to reference 
  pos = ca_align_to_last(xyz, ref)

  fig, (ax1) = plt.subplots(1)
  fig.set_figwidth(5)
  fig.set_figheight(5)
  fig.set_dpi(dpi)

  xy_min = pos[...,:2].min() - 1
  xy_max = pos[...,:2].max() + 1

  for ax in [ax1]:
    ax.set_xlim(xy_min, xy_max)
    ax.set_ylim(xy_min, xy_max)
    ax.axis(False)

  ims=[]
  for l,pos_,plddt_ in zip(labels,pos,plddt):
    if color_by_plddt:
      img = cf.plot_pseudo_3D(pos_, c=plddt_, cmin=50, cmax=90, line_w=line_w, ax=ax1)    
    elif Ls is None or len(Ls) == 1:
      img = cf.plot_pseudo_3D(pos_, ax=ax1, line_w=line_w)
    else:
      c = np.concatenate([[n]*L for n,L in enumerate(Ls)])
      img = cf.plot_pseudo_3D(pos_, c=c, cmap=cf.pymol_cmap, cmin=0, cmax=39, line_w=line_w, ax=ax1)
    ims.append([cf.add_text(f"{l}", ax1),img])
    
  ani = animation.ArtistAnimation(fig, ims, blit=True, interval=120)
  plt.close()
  return ani.to_html5_video()

labels = np.array([f"seed:{x}" for x in range(len(traj))])
pos = np.array([x["xyz"] for x in traj])
plddt = np.array([x["plddt"] for x in traj])
if use_pca and pos.shape[0] > 1:
  pos_ca = pos
  if lengths is not None and len(lengths) > 1:
    pos_ca = pos_ca[:,:lengths[0]]
  i,j = np.triu_indices(pos_ca.shape[1],k=1)
  pos_ca_dm = np.sqrt(np.square(pos_ca[:,None,:,:] - pos_ca[:,:,None]).sum(-1))[:,i,j]
  pc = PCA(1).fit_transform(pos_ca_dm)[:,0].argsort()
  if cycle:
    pc = np.append(pc,pc[1:-1][::-1])
  pos = pos[pc]
  labels = labels[pc]
  plddt = plddt[pc]
HTML(mk_animation(pos, labels, plddt, Ls=lengths, dpi=dpi, color_by_plddt=color_by_plddt))