<a href="https://colab.research.google.com/github/jamesmkrieger/ColabBio/blob/main/categorical_jacobian/esm2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**get Categorical Jacobian from ESM2**
##(aka. extract conservation and coevolution for your favorite protein)

In [1]:
%%time
#@markdown ##setup model
model_name = "esm2_t33_650M_UR50D" # @param ["esm2_t48_15B_UR50D","esm2_t36_3B_UR50D","esm2_t33_650M_UR50D","esm2_t30_150M_UR50D","esm2_t12_35M_UR50D","esm2_t6_8M_UR50D","esm1b_t33_650M_UR50S"]
# this step will take ~3mins
import torch
import os
if not os.path.isfile("utils.py"):
  os.system("wget -qnc https://raw.githubusercontent.com/sokrypton/ColabBio/main/categorical_jacobian/utils.py")
  os.system("apt-get install aria2 -qq")
  os.system("mkdir -p /root/.cache/torch/hub/checkpoints/")

import matplotlib.pyplot as plt
import numpy as np
from scipy.special import softmax

import pandas as pd
import numpy as np
import bokeh.plotting
bokeh.io.output_notebook()
from bokeh.models import BasicTicker, PrintfTickFormatter
from bokeh.palettes import viridis, RdBu
from bokeh.transform import linear_cmap
from bokeh.plotting import figure, show

from matplotlib.colors import to_hex
cmap = plt.colormaps["bwr_r"]
bwr_r = [to_hex(cmap(i)) for i in np.linspace(0, 1, 256)]
cmap = plt.colormaps["gray_r"]
gray = [to_hex(cmap(i)) for i in np.linspace(0, 1, 256)]

def pssm_to_dataframe(pssm, esm_alphabet):
  sequence_length = pssm.shape[0]
  idx = [str(i) for i in np.arange(1, sequence_length + 1)]
  df = pd.DataFrame(pssm, index=idx, columns=list(esm_alphabet))
  df = df.stack().reset_index()
  df.columns = ['Position', 'Amino Acid', 'Probability']
  return df

def contact_to_dataframe(con):
  sequence_length = con.shape[0]
  idx = [str(i) for i in np.arange(1, sequence_length + 1)]
  df = pd.DataFrame(con, index=idx, columns=idx)
  df = df.stack().reset_index()
  df.columns = ['i', 'j', 'value']
  return df

def pair_to_dataframe(pair,esm_alphabet):
  sequence_length = pair.shape[0]
  df = pd.DataFrame(pair, index=list(esm_alphabet), columns=list(esm_alphabet))
  df = df.stack().reset_index()
  df.columns = ['aa_i', 'aa_j', 'value']
  return df

from utils import *
import tqdm.notebook

TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def load_model(model_name="esm2_t36_3B_UR50D"):
  if not os.path.isfile(f"/root/.cache/torch/hub/checkpoints/{model_name}.pt"):
    os.system(f"aria2c -q -x 16 -d /root/.cache/torch/hub/checkpoints/ https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt")
    os.system(f"aria2c -q -x 16 -d /root/.cache/torch/hub/checkpoints/ https://dl.fbaipublicfiles.com/fair-esm/regression/{model_name}-contact-regression.pt")
  model, alphabet = torch.hub.load("facebookresearch/esm:main", model_name)
  model = model.to(DEVICE)
  model = model.eval()
  return model, alphabet

def get_logits(seq, p=1, return_jac=False):
  x,ln = alphabet.get_batch_converter()([(None,seq)])[-1],len(seq)
  if p is None: p = ln
  with torch.no_grad():
    f = lambda x: model(x)["logits"][:,1:(ln+1),4:24].detach().cpu().numpy()
    logits = np.zeros((ln,20), dtype=np.float32)
    if return_jac:
      jac = np.zeros((ln,1,ln,20), dtype=np.float32)
      fx = f(x.to(DEVICE))[0]
    with tqdm.notebook.tqdm(total=ln, bar_format=TQDM_BAR_FORMAT) as pbar:
      for n in range(0,ln,p):
        m = min(n+p,ln)
        x_h = torch.tile(torch.clone(x),[m-n,1])
        for i in range(m-n):
          x_h[i,n+i+1] = alphabet.mask_idx
        fx_h = f(x_h.to(DEVICE))
        for i in range(m-n):
          logits[n+i] = fx_h[i,n+i]
          if return_jac:
            jac[n+i] = fx_h[i,None] - fx
        pbar.update(m-n)
    if return_jac:
      return jac
    else:
      return logits

