<a href="https://colab.research.google.com/github/cydneymartell/ColabAggregationPrediction/blob/main/Martell_2025_Aggregation_Predictions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Aggregation predictions using models from [Martell et al. 2025](https://www.biorxiv.org/content/10.1101/2025.11.11.687847v1)

### This script generates predictions for 50 &deg;C, 75 &deg;C or pH 4 aggregation for protein sequences. Additionally for a protein of interest you can predict a deep mutational scan and plot the predictions.



##### Portions of this workflow incorporate scripts and adapted functions from SaProtHub (Su et al., 2024), which provides a framework for training and using fine-tuned SaProt models.

##### Su, J., Li, Z., Han, C., Zhou, Y., Shan, J., Zhou, X., Ma, D., The OPMC, Ovchinnikov, S., & Yuan, F. (2024). SaProtHub: Making Protein Modeling Accessible to All Biologists.

##### This workflow also incorporates scripts from ThermoMPNN to predict stability changes.

##### Dieckhaus, H., Brocidiacono, M., Randolph, N. Z., & Kuhlman, B. (2024). Transfer learning to leverage larger datasets for improved prediction of protein stability changes. Proceedings of the National Academy of Sciences, 121(6), e2314853121. https://doi.org/10.1073/pnas.2314853121

# **Aggregation Predictions**

In [None]:
#@title Install SaProt and Requirements
import subprocess
import os, sys
from tqdm.notebook import tqdm
from google.colab import files
import ipywidgets as widgets
from IPython.display import display, Markdown, clear_output
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import TwoSlopeNorm
from io import StringIO

root_dir = os.getcwd()
saprot_path = os.path.join(root_dir, "SaprotHub")

# Clone SaProtHub repository if it is missing
if not os.path.exists(saprot_path):
    !git clone https://github.com/westlake-repl/SaprotHub.git -q
if saprot_path not in sys.path:
    sys.path.append(saprot_path)

#Tests importing saprot
try:
    import saprot
    print("Saprot imported successfully!")
except Exception as e:
    print("Import failed:", e)
#!pip install -r SaprotHub/requirements.txt -q -q
print("Installing requirements...")
subprocess.run(
    ["pip", "install", "-r", "SaprotHub/requirements.txt", "-q", "-q"],
    stdout=subprocess.DEVNULL
)

import torch
from transformers import AutoTokenizer, AutoModel,EsmForProteinFolding
import torch.nn as nn
from huggingface_hub import snapshot_download
from pathlib import Path
import json
import copy
from saprot.config.config_dict import Default_config
from saprot.scripts.training import my_load_model
from saprot.config.config_dict import Default_config



In [None]:
from utils.foldseek_util import get_struc_seq
import tempfile
import zipfile
import re
#@title Generate Structurally Aware (SA) Sequences
#@markdown #### (Optional) You only need to generate this once and can input the saved structurally aware sequences for predictions below.
#@markdown #### For multiple structures, the following need to be uploaded:
#@markdown 1. **A ZIP file** containing PDB or CIF structure files
#@markdown 2. **A CSV file** with the following columns:
#@markdown    - `file_name` – The name of the structure file for that protein (e.g., protein1.pdb).
#@markdown    - `chain` – The specific chain to analyze.
#@markdown    - `AF_predicted` – Indicates whether the structure is AlphaFold-predicted (True or False).

#@markdown ### Select structure upload
Num_structures = "Single Structure"  # @param ['Single Structure', 'Multiple Structures']

display(Markdown("### Enter your structure information"))

#Inputs for multiple structures
if Num_structures == "Multiple Structures":
  display(Markdown("### Upload ZIP of PDB files"))

  pdb_zip_upload = files.upload()
  if pdb_zip_upload:
      zip_name = list(pdb_zip_upload.keys())[0]
      zip_bytes = pdb_zip_upload[zip_name]

      with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as tmp:
          tmp.write(zip_bytes)
          pdb_zip_path = tmp.name

      display(Markdown(f"**Uploaded PDB ZIP:** {zip_name}"))

  display(Markdown("### Upload CSV specifying structure type, chain, and file name"))
  display(Markdown("""CSV must include the following columns:

        - 'file_name' – The name of the pdb file for that protein (e.g., `protein1.pdb`).
        - 'chain' – The specific chain to analyze.
        - 'AF_predicted' – Indicates whether the structure is AlphaFold-predicted (`True` or `False`).
        """))
  uploaded = files.upload()
  if uploaded:
      file_name = list(uploaded.keys())[0]
      uploaded_csv = pd.read_csv(StringIO(uploaded[file_name].decode("utf-8")))
      display(Markdown(f"**Uploaded CSV:** {file_name} ({len(uploaded_csv)} rows)"))

  output_area_sa = widgets.Output()
  display(output_area_sa)

