In [None]:
import pandas as pd

from plm_inv_subnetworks.dataset import data_io
from plm_inv_subnetworks.dataset.cath_dataset import get_cath_db
from plm_inv_subnetworks.dataset.data_paths import PDB_DIR, ESMFOLD_650M_PDBS, ESM_PPL_METRICS, ESM_TMALIGN_METRICS

RUN_DIR_PREFIX = "../results/inverse_subnetworks" # Change this to runs/ if evaluating new trained subnetworks

In [None]:

ESM_PPL = pd.read_csv(ESM_PPL_METRICS)
ESM_TMALIGN = pd.read_csv(ESM_TMALIGN_METRICS)
ESM_PPL = ESM_PPL.groupby("cath_id", as_index=False)[ESM_PPL.select_dtypes(include="number").columns].mean()
ESM_GT = pd.merge(ESM_PPL, ESM_TMALIGN, on="cath_id")

db = get_cath_db()

In [8]:
def add_split(df, split):
    df["split"] = df["cath_id"].apply(
    lambda x: (
        "train" if x in split.get("train", []) else
        "val" if x in split.get("val", []) else
        "test" if x in split.get("test", []) else
        None
    )
    )
    return df

In [18]:
def load_suppression_df(run_dir, db):
    base_path = f"{RUN_DIR_PREFIX}/{run_dir}"
    
    # Updated PPL path
    ppl_path = f"{base_path}/perplexity.csv"
    ppl_df = pd.read_csv(ppl_path).groupby("cath_id", as_index=False).mean()
    
    # TM-score path
    tm_path = f"{base_path}/tmalign.csv"
    tm_df = pd.read_csv(tm_path)

    # Merge and hydrate
    df = pd.merge(tm_df, ppl_df, on="cath_id")
    df = data_io.hydrate_df_with_cath_terms(df, db)
    
    # Split assignment
    config, split = data_io.get_args_split(base_path)
    df = add_split(df, split)
    
    return df[df["split"] == "val"]

# Load all suppression datasets
SUPPRESS_SEQ_ALPHA = load_suppression_df("class_1", db)
SUPPRESS_SEQ_BETA  = load_suppression_df("class_2", db)
SUPPRESS_RES_ALPHA = load_suppression_df("residue_1", db)
SUPPRESS_RES_BETA  = load_suppression_df("residue_2", db)


### Inspect predictions and choose structures to visualize 

In [None]:
YOUR_FOLDED_PROTEIN_DIR = None # Set this to the directory where your inverse subnetwork-predicted structures are stored

# We only use IDs in the heldout set of each model in the paper: 1avcA07, 1bdyA00, 1am2A00, 1bccG00

cath_id = "1bdyA00"
DF = SUPPRESS_SEQ_ALPHA
RUN_DIR = "class_1"

print("ESM", ESM_GT[ESM_GT["cath_id"] == cath_id][["perplexity", "TM-score", "RMSD", "pLDDT"]].round(2))
print("SUB", DF[DF["cath_id"] == cath_id][["perplexity", "TM-score", "RMSD", "pLDDT"]].round(2))


print(f"scp rvinod@ssh.ccv.brown.edu:{RUN_DIR_PREFIX}/{RUN_DIR}/{YOUR_FOLDED_PROTEIN_DIR}/{cath_id}.pdb ~/Downloads/{RUN_DIR}_pred_{cath_id}.pdb")
print()
print(f"scp rvinod@ssh.ccv.brown.edu:{ESMFOLD_650M_PDBS}/{cath_id}.pdb ~/Downloads/{RUN_DIR}_esm_{cath_id}.pdb")
print()
print(f"scp rvinod@ssh.ccv.brown.edu:{PDB_DIR}/{cath_id}.pdb ~/Downloads/{RUN_DIR}_pdb_{cath_id}.pdb")


ESM      perplexity  TM-score  RMSD  pLDDT
151        9.48      0.79  1.93  70.38
SUB     perplexity  TM-score  RMSD  pLDDT
26        8.99      0.76  2.11  70.09
scp rvinod@ssh.ccv.brown.edu:../results/inverse_subnetworks/class_1/pred/1bdyA00.pdb ~/Downloads/class_1_pred_1bdyA00.pdb

scp rvinod@ssh.ccv.brown.edu:../data/esmfold_650M/pdbs/1bdyA00.pdb ~/Downloads/class_1_esm_1bdyA00.pdb

scp rvinod@ssh.ccv.brown.edu:../data/dompdb_chain/1bdyA00.pdb ~/Downloads/class_1_pdb_1bdyA00.pdb


In [None]:
# Use this palette for pymol viz

# https://kpwulab.com/2023/03/09/color-alphafold2s-plddt/
# Using AlphaFold2 colors (from https://github.com/sokrypton/ColabFold)

# set_color n0, [0.051, 0.341, 0.827]
# set_color n1, [0.416, 0.796, 0.945]
# set_color n2, [0.996, 0.851, 0.212]
# set_color n3, [0.992, 0.490, 0.302]
# color n0, b < 100; color n1, b < 90
# color n2, b < 70;  color n3, b < 50