def get_categorical_jacobian(seq, layer=None, fast=False):
  # ∂in/∂out
  x, ln = alphabet.get_batch_converter()([("seq", seq)])[-1], len(seq)
  with torch.no_grad():
    if layer is None:
      f = lambda x: model(x)["logits"][..., 1:(ln+1), 4:24].detach().cpu().numpy()
    else:
      f = lambda x: model(x, repr_layers=[layer])["representations"][layer][..., 1:(ln+1), :].detach().cpu().numpy()

    fx = f(x.to(DEVICE))[0]
    fx_h = np.zeros([ln, 1 if fast else 20, ln, fx.shape[-1]], dtype=np.float32)
    x = x.to(DEVICE) if fast else torch.tile(x, [20, 1]).to(DEVICE)
    with tqdm.notebook.tqdm(total=ln, bar_format=TQDM_BAR_FORMAT) as pbar:
      for n in range(ln):  # for each position
        x_h = torch.clone(x)

        # mutate to all 20 aa
        x_h[:, n+1] = alphabet.mask_idx if fast else torch.arange(4, 24)
        fx_h[n] = f(x_h)
        pbar.update(1)

  # note: direction here differs from manuscript
  # positive = good
  # negative = bad
  return fx_h - fx

model, alphabet = load_model(model_name)
esm_alphabet_len = len(alphabet.all_toks)
esm_alphabet = list("".join(alphabet.all_toks[4:24]))
ALPHABET = "AFILVMWYDEKRHNQSTGPC"
ALPHABET_map = [esm_alphabet.index(a) for a in ALPHABET]

def jac_to_con(jac, center=True, diag="remove", apc=True):

  X = jac.copy()
  Lx,Ax,Ly,Ay = X.shape
  if Ax == 20:
    X = X[:,ALPHABET_map,:,:]

  if Ay == 20:
    X = X[:,:,:,ALPHABET_map]
    if symm and Ax == 20:
      X = (X + X.transpose(2,3,0,1))/2

  if center:
    for i in range(4):
      if X.shape[i] > 1:
        X -= X.mean(i,keepdims=True)

  contacts = np.sqrt(np.square(X).sum((1,3)))

  if symm and (Ax != 20 or Ay != 20):
    contacts = (contacts + contacts.T)/2

  if diag == "remove":
    np.fill_diagonal(contacts,0)

  if diag == "normalize":
    contacts_diag = np.diag(contacts)
    contacts = contacts / np.sqrt(contacts_diag[:,None] * contacts_diag[None,:])

  if apc:
    ap = contacts.sum(0,keepdims=True) * contacts.sum(1, keepdims=True) / contacts.sum()
    contacts = contacts - ap

  if diag == "remove":
    np.fill_diagonal(contacts,0)

  return {"jac":X, "contacts":contacts}

Downloading: "https://github.com/facebookresearch/esm/zipball/main" to /root/.cache/torch/hub/main.zip


CPU times: user 10.9 s, sys: 3.77 s, total: 14.7 s
Wall time: 54.6 s


In [10]:
#@markdown ##enter sequence

sequence = "GFPNTISIGGLFMRNTVQEHSAFRFAVQLYNTNQNTTE KPFHLNYHVDHLDSSNSFSVTNAFCSQFSRGVYAIFGFYDQMSMNTLTSFCGALHTSFVT PSFPTDADVQFVIQMRPALKGAILSLLSYYKWEKFVYLYDTERGFSVLQAIMEAAVQNNW QVTARSVGNIKDVQEFRRIIEEMDRRQEKRYLIDCEVERINTILEQVVILGKHSRGYHYM LANLGFTDILLERVMHGGANITGFQIVNNENPMVQQFIQRWVRLDEREFPEAKNAPLKYT SALTHDAILVIAEAFRYLRRQRVDVSRRGSAGDCLANPAVPWSQGIDIERALKMVQVQGM TGNIQFDTYGRRTNYTIDVYEMKVSGSRKAGYWNEYERFVPFS" # @param {type:"string"}
sequence = sequence.upper().replace(' ', '')
print(len(sequence))
sequence = ''.join([i for i in sequence if i.isalpha()])

PARALLEL = 20
if len(sequence) > 1500:
  PARALLEL = 10
elif len(sequence) > 2400:
  PARALLEL = 1

os.makedirs("output",exist_ok=True)
with open("output/README.txt","w") as handle:
  handle.write("conservation_logits.txt = (L, A) matrix\n")
  handle.write("coevolution.txt = (L, L) matrix\n")
  handle.write("jac.npy = ((L*L-L)/2, A, A) tensor\n")
  handle.write("jac index can be recreated with np.triu_indices(L,1)\n")
  handle.write(f"[A]lphabet: {ALPHABET}\n")
  handle.write(f"sequence: {sequence}\n")

381


In [11]:
#@markdown ##compute conservation

logits = get_logits(sequence, p=PARALLEL)
logits = logits[:,ALPHABET_map]
np.savetxt(f"output/conservation_logits_{model_name}.txt",logits)
pssm = softmax(logits,-1)
df = pssm_to_dataframe(pssm, ALPHABET)