#Inputs for a single structure
elif Num_structures == "Single Structure":

  #Chain input
  chain_box = widgets.Text(
      placeholder="Enter chain",
      description="Chain:",
      layout=widgets.Layout(width="400px")
  )
  display(chain_box)

  #Alphafold Selection; masks residues with low plddt
  AF_dropdown = widgets.Dropdown(
      options=[('Yes (True)', True), ('No (False)', False)],
      value=False,
      description="AlphaFold-Predicted?",
      layout=widgets.Layout(width="300px"),style={'description_width': '150px'}
  )
  display(AF_dropdown)

  #Upload PDB File
  #upload_btn = widgets.Button(description="Upload PDB File")
  #display(upload_btn)
  uploaded = files.upload()

  output_area_sa = widgets.Output()
  display(output_area_sa)

generate_btn = widgets.Button(
        description="Generate SA Sequences",
        layout=widgets.Layout(width="300px", height="40px",)
    )
generate_btn.button_style = 'success'
display(generate_btn)

#Foldseek path
foldseek_path = "/content/SaprotHub/bin/foldseek"
!chmod +x /content/SaprotHub/bin/foldseek

from Bio.PDB import MMCIFParser, PDBIO

def convert_cif_to_pdb(cif_bytes):
    parser = MMCIFParser(QUIET=True)
    structure = parser.get_structure("structure", StringIO(cif_bytes.decode("utf-8")))

    pdb_io = PDBIO()
    tmp = tempfile.NamedTemporaryFile(suffix=".pdb", delete=False)
    pdb_io.set_structure(structure)
    pdb_io.save(tmp.name)
    return tmp.name

def on_generate_click(b):
    #Generates structurally aware sequence for a single protein
    if Num_structures == "Single Structure":
      file_name = list(uploaded.keys())[0]
      raw_bytes = uploaded[file_name]

      if file_name.endswith(".cif"):
          pdb_path = convert_cif_to_pdb(raw_bytes)
      else:
          file_name = list(uploaded.keys())[0]
          pdb_content = uploaded[file_name].decode("utf-8")

          with tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) as tmp:
              tmp.write(pdb_content.encode("utf-8"))
              pdb_path = tmp.name

      chain = chain_box.value.strip()
      AF_bool = AF_dropdown.value

      with output_area_sa:
          output_area_sa.clear_output()
          if not chain:
              print("Please enter the chain")
              return
          try:
              seq_dict = get_struc_seq(foldseek_path, pdb_path, [chain], plddt_mask=AF_bool)
              seq, foldseek_seq, combined_seq = seq_dict[chain]
              print("Sequence:", seq)
              print("Foldseek sequence:", foldseek_seq)
              print("Combined structurally aware (SA) sequence:", combined_seq)
          except Exception as e:
              print("Error:", e)
    #Generates structurally aware sequence for many proteins
    elif Num_structures == "Multiple Structures":
      if not pdb_zip_upload:
                print("Please upload a ZIP file.")
                return
      if uploaded_csv is None:
          print("Please upload a CSV file.")
          return

      print("Extracting PDB ZIP...")
      extract_dir = tempfile.mkdtemp()
      with zipfile.ZipFile(pdb_zip_path, 'r') as z:
          z.extractall(extract_dir)
      print(f"Extracted to: {extract_dir}\n")

      names = []
      sequences = []
      sa = []
      #Generates structurally aware sequences for all pdbs
      for i, row in uploaded_csv.iterrows():
        pdb_file = row["file_name"]
        chain = str(row["chain"])
        AF_bool = bool(row["AF_predicted"])

        # Full path to PDB inside extracted ZIP
        pdb_path = os.path.join(extract_dir, pdb_file)

        if not os.path.exists(pdb_path):
            print(f"Missing PDB: {pdb_file}")
            continue

        #print(f"Processing: {pdb_file}, chain {chain}, AF={AF_bool}")
        try:
            seq, foldseek_seq, combined_seq = get_struc_seq(
                foldseek_path,
                pdb_path,
                [chain],
                plddt_mask=AF_bool
            )[chain]

            names.append(pdb_file.split('/')[-1])

            sequences.append(seq)
            sa.append(combined_seq)

        except Exception as e:
                    print(f"Failed: {e}")

      #Save dataframe of structurally aware sequences and downloads the dataframe
      df = pd.DataFrame(zip(names, sequences, sa), columns = ["ID","seq_no_sa","protein"])
      df.to_csv("structurally_aware_sequences.csv", index=False)
      files.download("structurally_aware_sequences.csv")
      print(f"Downloaded to structurally_aware_sequences.csv")
generate_btn.on_click(on_generate_click)


In [None]:

#@title Load Aggregation Model
#@markdown #### Fine-tuned adapters are deposited on [HuggingFace](https://hf.co/collections/cmartell/martell-et-al-2025-aggregation-models-and-datasets)
#Function to initialize the saprot regression model based on the LoRA configs
def model_initialization(base_model, lora_kwargs):
  config = copy.deepcopy(Default_config)
  config.model.model_py_path = "saprot/saprot_regression_model"
  config.model.kwargs.config_path = base_model
  config.model.kwargs.lora_kwargs = lora_kwargs
  model = my_load_model(config.model)
  device = "cuda" if torch.cuda.is_available() else "cpu"
  model.to(device)
  model.eval()

  return model, device

