# VESM: Getting Started (Quickstart & Inference)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ntranoslab/vesm/blob/main/notebooks/VESM_Getting_Started.ipynb)

### Requirements
- Python 3.9+ (recommended), `torch`, `transformers`, `huggingface_hub`, `numpy`, `matplotlib`, `seaborn`
- GPU is optional but recommended for speed

## Table of contents
1. [Setup & Imports](#setup--imports)
2. [Load a VESM Model](#load-a-model)
3. [Run Inference & Get Scores](#vesm_inference)
4. [Visualize Results](#visualize-results)  
5. [Download Prediction Scores](#download-prediction-scores)
6. [VESM3 Inference](#vesm3-inference)

<a id="setup--imports"></a>

## 1. Setup & Imports

In [None]:
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, EsmForMaskedLM
import numpy as np
import pandas as pd
import warnings
warnings.simplefilter("ignore", FutureWarning) # ignore future warnings from transformers package
warnings.simplefilter("ignore", UserWarning) # ignore user warnings

In [None]:
local_dir = 'vesm' # local directory to store models

print('CUDA available:', torch.cuda.is_available())
if torch.cuda.is_available():
    print('GPU name:', torch.cuda.get_device_name(0))
    device = torch.device('cuda:0')
else:
    device = 'cpu'

<a id="load-a-model"></a>

## 2. Load VESM model

In [None]:
esm_dict = {
    "VESM_3B": 'facebook/esm2_t36_3B_UR50D',
    "VESM_650M": 'facebook/esm2_t33_650M_UR50D',
    "VESM_150M": 'facebook/esm2_t30_150M_UR50D',
    "VESM_35M": 'facebook/esm2_t12_35M_UR50D',
    "VESM3": "esm3_sm_open_v1"
}

def load_vesm(model_name="VESM_3B", local_dir="vesm", device='cuda'):
    if model_name in esm_dict:
        ckt = esm_dict[model_name]
    else:
        print("Model not found")
        return None
    # download weights
    hf_hub_download(repo_id="ntranoslab/vesm", filename=f"{model_name}.pth", local_dir=local_dir)
    # load base model
    if model_name == "VESM3":
      from esm.models.esm3 import ESM3
      model = ESM3.from_pretrained(ckt, device=device).to(torch.float)
      tokenizer = model.tokenizers.sequence
    else:
      model = EsmForMaskedLM.from_pretrained(ckt).to(device)
      tokenizer = AutoTokenizer.from_pretrained(ckt)
    # load pretrained VESM
    model.load_state_dict(torch.load(f'{local_dir}/{model_name}.pth'), strict=False)
    return model, tokenizer

We first load the VESM_3B checkpoint from Hugging Face

In [None]:
model_name = 'VESM_3B'
model, tokenizer = load_vesm(model_name, local_dir=local_dir, device=device)
sequence_vocabs = tokenizer.get_vocab()

# VESM Inference

The following function is to get log-likelihood ratio (LLR) scores for all possible single missense mutations of a given sequence.

In [None]:
def get_LLR(sequence, esm_model, device='cuda'):
  """
    @param sequence: str, input protein sequence
    @param esm_model: loaded VESM model
    @param device: 'cuda' or 'cpu'
    @return: pd.DataFrame, LLR scores for all missense mutations
  """
  tokens = tokenizer(sequence, return_tensors='pt')
  batch_tokens = tokens['input_ids']
  esm_model.eval()
  with torch.no_grad():
    logits =torch.log_softmax(esm_model(batch_tokens.to(device),)['logits'],dim=-1)[0,:,:].cpu()
  tok = batch_tokens[0,:].cpu()
  wt_norm = logits[np.arange(len(tok)), tok].unsqueeze(1)
  LLR = logits - wt_norm
  LLR=LLR[1:-1,4:24].numpy()
  AAorder=['K','R','H','E','D','N','Q','T','S','C','G','A','V','L','I','M','P','Y','F','W']
  LLR_ = pd.DataFrame(LLR,columns=tokenizer.all_tokens[4:24],index=list(sequence)).T.loc[AAorder]
  LLR_.columns = [j.split('.')[0]+' '+str(i+1) for i,j in enumerate(LLR_.columns)]
  return LLR_

<a id="vesm_inference"></a>
## 3. Inference on a protein sequence

In [None]:
protein_sequence = 'MVTLGVISLLENILVIVAIAKNKLHSPMYFFICSLAVADMLVSVSNGSET'

compute log-likelihood ratio scores (LLR) for all missense mutations with VESM

In [None]:
get_LLR(protein_sequence, model)

<a id="visualize-results"></a>

## 4. Visualizing Results

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objs as go
from google.colab import output

AAorder=['K','R','H','E','D','N','Q','T','S','C','G','A','V','L','I','M','P','Y','F','W']

def plot_interactive(LLR,higher_than_wt=False,thresh=2,zmax=0,cmap='Viridis_r'):

  TITLE=''

  is_dark = output.eval_js('document.documentElement.matches("[theme=dark]")')
  template='plotly_dark' if is_dark else 'plotly_white'

  fig = px.imshow(LLR.values, x=LLR.columns, y=LLR.index, color_continuous_scale=cmap,zmax=zmax,
                  labels=dict(y="Amino acid change", x="Protein sequence", color="LLR"),
                  template=template,
                  title=TITLE)
  fig.update_xaxes(tickangle=-90,range=[0,99],rangeslider=dict(visible=True),dtick=1)
  fig.update_yaxes(dtick=1)
  fig.update_layout({
  'plot_bgcolor': 'rgba(0, 0, 0, 0)',
  'paper_bgcolor': 'rgba(0, 0, 0, 0)',
  },font={'family':'Arial','size':11},
  hoverlabel=dict(font=dict(family='Arial', size=14)))

  fig.update_traces(
      hovertemplate="<br>".join([
          "<b>%{x} %{y}</b>"+
          " (%{z:.2f})",
      ])+'<extra></extra>'
  )
  if higher_than_wt:
    hwt_x=[]
    hwt_y=[]
    cust=[]
    for i in LLR.columns:
      for j in list(LLR.index[LLR[i]>thresh]):
        hwt_x+=[i]
        hwt_y+=[j]
        cust+=[LLR.loc[j,i]]

    fig.add_trace(go.Scatter(
        x=hwt_x,
        y=hwt_y,
        customdata=cust,
        mode='markers',
        marker=dict(size=8),
        hovertemplate="<br>".join([
            "<b>%{x} %{y}</b>"+
            " (%{customdata:.2f})",
        ])+'<extra></extra>')
    )

  fig.show()


def plot_heatmap(LLR,figname=None,vmin=None):
  """
    Plot a per-residue score heatmap.
      Args:
          scores: 2D array-like of shape (positions, alphabet) or 1D positional scores.
          figname: Optional basename to save PNG.
          range: Color scale range for heatmap.
  """
  primaryLLR=LLR
  plt.figure(figsize=(int(np.round(primaryLLR.shape[1]*90/390)),5))
  sns.heatmap( primaryLLR ,cmap='viridis_r',xticklabels=True, yticklabels=True,vmax=0,vmin=vmin)
  if not figname is None:
    plt.savefig(f"{figname}.png", dpi=300,bbox_inches = 'tight')
  plt.show()

### Visualizing LLRs

In [None]:
# MC4R
sequence = 'MVNSTHRGMHTSLHLWNRSSYRLHSNASESLGKGYSDGGCYEQLFVSPEVFVTLGVISLLENILVIVAIAKNKNLHSPMYFFICSLAVADMLVSVSNGSETIVITLLNSTDTDAQSFTVNIDNVIDSVICSSLLASICSLLSIAVDRYFTIFYALQYHNIMTVKRVGIIISCIWAACTVSGILFIIYSDSSAVIICLITMFFTMLALMASLYVHMFLMARLHIKRIAVLPGTGAIRQGANMKGAITLTILIGVFVVCWAPFFLHLIFYISCPQNPYCVCFMSHFNLYLILIMCNSIIDPLIYALRSQELRKTFKEIICCYPLGGLCDLSSRY'
plot_interactive(get_LLR(sequence, model))

<a id="visualize_structure"></a>

### Visualize on Structure

In [None]:
!pip install --quiet py3Dmol biopython

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import py3Dmol
from Bio.SeqUtils import seq1
from Bio.PDB import PDBParser

def show_pdb(position_scores, pdb_file, width=800, height=600, surface=False):
    view = py3Dmol.view(width=width, height=height, js='https://3dmol.org/build/3Dmol.js')
    view.addModel(open(pdb_file, 'r').read(), 'pdb')
    sigmoid_scores = 1 / (1 + np.exp((0.6 * np.array(position_scores) + 6)))
    for i, value in enumerate(sigmoid_scores):
        rgba = plt.cm.coolwarm(value)
        hexcol = matplotlib.colors.rgb2hex(rgba[:3])
        view.setStyle({'resi': str(i+1)}, {'cartoon': {'color': hexcol}})
    view.setBackgroundColor('#383838')

    if surface:
        for i, value in enumerate(sigmoid_scores):
            rgba = plt.cm.coolwarm(value)
            hexcol = matplotlib.colors.rgb2hex(rgba[:3])
            view.addSurface(
                py3Dmol.VDW,
                { 'opacity': 0.7, 'color': hexcol },
                { 'resi': str(i+1) }
            )
        view.setBackgroundColor('white')


    view.zoomTo()
    return view

In [None]:
import requests
def download_latest_af_pdb(uniprot_id, version_start=20, version_end=1, out_name=None):
    """
    Download the latest AF PDB model version for a UniProt ID.

    Args:
        uniprot_id (str): UniProt accession, e.g., "P12345"
        version_start (int): Highest version to check, e.g., 10
        version_end (int): Lowest version to check, e.g., 1
        out_name (str): Optional output filename. If None, use the server's name.
    """
    base_url = "https://alphafold.ebi.ac.uk/files"

    for v in range(version_start, version_end - 1, -1):
        url = f"{base_url}/AF-{uniprot_id}-F1-model_v{v}.pdb"
        head = requests.head(url)

        if head.status_code == 200:
            print(f"Found version v{v}: {url}")
            response = requests.get(url)
            response.raise_for_status()

            # Determine filename
            filename = out_name or url.split("/")[-1]
            with open(filename, "wb") as f:
                f.write(response.content)
            print(f"Downloaded to: {filename}")
            return filename

    print(f"No model found for UniProt ID '{uniprot_id}' in versions v{version_start}–v{version_end}.")
    return None

Example structure from AFDB

In [None]:
uniprot_id = 'P32245'
pdb_file = download_latest_af_pdb(uniprot_id, version_start=10, version_end=4, out_name=None)

extract sequence from pdb

In [None]:
sequence = "".join(
    seq1(res.get_resname())
    for res in PDBParser(QUIET=True).get_structure("model", pdb_file).get_residues()
    if res.id[0] == " "
)

average scores per position

In [None]:
position_scores = get_LLR(sequence, model).values.mean(0)
show_pdb(position_scores,pdb_file)

<a id="download"></a>
## Download Prediction Scores

- **sequence**: input any amino acid sequence for LLR inference.
- **seq_name**: name the sequence for the file name (optional).
- **heatmap**: check if you want to download the LLR heatmap

In [None]:
sequence = 'MVNSTHRGMHTSLHLWNRSSYRLHSNASESLGKGYSDGGCYEQLFVSPEVFVTLGVISLLENILVIVAIAKNKNLHSPMYFFICSLAVADMLVSVSNGSETIVITLLNSTDTDAQSFTVNIDNVIDSVICSSLLASICSLLSIAVDRYFTIFYALQYHNIMTVKRVGIIISCIWAACTVSGILFIIYSDSSAVIICLITMFFTMLALMASLYVHMFLMARLHIKRIAVLPGTGAIRQGANMKGAITLTILIGVFVVCWAPFFLHLIFYISCPQNPYCVCFMSHFNLYLILIMCNSIIDPLIYALRSQELRKTFKEIICCYPLGGLCDLSSRY' #@param {type:"string"}
seq_name = "" #@param {type:"string"}
heatmap = True #@param {type:"boolean"}

import hashlib, zipfile
from google.colab import files

def short_hash(seq, length=8):
    return hashlib.sha1(seq.encode()).hexdigest()[:length]

def meltLLR(LLR,gene_prefix=None,ignore_pos=False):
  vars = LLR.melt(ignore_index=False)
  vars['variant'] = [''.join(i.split(' '))+j for i,j in zip(vars['variable'],vars.index)]
  vars['score'] = vars['value']
  vars = vars.set_index('variant')
  if not ignore_pos:
    vars['pos'] = [int(i[1:-1]) for i in vars.index]
  del vars['variable'],vars['value']
  if gene_prefix is not None:
    vars.index=gene_prefix+'_'+vars.index
  return vars

base_name = f'vesm_LLR_{seq_name if len(seq_name) > 0 else short_hash(sequence)}'
csv_file = f"{base_name}.csv"
LLR_scores = get_LLR(sequence, model)

# save csv file
meltLLR(LLR_scores).to_csv(csv_file)

if heatmap:
    png_file = f"{base_name}.png"
    plot_heatmap(LLR_scores, figname=base_name, vmin=None)

    # Create results zip with both files
    zip_name = f"{base_name}_results.zip"
    with zipfile.ZipFile(zip_name, 'w') as zipf:
        zipf.write(csv_file)
        zipf.write(png_file)

file_name = zip_name if heatmap else csv_file

print(f"Saved results to: {file_name}")
files.download(file_name)

<a id="VESM3"></a>

# VESM3 Inference

Downloading the base ESM3-open model requires huggingface login

In [None]:
from huggingface_hub import login
login()

In [None]:
!pip install --quiet esm

remove previous model (necessary if using Colab's T4 GPU)

In [None]:
if 'model' in globals():
    del model
    torch.cuda.empty_cache()
    import gc; gc.collect()

Load VESM3 checkpoint

In [None]:
vesm3, _ = load_vesm("VESM3", local_dir=local_dir, device=device)

In [None]:
def get_vesm3_LLR_sequence(sequence, esm3_model):
  from esm.sdk.api import ESMProtein
  protein = ESMProtein(sequence=sequence)
  tokens = esm3_model.encode(protein).sequence
  with torch.no_grad():
    logits =torch.log_softmax(esm3_model.forward(sequence_tokens=tokens.reshape(1,-1)).sequence_logits[0, :, :], dim=-1).cpu()
  tok = tokens.cpu()
  wt_norm = logits[np.arange(len(tok)), tok].unsqueeze(1)
  LLR = logits - wt_norm

  AAorder=['K','R','H','E','D','N','Q','T','S','C','G','A','V','L','I','M','P','Y','F','W']
  order = [esm3_model.tokenizers.sequence.vocab[x] for x in AAorder]
  LLR=LLR[1:-1,order].numpy()
  LLR_ = pd.DataFrame(LLR,columns=AAorder,index=list(sequence)).T
  LLR_.columns = [j.split('.')[0]+' '+str(i+1) for i,j in enumerate(LLR_.columns)]
  return LLR_


def get_vesm3_LLR_structure(pdb_file, esm3_model):
  from esm.sdk.api import ESMProtein
  protein = ESMProtein.from_pdb(pdb_file)
  tokens = esm3_model.encode(protein)
  seq_tokens = tokens.sequence
  struct_tokens = tokens.structure

  with torch.no_grad():
    logits =torch.log_softmax(esm3_model.forward(sequence_tokens=seq_tokens.reshape(1,-1), structure_tokens=struct_tokens.reshape(1,-1) ).sequence_logits[0, :, :], dim=-1).cpu()
  tok = seq_tokens.cpu()
  wt_norm = logits[np.arange(len(tok)), tok].unsqueeze(1)
  LLR = logits - wt_norm

  AAorder=['K','R','H','E','D','N','Q','T','S','C','G','A','V','L','I','M','P','Y','F','W']
  order = [esm3_model.tokenizers.sequence.vocab[x] for x in AAorder]
  LLR=LLR[1:-1,order].numpy()
  LLR_ = pd.DataFrame(LLR,columns=AAorder,index=list(sequence)).T
  LLR_.columns = [j.split('.')[0]+' '+str(i+1) for i,j in enumerate(LLR_.columns)]
  return LLR_

In [None]:
### Example structure from AFDB
uniprot_id = 'P32245'
pdb_file = download_latest_af_pdb(uniprot_id, version_start=10, version_end=1, out_name=None)

In [None]:
from Bio.PDB import PDBParser
from Bio.SeqUtils import seq1
# extract sequence from pdb
sequence = "".join(
    seq1(res.get_resname())
    for res in PDBParser(QUIET=True).get_structure("model", pdb_file).get_residues()
    if res.id[0] == " "
)

### Inference with sequence only

In [None]:
LLR_seq = get_vesm3_LLR_sequence(sequence, vesm3)

### Inference with structure

In [None]:
LLR_struct = get_vesm3_LLR_structure(pdb_file, vesm3)

### Visualize Results

Visualize differences between sequence- and structure- derived LLRs with VESM3

In [None]:
# LLR_seq vs LLR_struct
import matplotlib
import matplotlib.pyplot as plt
plt.plot(LLR_seq.values.reshape(-1), LLR_struct.values.reshape(-1), '.')
plt.plot([-20, 5], [-20, 5], 'k--')
plt.xlabel('LLR from Sequence')
plt.ylabel('LLR from Structure')
plt.grid(True)
plt.show()

Plot the heatmap of differences

In [None]:
def plot_diff_heatmap(diff,figname=None,range=4):
  plt.figure(figsize=(int(np.round(diff.shape[1]*90/390)),5))
  sns.heatmap( diff ,cmap='coolwarm_r',xticklabels=True, yticklabels=True,vmax=range,vmin=-range)
  if not figname is None:
    plt.savefig(f"{figname}.png", dpi=300,bbox_inches = 'tight')
  plt.show()

diff = LLR_struct - LLR_seq
plot_heatmap(LLR_seq)
threshold = -5
plot_diff_heatmap(diff[((LLR_struct<threshold) & (LLR_seq>threshold)) | ((LLR_struct>threshold) & (LLR_seq<threshold))])
plot_heatmap(LLR_struct)