# plot pssm
num_colors = 256  # You can adjust this number
palette = viridis(256)
TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"
p = figure(title="CONSERVATION",
           x_range=[str(x) for x in range(1,len(sequence)+1)],
           y_range=list(ALPHABET)[::-1],
           width=900, height=400,
           tools=TOOLS, toolbar_location='below',
           tooltips=[('Position', '@Position'), ('Amino Acid', '@{Amino Acid}'), ('Probability', '@Probability')])

r = p.rect(x="Position", y="Amino Acid", width=1, height=1, source=df,
           fill_color=linear_cmap('Probability', palette, low=0, high=1),
           line_color=None)
p.xaxis.visible = False  # Hide the x-axis
show(p)

  0%|          | 0/381 [elapsed: 00:00 remaining: ?]

In [13]:
#@markdown ##compute coevolution
#@markdown Set output `layer` and postprocessing options such as to `center`, [`symm`]etrize, remove [`diag`]onal and to perform average product correction `apc`.
#@markdown The `fast` approximation only perturbs the mask token.
fast = True # @param {type:"boolean"}
layer = None # @param ["0","1","2","3","4","5","6","7","8","9","10","11","12","13","14","15","16","17","18","19","20","21","22","23","24","25","26","27","28","29","30","31","32","33","None"] {type:"raw"}
center = True # @param {type:"boolean"}
symm = True # @param {type:"boolean"}
diag = "remove" # @param ["remove", "normalize", "none"]
apc = True # @param {type:"boolean"}
settings = dict(layer=layer,
                sequence=sequence,
                fast=fast)
if not "jac" in dir() or settings != settings_:
  if fast and layer is None:
    jac = get_logits(sequence, p=PARALLEL, return_jac=True)
  else:
    jac = get_categorical_jacobian(sequence, layer=layer, fast=fast)
  settings_ = settings.copy()

con = jac_to_con(jac, center=center, diag=diag, apc=apc)

np.savetxt(f"output/coevolution_{model_name}.txt",con["contacts"])
if layer is not None:
  i,j = np.triu_indices(len(sequence),1)
  np.save(f"output/jac_{model_name}.npy",con["jac"][i,:,j,:].astype(np.float16))

df = contact_to_dataframe(con["contacts"])
TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"
p = figure(title="COEVOLUTION",
          x_range=[str(x) for x in range(1,len(sequence)+1)],
          y_range=[str(x) for x in range(1,len(sequence)+1)][::-1],
          width=800, height=800,
          tools=TOOLS, toolbar_location='below',
          tooltips=[('i', '@i'), ('j', '@j'), ('value', '@value')])

r = p.rect(x="i", y="j", width=1, height=1, source=df,
          fill_color=linear_cmap('value', gray, low=df.value.min(), high=df.value.max()),
          line_color=None)
p.xaxis.visible = False  # Hide the x-axis
p.yaxis.visible = False  # Hide the x-axis
show(p)

  0%|          | 0/381 [elapsed: 00:00 remaining: ?]

In [14]:
#@markdown ##show table of top covarying positions
from google.colab import data_table

sub_df = df[df["j"]>df["i"]].sort_values('value',ascending=False)
data_table.DataTable(sub_df, include_index=False, num_rows_per_page=20, min_width=10)



Unnamed: 0,i,j,value
135621,356,367,4.232102
69170,182,210,3.580831
5733,16,19,3.566924
86333,227,228,3.305645
53484,141,145,2.933220
...,...,...,...
118566,312,76,-0.216349
311,1,312,-0.218741
50321,133,30,-0.223345
72847,192,77,-0.244394


In [None]:
#@markdown ##select pair of residues to investigate
#@markdown Note: 1-indexed (first position is 1)

position_i = 15 # @param {type:"integer"}
position_j = 57 # @param {type:"integer"}
if layer is None:
  if fast:
    print("this function is only supported when `fast=True`")
  else:
    i = position_i - 1
    j = position_j - 1
    df = pair_to_dataframe(con["jac"][i,:,j,:], ALPHABET)

    # plot pssm
    TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"
    p = figure(title=f"coevolution between {position_i} {position_j}",
              x_range=list(ALPHABET),
              y_range=list(ALPHABET)[::-1],
              width=400, height=400,
              tools=TOOLS, toolbar_location='below',
              tooltips=[('aa_i', '@aa_i'), ('aa_j', '@aa_j'), ('value', '@value')])
    p.xaxis.axis_label = f"{sequence[i]}{position_i}"
    p.yaxis.axis_label = f"{sequence[j]}{position_j}"

    r = p.rect(x="aa_i", y="aa_j", width=1, height=1, source=df,
               fill_color=linear_cmap('value', bwr_r, low=-3.0, high=3.0),
               line_color=None, dilate=True)
    show(p)
else:
  print("this function is only supported when `layer=None`")

In [None]:
#@title download results (optional)
from google.colab import files
os.system(f"zip -r output.zip output/")
files.download(f'output.zip')