#Function to predict aggregation for multiple protein sequences
def predict_aggregation_multiple(sequences, base_model, lora_kwargs):

  model, device = model_initialization(base_model, lora_kwargs)

  #Tokenizes all protein sequences
  all_inputs = tokenizer(
    sequences.tolist(),
    return_tensors="pt",
    padding=True,
    truncation=True
    ).to(device)
  #Makes predictions for each protein
  pred_labels =[]
  for i in tqdm(range(len(sequences))):
    input_i = {k: v[i].unsqueeze(0) for k, v in all_inputs.items()}  # keep batch dim
    with torch.no_grad():
        pred = model(input_i)
        pred_labels.append(pred.item())
  return pred_labels

#Function to predict aggregation for a single protein sequence
def predict_aggregation_single(sa_seq,base_model, lora_kwargs):

  model, device = model_initialization(base_model, lora_kwargs)

  #Tokenize and make prediction
  inputs = tokenizer(sa_seq, return_tensors="pt").to(device)
  inputs = {k: v.to(device) for k, v in inputs.items()}
  with torch.no_grad():
    pred = model(inputs)
  pred_labels = pred.item()

  return pred_labels

#Selection for aggregation stress prediction
#@markdown #### <br>Select Aggregation Model
Aggregation_stress = "50 C Aggregation"  # @param ['50 C Aggregation', '75 C Aggregation', 'pH 4 Aggregation']

ADAPTER_HOME = Path(f'{root_dir}/SaprotHub/adapters')
model_dict = {
    '50 C Aggregation': 'cmartell/Model-50C_Aggregation-650M',
    '75 C Aggregation': 'cmartell/Model-75C_Aggregation-650M',
    'pH 4 Aggregation': 'cmartell/Model-pH4_Aggregation-650M'
}
model_arg = model_dict[Aggregation_stress]
snapshot_download(repo_id=model_arg, repo_type="model",local_dir=ADAPTER_HOME/model_arg)

adapter_path = ADAPTER_HOME/model_arg
base_model_name = "westlake-repl/SaProt_650M_AF2"
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
#base_model = AutoModel.from_pretrained(base_model_name)
base_model = "westlake-repl/SaProt_650M_AF2"
lora_kwargs = {
  "is_trainable": False,
  "num_lora": 1,
  "config_list": [{"lora_config_path": adapter_path}]
}

metadata_path = Path(adapter_path)/ "metadata.json"
with open(metadata_path, 'r') as f:
  metadata = json.load(f)


print("Fine-tuned adapters loaded successfully")

In [None]:
#@title Make Aggregation Predictions
#@markdown #### Predictions can be made for a single protein of interest or multiple sequences. Predictions use the structurally aware (SA) sequence, which can be produced above from the AlphaFold-predicted structure or PDB file.
#@markdown #### <br> For multiple-sequence predictions, the uploaded CSV must contain:
#@markdown - Column named `protein` with the SA sequences
#@markdown - Column named an `ID` with the corresponding protein names
#@markdown #### <br> Input Selection
Input_type = "Single SA sequence"  # @param ["Single SA sequence", "Multiple SA sequences"]

uploaded_name = None
uploaded_sequence = None
uploaded_csv = None

#Inputs for Single SA sequence
if Input_type == "Single SA sequence":
    display(Markdown("### Enter your sequence information"))

    # Text input for sequence name
    name_box = widgets.Text(
        placeholder="Enter sequence name",
        description="Name:",
        layout=widgets.Layout(width="400px")
    )
    display(name_box)

    # Text area for sequence
    seq_box = widgets.Textarea(
        placeholder="Paste your sequence here",
        description="Sequence:",
        layout=widgets.Layout(width="600px", height="150px")
    )
    display(seq_box)

    output_area_pred = widgets.Output()

#Input for Multiple sequences
elif Input_type == "Multiple SA sequences":
    display(Markdown("###Upload a CSV file containing multiple sequences"))

    uploaded = files.upload()
    if uploaded:
        file_name = list(uploaded.keys())[0]
        uploaded_csv = pd.read_csv(StringIO(uploaded[file_name].decode("utf-8")))
        display(Markdown(f"**Uploaded CSV:** {file_name} ({len(uploaded_csv)} rows)"))


predict_button = widgets.Button(description="Make Predictions", button_style="success", disabled=False)
output_area_pred = widgets.Output()

def on_predict_click(b):
    output_area_pred.clear_output()
    with output_area_pred:
        #Predictions for single sequence
        if Input_type == "Single SA sequence":
            uploaded_name = name_box.value.strip()
            uploaded_sequence = seq_box.value.strip()
            if uploaded_name and uploaded_sequence:
              display(Markdown(f"Running prediction on single sequence: {uploaded_name}"))
              result = predict_aggregation_single(uploaded_sequence, base_model, lora_kwargs)
              pd.DataFrame([[uploaded_name,uploaded_sequence, result]], columns=["ID","protein", "score"]).to_csv(f"{uploaded_name}_predictions.csv", index=False)
              files.download(f"{uploaded_name}_predictions.csv")
              print(f"Aggregation prediction for {uploaded_name} is {result:.3f}")
              display(Markdown("Predictions were downloaded as " + f"{uploaded_name}_predictions.csv"))
        #Predictions for multiple sequences
        elif uploaded_csv is not None:
            display(Markdown(f"Running prediction on {len(uploaded_csv)} sequences from CSV"))
            result = predict_aggregation_multiple(uploaded_csv["protein"], base_model, lora_kwargs)
            uploaded_csv["score"] =result
            uploaded_csv.to_csv(f"predictions.csv", index=False)
            files.download("predictions.csv")
            display(Markdown(f"Predictions were downloaded as predictions.csv"))
        else:
            display(Markdown("No input detected."))

predict_button.on_click(on_predict_click)
display(predict_button, output_area_pred)


In [None]:
#@title Make Aggregation DMS Predictions
#@markdown #### Predicts aggregation for all single point mutants for a protein of interest. These predictions use the structurally aware sequence using the wild-type structure for each mutant prediction.

#generate DMS sequences
def generate_dms_SA_seq(seq, name):
    AA = ["A","T","S","V","I","K","P","Y","M","N","H","G","F","R","D","E","W","Q","L"]

    dms_seq =[seq]
    names =[name]
    original_AA = ["wt"]
    mutant_AA = ["wt"]
    position_mut = ["wt"]

    wt_s = seq
    seq = ''

    #Iterates through all amino acids
    for a in range(len(AA)):
        seq = ""
        count = 1

        #Iterates through each wt poisiton
        for i, chr in enumerate(wt_s):
          if chr.isupper():
            if a != chr:
              if i == 1:
                seq = AA[a] + wt_s[i+1:]
                position_mut.append(count)
                mutant_AA.append(AA[a])
                original_AA.append(chr)
                dms_seq.append(seq)
              else:
                seq = wt_s[:i] + AA[a] + wt_s[i+1:]
                position_mut.append(count)
                mutant_AA.append(AA[a])
                original_AA.append(chr)
                dms_seq.append(seq)
              names.append(name + "_"+ chr + str(count)+ AA[a] )
              count+=1
    df_dms = pd.DataFrame(list(zip(names, dms_seq, original_AA, position_mut, mutant_AA)), columns=["ID","protein","original_AA","position_mut","mutant_AA"])
    return(df_dms)


display(Markdown("### Enter your sequence information"))
name_box = widgets.Text(
    placeholder="Enter sequence name",
    description="Name:",
    layout=widgets.Layout(width="400px")
)
seq_box = widgets.Textarea(
    placeholder="Paste your sequence here",
    description="Sequence:",
    layout=widgets.Layout(width="600px", height="150px")
)
display(name_box, seq_box)
dms_button = widgets.Button(description="Generate DMS Predictions", button_style="success", disabled=False)
output_area_dms = widgets.Output()

df_dms = None

#Predictions aggregation for dms scan
def on_dms_click(b):
    global df_dms
    output_area_dms.clear_output()
    with output_area_dms:

      uploaded_name = name_box.value.strip()
      uploaded_sequence = seq_box.value.strip()
      if uploaded_name and uploaded_sequence:
        display(Markdown(f"Running DMS predictions on protein: {uploaded_name}"))
        df_dms = generate_dms_SA_seq(uploaded_sequence, uploaded_name)
        result = predict_aggregation_multiple(df_dms["protein"], base_model, lora_kwargs)
        df_dms["score"] =result
        df_dms.to_csv(f"/content/{uploaded_name}_agg_dms_predictions.csv", index=False)
        files.download(f"/content/{uploaded_name}_agg_dms_predictions.csv")
        display(Markdown("DMS Predictions were downloaded and saved as " +f"{uploaded_name}_agg_dms_predictions.csv"))

      else:
          display(Markdown("No input detected."))

dms_button.on_click(on_dms_click)
display(dms_button, output_area_dms)

In [None]:

#@title Plot Aggregation DMS Predictions
#@markdown #### Plots the DMS Predictions for the protein of interest generated above. The heatmap is colored by the Δ log2(Fold Change), where more positive values (green) indicated substitutions predicted to be more aggregation resistant. The black dots indicate the wild type residue at each position. The model does not predict substitions to cysteines because the models weren't trained on proteins with these residues.
AminoAcids = ['C', 'D', 'E', 'R','K', 'Q', 'N', 'H','M','A','G', 'S', 'T', 'I',  'L','V', 'F','W', 'Y','P']
df_dms["no_sa"]=[c[::2] for c in df_dms["protein"]]
pos = np.arange(1,len(df_dms["no_sa"].values[0])+1,1)
matrix_agg = np.zeros((len(AminoAcids),len(df_dms["no_sa"].values[0])))
xlabel = []
wt_score_agg = df_dms.query(""" position_mut == "wt" """)["score"].values[0]
df_dms["delta_score"] = df_dms["score"] - wt_score_agg

#Makes the columns the WT amino acid position
for col_wt in range(0,len(pos)):
    #Iterates through each amino acid substitution as the row
    for row_sub in range(len(AminoAcids)):
        curr = df_dms.query("""  position_mut == {} and mutant_AA == "{}" """.format(pos[col_wt],AminoAcids[row_sub]))

        if df_dms.query(""" position_mut == "wt" """)["no_sa"].values[0][col_wt] == AminoAcids[row_sub]:
            curr = df_dms.query(""" position_mut == "wt" """)
            xlabel.append(AminoAcids[row_sub]+str(pos[col_wt]))
            wt_score_agg = curr["delta_score"].values[0]
        if len(curr) > 0:
            matrix_agg[row_sub,col_wt]=curr["delta_score"].values[0]

        else:
            matrix_agg[row_sub,col_wt]=np.nan

#Plot heatmap
plt.figure(figsize = (20,4))
ax=sns.heatmap(matrix_agg, yticklabels = AminoAcids, xticklabels=xlabel,center = 0,cmap = "PRGn",cbar_kws={'label': 'Δ log2(Fold Change)'})

n = df_dms.query(""" position_mut == "wt" """)["ID"].values[0]
plt.title("Aggregation DMS for " +n)

#Adds scatterplot points for the WT protein sequence
wtseq = df_dms.query(""" position_mut == "wt" """)["no_sa"].values[0]
for i in range(len(wtseq)):
    plt.scatter(i+0.5,AminoAcids.index(wtseq[i])+0.5, s=30,color='black')

plt.savefig(f"/content/{n}_agg_DMS_heatmap.png", dpi=300, bbox_inches='tight')
plt.show()

print("Heatmap saved as "+f"{df_dms.query(""" position_mut == "wt" """)["ID"].values[0]}_agg_DMS_heatmap.png")

# **Compare Stability and Aggregation Predictions**

In [None]:
#@title Install ThermoMPNN and the Dependencies
import yaml
target_dir = "/content/"

# Only change directory if not already there
if os.getcwd() != os.path.abspath(target_dir):
    os.chdir(target_dir)

repo_url = "https://github.com/Kuhlman-Lab/ThermoMPNN.git"
repo_name = "ThermoMPNN"

#Clone ThermoMPNN Repo
if not os.path.exists(repo_name):
    print("Cloning ThermoMPNN...")
    !git clone $repo_url -q
    print("Cloned ThermoMPNN")
    !cd "ThermoMPNN/analysis"
else:
    print("ThermoMPNN already exists — skipping clone.")

sys.path.insert(0, "/content/ThermoMPNN")

#Install Miniconda
miniconda_path = "/content/miniconda"
conda_bin = os.path.join(miniconda_path, "bin")

if not os.path.exists(miniconda_path):
    !wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh > /dev/null 2>&1
    !bash /tmp/miniconda.sh -b -p {miniconda_path} > /dev/null 2>&1
    !rm /tmp/miniconda.sh
    print("Miniconda installed.")
else:
    print("Miniconda already exists — skipping install.")


os.environ["PATH"] = conda_bin + ":" + os.environ["PATH"]
sys.path.append(conda_bin)

!conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main
!conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r

#Create ThermoMPNN environment
env_name = "thermoMPNN"
yaml_path = "/content/ThermoMPNN/environment.yaml"

# Check if environment already exists
envs_list = !conda env list
if env_name in "\n".join(envs_list):
    print(f"\nEnvironment {env_name} already exists")
else:
    print(f"\nCreating environment {env_name}...")
    !conda env create -f {yaml_path} -q --yes
    !/content/miniconda/envs/thermoMPNN/bin/python -m pip install -q --upgrade --force-reinstall torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 > /dev/null 2>&1

target_dir = "/content/ThermoMPNN/analysis"

if os.getcwd() != os.path.abspath(target_dir):
    print(f"Working directory {os.getcwd()}")
    os.chdir(target_dir)
    print(f"Changed working directory to: {os.getcwd()}")
else:
    print(f"Already in the correct directory: {os.getcwd()}")



#Update the model_weights path in the local.yaml
yaml_path = "/content/ThermoMPNN/local.yaml"
with open(yaml_path, 'r') as f:
    config = yaml.safe_load(f)
config['platform']['thermompnn_dir'] = "/content/ThermoMPNN"
with open(yaml_path, 'w') as f:
    yaml.safe_dump(config, f)
print("local.yaml updated successfully!")

In [None]:
#@title Make DMS stability predictions with ThermoMPNN
#@markdown #### Predicts ddG for all single point mutants for a protein of interest. The predictions require the protein structure as input.

import glob
import shutil
from ipywidgets import widgets, Layout


#Widgets to upload files and start predictions
protein_input = widgets.Text(
    description='Protein Name:',
    placeholder='Enter protein name',
    layout=Layout(width='40%'),style={'description_width': '120px'}
)

structure_upload_widget = widgets.FileUpload(
    description='Upload PDB',
    accept='.pdb',
    multiple=False
)

run_button = widgets.Button(
    description='Run ThermoMPNN',
    button_style='success'
)

output = widgets.Output()

def run_prediction(b):
    with output:
        clear_output()
        if not protein_input.value:
            print("Please enter a protein name.")
            return
        if not structure_upload_widget.value:
            print("Please upload a PDB file.")
            return

        protein_name = protein_input.value
        uploaded_file = list(structure_upload_widget.value.values())[0]
        pdb_filename = uploaded_file['metadata']['name']
        pdb_path = os.path.join('/content', pdb_filename)

        # Save uploaded file to content folder
        with open(pdb_path, 'wb') as f:
            f.write(uploaded_file['content'])

        print(f"Running ThermoMPNN on {pdb_filename} for protein {protein_name}...")

        output_log = f"/content/{protein_name}_thermompnn_output.log"

        # Run prediction
        !/content/miniconda/envs/thermoMPNN/bin/python /content/ThermoMPNN/analysis/custom_inference.py \
            --pdb {pdb_path} \
            --model_path ../models/thermoMPNN_default.pt \
            > {output_log} 2>&1

        #Save the prediction to the content folder
        csv_files = glob.glob("/content/ThermoMPNN/analysis/ThermoMPNN_inference_*.csv")
        if csv_files:
            original_csv = csv_files[0]  # take the first (or latest) file
            new_csv = f"/content/{protein_name}_thermoMPNN_predictions.csv"
            shutil.move(original_csv, new_csv)
            print(f"ThermoMPNN prediction complete. Output log saved to {output_log}. Predictions saved to {new_csv}")
        else:
            print(f"No CSV output found. Check if ThermoMPNN ran correctly. Output log saved to {output_log}.")





        structure_upload_widget.value.clear()  # clear uploaded files
        structure_upload_widget._counter = 0

run_button.on_click(run_prediction)


display(protein_input, structure_upload_widget, run_button, output)


In [None]:
#@title Plot ThermoMPNN Predictions
#@markdown #### Plots the DMS stability predictions for the protein of interest. The heatmap is colored by the ΔΔG, where more negative values (green) indicated substitutions predicted to be stabilizing. The black dots indicate the wild type residue at each position. Substitions to cysteines aren't shown because they aren't predicted for aggregation .

thermompnn_predictions = pd.read_csv(f"/content/{protein_input.value}_thermoMPNN_predictions.csv")
thermompnn_predictions["position_mut"] = [str(p+1) for p in thermompnn_predictions["position"]]
thermompnn_predictions["mutant_AA"] = thermompnn_predictions["mutation"]
thermompnn_predictions["original_AA"] = thermompnn_predictions["wildtype"]

wt_df = thermompnn_predictions[["position", "wildtype"]].drop_duplicates().sort_values("position")
wt_sequence = "".join(wt_df["wildtype"].tolist())
thermompnn_predictions["no_sa"] = [wt_sequence for p in range(len(thermompnn_predictions))]
AminoAcids = [ 'C','D', 'E', 'R','K', 'Q', 'N', 'H','M','A','G', 'S', 'T', 'I',  'L','V', 'F','W', 'Y','P']
pos = np.arange(1,len(thermompnn_predictions["no_sa"].values[0])+1,1)
matrix_agg = np.zeros((len(AminoAcids),len(thermompnn_predictions["no_sa"].values[0])))
xlabel = []

#Makes the columns the WT amino acid position
for col_wt in range(0,len(pos)):
    #Iterates through each amino acid substitution as the row
    for row_sub in range(len(AminoAcids)):
        curr = thermompnn_predictions.query("""  position_mut == "{}" and mutant_AA == "{}" """.format(pos[col_wt],AminoAcids[row_sub]))

        if thermompnn_predictions["no_sa"].values[0][col_wt] == AminoAcids[row_sub]:
            xlabel.append(AminoAcids[row_sub]+str(pos[col_wt]))
            matrix_agg[row_sub,col_wt]=0
        if len(curr) > 0:
            matrix_agg[row_sub,col_wt]=curr["ddG_pred"].values[0]

        else:
            matrix_agg[row_sub,col_wt]=np.nan

#Plot heatmap
plt.figure(figsize = (25,4))
ax=sns.heatmap(matrix_agg, yticklabels = AminoAcids, xticklabels=xlabel,center = 0,cmap = "PRGn_r",cbar_kws={'label': 'ΔΔG Predictions'})

n = protein_input.value
plt.title("ThermoMPNN DMS for " +n +"\nΔΔG Predictions")

#Adds scatterplot points for the WT protein sequence
wtseq = thermompnn_predictions["no_sa"].values[0]
for i in range(len(wtseq)):
    plt.scatter(i+0.5,AminoAcids.index(wtseq[i])+0.5, s=30,color='black')

plt.savefig(f"{n}_ThermoMPNN_DMS_heatmap.png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
#@title Combine and Download csv with ThermoMPNN  and Aggregation Predictions
#@markdown #### #@title Enter the protein name to combine and download ThermoMPNN and aggregation DMS predictions generated in this session.

#Widgets to get protein name
protein_name = widgets.Text(
    description='Protein Name:',
    placeholder='Enter protein name',
    layout=widgets.Layout(width='40%'),
    style={'description_width': '120px'}
)

download_button = widgets.Button(
    description="Download Combined CSV",
    button_style="success",  # green
    layout=widgets.Layout(width="220px")
)

output = widgets.Output()

display(protein_name, download_button, output)
dfs = {} #Initialize dictionary to store dataframes globally

def on_download_clicked(b):
    with output:
        output.clear_output()

        n = protein_name.value

        if n == "":
            print("Please enter a protein name.")
            return
        #Read in thermompnn and aggregation predictions for a protein of interest
        try:
            thermompnn_predictions = pd.read_csv(f"/content/{n}_thermoMPNN_predictions.csv")
            aggregation_dms_predictions = pd.read_csv(f"/content/{n}_agg_dms_predictions.csv")

        except FileNotFoundError:
            print(f"Could not find prediction files for protein: {n}")
            print(f"Expected the following:\n  /content/{n}_thermoMPNN_predictions.csv\n  /content/{n}_agg_dms_predictions.csv")
            return


        aggregation_dms_predictions["no_sa"] = [c[::2] for c in aggregation_dms_predictions["protein"]]

        wt_score_agg = aggregation_dms_predictions.query("position_mut == 'wt'")["score"].values[0]
        aggregation_dms_predictions["delta_score"] = aggregation_dms_predictions["score"] - wt_score_agg

        thermompnn_predictions["position_mut"] = [str(p+1) for p in thermompnn_predictions["position"]]
        thermompnn_predictions["mutant_AA"] = thermompnn_predictions["mutation"]
        thermompnn_predictions["original_AA"] = thermompnn_predictions["wildtype"]

        agg_thermompnn = pd.merge(
            thermompnn_predictions,
            aggregation_dms_predictions,
            on=["original_AA", "position_mut", "mutant_AA"])

        agg_thermompnn["score_agg"] = agg_thermompnn["score"]
        agg_thermompnn["delta_score_agg"] = agg_thermompnn["delta_score"]
        dfs["agg"] = aggregation_dms_predictions
        dfs["thermo"] = thermompnn_predictions
        dfs["merged"] = agg_thermompnn
        if len(agg_thermompnn) != len(aggregation_dms_predictions) - 1:
            print("Error merging predictions. Check the prediction files.")
            return
        output_path = f"{n}_thermompnn_agg_predictions.csv"
        agg_thermompnn[['ID', 'protein', 'no_sa', 'position_mut', 'mutant_AA','original_AA', 'score_agg', 'delta_score_agg', 'ddG_pred']].to_csv(output_path, index=False)

        files.download(output_path)
        print(f"Combined Aggregation and ThermoMPNN predictions downloaded as: {output_path}")

        dfs["agg"] = aggregation_dms_predictions
        dfs["thermo"] = thermompnn_predictions
        dfs["merged"] = agg_thermompnn

download_button.on_click(on_download_clicked)


In [None]:
#@title Heatmap Comparing ThermoMPNN  and Aggregation Predictions
import matplotlib.patches as mpatches
import matplotlib.colors as mcolors
AminoAcids = [ 'C','D', 'E', 'R','K', 'Q', 'N', 'H','M','A','G', 'S', 'T', 'I',  'L','V', 'F','W', 'Y','P']
agg_thermompnn = dfs["merged"]
aggregation_dms_predictions = dfs["agg"]
pos = np.arange(1,len(agg_thermompnn["no_sa"].values[0])+1,1)
matrix_agg = np.zeros((len(AminoAcids),len(agg_thermompnn["no_sa"].values[0])))
xlabel = []

#Makes the columns the WT amino acid position
for col_wt in range(0,len(pos)):
    #Iterates through each amino acid substitution as the row
    for row_sub in range(len(AminoAcids)):
        curr = agg_thermompnn.query("""  position_mut == "{}" and mutant_AA == "{}" """.format(pos[col_wt],AminoAcids[row_sub]))

        if aggregation_dms_predictions.query(""" position_mut == "wt" """)["no_sa"].values[0][col_wt] == AminoAcids[row_sub]:
            xlabel.append(AminoAcids[row_sub]+str(pos[col_wt]))
            matrix_agg[row_sub,col_wt]=0
        elif len(curr) > 0:
            if curr["delta_score"].values[0] > 0 and curr["ddG_pred"].values[0] <0: #agg res, stabilizing
              matrix_agg[row_sub,col_wt]=1
            elif curr["delta_score"].values[0] > 0 and curr["ddG_pred"].values[0] >0: #agg res, destabilizing
              matrix_agg[row_sub,col_wt]=0.5
            elif curr["delta_score"].values[0] < 0 and curr["ddG_pred"].values[0] <0: #agg promoting, stabilizing
              matrix_agg[row_sub,col_wt]=-0.5
            else:
              matrix_agg[row_sub,col_wt]=-1 #agg promoting, destabilizing

        else:
            matrix_agg[row_sub,col_wt]=np.nan

# Map predictions to integer bins
value_to_bin = {
    -1: 0,
    -0.5: 1,
    0.0: 2,
    0.5: 3,
    1: 4
}

# Build categorical matrix with integer labels
cat_matrix = np.full_like(matrix_agg, fill_value=np.nan, dtype=float)

for i in range(matrix_agg.shape[0]):
    for j in range(matrix_agg.shape[1]):
        val = matrix_agg[i, j]
        if not np.isnan(val):
            cat_matrix[i, j] = value_to_bin[val]

# Categorical legend labels
legend_labels = {
    -1: "Aggregation Promoting &\nDestabilizing Mutation",
    -0.5: "Aggregation Promoting &\nStabilizing Mutation",
    0.0: "WT",
    0.5: "Aggregation Resistance &\nDestabilizing Mutation",
    1: "Aggregation Resistance &\nStabilizing Mutation"
}

legend_colors = {
    -1: "tomato",
    -0.5: "deepskyblue",
    0.0: "white",
    0.5: "mediumorchid",
    1: "yellowgreen"
}
colors = ["tomato","deepskyblue",  "white", "mediumorchid", "yellowgreen"]
cmap = mcolors.ListedColormap(colors)

handles = [
    mpatches.Patch(facecolor=colors[value_to_bin[val]], label=label, edgecolor='black')
    for val, label in legend_labels.items()
]

plt.figure(figsize=(20,6))

# Heatmap plot
ax = sns.heatmap(
    cat_matrix,
    cmap=cmap,
    vmin=0, vmax=4,
    cbar=False,
    yticklabels=AminoAcids,
    xticklabels=xlabel,
    linewidths=0.1,
    linecolor='lightgray'
)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90, fontsize=12)
ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=12)

# WT dots
wtseq = aggregation_dms_predictions.query(""" position_mut == "wt" """)["no_sa"].values[0]
for i in range(len(wtseq)):
    plt.scatter(i+0.5, AminoAcids.index(wtseq[i])+0.5, s=30, color='black')

n = aggregation_dms_predictions.query(""" position_mut == "wt" """)["ID"].values[0]
plt.title("Comparison of Aggregation and ThermoMPNN Predictions", fontsize = 16)

# Legend on the right outside plot
plt.legend(
    handles=handles,
    loc='center left',
    bbox_to_anchor=(1.02, 0.5),   # push legend outside right side
    borderaxespad=0,
    frameon=False, fontsize = 14
)

plt.tight_layout()
plt.show()



In [None]:
#@title Scatterplot Comparing ThermoMPNN  and Aggregation Predictions
!pip install -q plotly

import plotly.express as px


# Define aggregation effect categories
def combined_effect(row):
    if row['delta_score'] > 0 and row['ddG_pred'] < 0:
        return 'Aggregation Resistance & Stabilizing'
    elif row['delta_score'] < 0 and row['ddG_pred'] < 0:
        return 'Aggregation Promoting & Stabilizing'
    elif row['delta_score'] > 0 and row['ddG_pred'] > 0:
        return 'Aggregation Resistance & Destabilizing'
    else:
        return 'Aggregation Promoting & Destabilizing'

agg_thermompnn['Combined_Effect'] = agg_thermompnn.apply(combined_effect, axis=1)

# Map heatmap colors
agg_colors = {
    'Aggregation Resistance & Stabilizing': 'yellowgreen',
    'Aggregation Promoting & Stabilizing': 'deepskyblue',
    'Aggregation Resistance & Destabilizing': 'mediumorchid',
    'Aggregation Promoting & Destabilizing': 'tomato'
}

agg_thermompnn["delta_agg_score"]=agg_thermompnn["delta_score"]

fig = px.scatter(agg_thermompnn, x='ddG_pred', y='delta_agg_score', color='Combined_Effect',custom_data=agg_thermompnn[['original_AA','position_mut', 'mutant_AA','delta_agg_score','ddG_pred']],color_discrete_map=agg_colors)
fig.update_layout(width=1200, height=800,plot_bgcolor='white', paper_bgcolor='white' )

#Get axis limits
x_min, x_max = agg_thermompnn['ddG_pred'].min(), agg_thermompnn['ddG_pred'].max()
y_min, y_max = agg_thermompnn['delta_agg_score'].min(), agg_thermompnn['delta_agg_score'].max()

#Add vertical line at x=0
fig.add_shape(
    type="line",
    x0=0, x1=0,
    y0=y_min, y1=y_max,
    line=dict(color="black", width=2.5, dash="dash")
)

# Add horizontal line at y=0
fig.add_shape(
    type="line",
    x0=x_min, x1=x_max,
    y0=0, y1=0,
    line=dict(color="black", width=2.5, dash="dash")
)
fig.update_xaxes(showgrid=True, gridcolor='lightgrey', gridwidth=1)
fig.update_yaxes(showgrid=True, gridcolor='lightgrey', gridwidth=1)

fig.update_layout(
    xaxis_title="ΔΔG (kcal/mol)<br>Stability Prediction",
    yaxis_title="Aggregation Prediction<br>Δlog2(Fold Change)")

fig.update_traces(
    hovertemplate=(
        "Original AA: %{customdata[0]}<br>"
        "Position: %{customdata[1]}<br>"
        "Mutant AA: %{customdata[2]}<br>"
        "ΔAgg Score: %{customdata[3]:.2g}<br>"
        "ΔΔG Pred: %{customdata[4]:.2g}<br>"
        "<extra></extra>"))

fig.show()



In [None]:
#@title Display Combined Results as an Interactive Table
from google.colab import data_table
df = dfs["merged"][['ID', 'no_sa', 'position_mut', 'mutant_AA',
             'original_AA', 'score_agg', 'delta_score_agg', 'ddG_pred']]
data_table.enable_dataframe_formatter()
data_table.DataTable(df, include_index=True, num_rows_per_page=10)