In [None]:
# Additional requirements just for this notebook:

# !pip install huggingface_hub[cli]
# !pip install esm 
# !pip install py3Dmol

In [None]:
# General:
import os 
import io
from omegaconf import OmegaConf
# import huggingface_hub
import requests
import pathlib
import pickle
import pandas as pd
import numpy as np
import torch

# ESM:
# from esm.models.esmc import ESMC
# from esm.models.esm3 import ESM3
from esm.sdk.api import ESMProtein, LogitsConfig, GenerationConfig
from esm.utils.structure.protein_chain import ProteinChain
# from esm.utils.types import FunctionAnnotation
from biotite.database import rcsb
# from fa_helper import visualize_function_annotations, get_keywords_from_interpro, interpro2keywords
from scripts.helpers.pdb import gene_to_pdb, fetch_pdb_ids

# Visualization:
import matplotlib.pyplot as plt
from scripts.helpers.visualization import visualize_3D_protein
from PIL import Image
from IPython.display import SVG
from rdkit.Chem import Draw
from rdkit.Chem.Draw import SimilarityMaps
from synformer.models.synformer import draw_generation_results
import seaborn as sns

# Chemistry:
from rdkit import Chem
import rdkit

# Synformer:
# from synformer.chem.fpindex import FingerprintIndex
# from synformer.chem.matrix import ReactantReactionMatrix
from synformer.chem.mol import Molecule
from synformer.models.synformer import Synformer
from scripts.sample_helpers import load_model, featurize_smiles, load_protein_molecule_pairs, sample

In [None]:
# Loading ESM-related configs:
assert os.path.exists("configs/esm.yml"), "Missing config file: configs/esm.yml" 
esm_config = OmegaConf.load("configs/esm.yml")
# huggingface_hub.login(esm_config.hf_token)

### 1. Loading the data

In [None]:
# Dataset paths
# protein_molecule_pairs_path = os.path.join("data", "protein_molecule_pairs", "papyrus_val_19399.csv")
protein_molecule_pairs_path = os.path.join("data", "protein_molecule_pairs", "papyrus_test_19399.csv")
protein_embeddings_path = os.path.join("data", "protein_embeddings", "embeddings_selection_float16_4973.pth")
synthetic_pathways_path = os.path.join("data", "synthetic_pathways", "filtered_pathways_370000.pth")

In [None]:
df_protein_molecule_pairs = load_protein_molecule_pairs(protein_molecule_pairs_path)
df_protein_molecule_pairs = df_protein_molecule_pairs.reset_index()

# Some example entries
df_protein_molecule_pairs.sample(10)

In [None]:
protein_embeddings = torch.load(protein_embeddings_path, map_location=torch.device("cpu"))
print(len(protein_embeddings), "protein embeddings")

synthetic_pathways = torch.load(synthetic_pathways_path, map_location=torch.device("cpu"))
print(len(synthetic_pathways), "synthetic pathways")

In [None]:
# Retrieve amino-acid sequences for given target IDs
# Alternatively: they're still in original dataset?

def get_amino_acid_sequence(target_id):
    if "_" in target_id:
        target_id = target_id.split("_")[0]
    pdb_id, _ = fetch_pdb_ids(target_id)
    if pdb_id is None:
        # print(f"No PDB ID found for {target_id}")
        return 
    else:
        protein_chain = ProteinChain.from_pdb(rcsb.fetch(pdb_id, "pdb"), chain_id="A") 
        known_protein = ESMProtein.from_protein_chain(protein_chain) 
        return known_protein.sequence

unique_proteins = df_protein_molecule_pairs["target_id"].unique()
aa_seq_path = os.path.join("data", "other", "aa_seq_test.csv")
if os.path.exists(aa_seq_path):
    df_aa_seq = pd.read_csv(aa_seq_path)
else:
    aa_seq = []
    for target_id in unique_proteins:
        try:
            aa_seq.append({
                "target_id": target_id,
                "aa_seq": get_amino_acid_sequence(target_id)
            })
        except Exception as e:
            print(f"({target_id})", e)
    df_aa_seq = pd.DataFrame(aa_seq)
    df_aa_seq.to_csv(aa_seq_path, index=False)

### 2. Loading the model

In [None]:
# Protein-Synformer model configs
# config_path = "configs/prot2drug.yml"  
# config = OmegaConf.load(config_path)

# Sampling settings
model_name = "epoch=23-step=28076"
model_path = os.path.join("data", "trained_weights", f"{model_name}.ckpt")
config_path = None
device = "cpu"

In [None]:
model, fpindex, rxn_matrix = load_model(model_path, config_path, device)

### 3. Example

In [None]:
# Random example
ex_smiles, ex_target_id, ex_short_target_id = df_protein_molecule_pairs.sample().iloc[0]
ex_protein_embeddings = protein_embeddings[ex_target_id].float()
ex_synthetic_pathway_true = synthetic_pathways[ex_smiles]

print("SMILES:", ex_smiles)
print("Target:", ex_target_id)
print("Protein embeddings:", ex_protein_embeddings.shape)
print("True synthetic pathway:", ex_synthetic_pathway_true)

In [None]:
ex_pdb_id, ex_df_pdb_ids = fetch_pdb_ids(ex_short_target_id)

if ex_pdb_id is None:
    print(f"No PDB ID found for {ex_short_target_id}")
else:
    print("PDB ID:", ex_pdb_id)
    ex_protein_chain = ProteinChain.from_pdb(rcsb.fetch(ex_pdb_id, "pdb"), chain_id="A") 
    # Get protein object with all the ground-truth data (except function for some reason) 
    # In the code, they don't provide a way to automatically fetch function annotations, 
    # instead I have to fetch them myself and then set ex_protein.function_annotations 
    ex_known_protein = ESMProtein.from_protein_chain(ex_protein_chain) 

    # Get protein with just the sequence data 
    # So that we can predict the other tracks later 
    # ex_protein = ESMProtein(sequence=ex_protein_chain.sequence) 
    
    # print(len(ex_known_protein.sequence))
    print(ex_known_protein.sequence)
    
    # TODO: have ESM predict binding site and then also visualize it 
    # (already done in Binding Site notebook)
    #
    #
    #
    #
    #
    
    visualize_3D_protein(ex_known_protein, style="cartoon")

In [None]:
info, result = sample( 
    ex_target_id,
    model, 
    fpindex, 
    rxn_matrix,  
    protein_embeddings, 
    device,
    true_smiles=ex_smiles,
    repeat=50
)
# prints: analog.sim(mol), cnt_rxn, log_likelihood, analog.smiles

In [None]:
# Best one:
best_idx = pd.DataFrame(info).T["similarity"].idxmax()

print("True:", ex_smiles)
print("Pred:", info[best_idx]["smiles"])

info[best_idx]

In [None]:
# TODO
# draw_generation_results(result)[best_idx]  







In [None]:
Draw.MolsToGridImage((
    Chem.MolFromSmiles(ex_smiles),     # true 
    Chem.MolFromSmiles(info[best_idx]["smiles"])  # predicted
))

In [None]:
# Similarity map
# Source: https://greglandrum.github.io/rdkit-blog/posts/2020-01-03-similarity-maps-with-new-drawing-code.html

d = Draw.MolDraw2DCairo(400, 400)
_, max_weight = SimilarityMaps.GetSimilarityMapForFingerprint(
    Chem.MolFromSmiles(ex_smiles),
    Chem.MolFromSmiles(info[best_idx]["smiles"]), 
    lambda m, i: SimilarityMaps.GetMorganFingerprint(m, i, radius=2, fpType="bv"), 
    draw2d=d
)
d.FinishDrawing()
Image.open(io.BytesIO(d.GetDrawingText()))

### 4. Evaluation

In [None]:
infos = pickle.load(open(f"data/evaluations/{model_name}/infos_210of300_2025-06-09 22-32-20.pkl", "rb"))

# timestamp = "2025-06-09 19-51-21"
# infos = pickle.load(open(f"data/evaluations/{model_name}/infos {timestamp}.pkl", "rb"))
# results = pickle.load(open(f"data/evaluations/{model_name}/results {timestamp}.pkl", "rb"))

#### 4.1. Tanimoto similarity

In [None]:
similarity_data = []

df = df_protein_molecule_pairs.set_index("target_id")

# Go through all proteins that were processed during this evaluation:
for target_id, info in infos.items():
    # print(target_id)
    
    # Go through all predictions made for this protein:
    for idx, pred in info.items():
        pred_smiles = pred["smiles"]
        mol_pred = Molecule(pred_smiles)
        # print("  ", pred_smiles)
        
        # Go through all true smiles for this protein and compute similarities:
        for true_smiles in df.loc[target_id, "SMILES"]:
            # print("    ", true_smiles)
            mol_true = Molecule(true_smiles)
            sim = mol_pred.sim(mol_true)  # Tanimoto similarity
            # Alternative similarity score: mol_pred.dice_similarity(mol_true) 
            similarity_data.append({
                "true_smiles": true_smiles,
                "pred_smiles": pred_smiles,
                "target_id": target_id,
                "similarity": sim
            })

df_similarity = pd.DataFrame(similarity_data)
df_similarity = df_similarity.drop_duplicates()

del df 

# Example entries:
df_similarity.sample(10) 

In [None]:
df_similarity.describe()

In [None]:
# Random (protein, molecule) pair and all corresponding predictions: 

rand_target_id, rand_true_smiles = df_similarity[["target_id", "true_smiles"]].sample().iloc[0]
df_similarity[(df_similarity["target_id"]==rand_target_id) & (df_similarity["true_smiles"]==rand_true_smiles)].sort_values("similarity", ascending=False)

In [None]:
# The very best prediction out of all predictions

best_pred = df_similarity.loc[df_similarity["similarity"].idxmax()]

print(best_pred)

Draw.MolsToGridImage((
    Chem.MolFromSmiles(best_pred["true_smiles"]),
    Chem.MolFromSmiles(best_pred["pred_smiles"])
))

In [None]:
# For each (protein, molecule) pair, find the best prediction

best_pred_per_pair = df_similarity.loc[df_similarity.groupby(["target_id", "true_smiles"])["similarity"].idxmax()]
print(len(best_pred_per_pair))

# Some examples:
# best_pred_per_pair.sample(10)

best_pred_per_pair.describe()

In [None]:
# For each protein, find the best prediction

best_pred_per_protein = df_similarity.loc[df_similarity.groupby(["target_id"])["similarity"].idxmax()]
print(len(best_pred_per_protein))

# Some examples:
# best_pred_per_protein.sample(10)

best_pred_per_protein.describe()

In [None]:
plt.title("Distribution of similarity scores")
sns.kdeplot(df_similarity["similarity"], fill=True, label="all");
sns.kdeplot(df_similarity.groupby(["target_id", "true_smiles"])["similarity"].max(), fill=True, label="best per pair")  # best prediction per protein-molecule pair 
sns.kdeplot(df_similarity.groupby(["target_id"])["similarity"].max(), fill=True, label="best per protein")  # best prediction per protein
plt.legend();

In [None]:
plt.title("Distribution of SMILES lengths")
sns.histplot(df_similarity["pred_smiles"].str.len(), fill=True, label="pred", binwidth=1, alpha=0.5)
sns.histplot(df_similarity["true_smiles"].str.len(), fill=True, label="true", binwidth=1, alpha=0.5)
plt.legend();

#### 4.2. Binding affinity predictions

In [None]:
# !pip install deeppurpose
# !pip install git+https://github.com/bp-kelley/descriptastorus 
# !pip install pandas-flavor

In [None]:
from DeepPurpose import DTI as models
import DeepPurpose.utils as utils

In [None]:
# Load pretrained model (MPNN for drug, CNN for protein)
model = models.model_pretrained(model="MPNN_CNN_DAVIS")

In [None]:
# Example: best_pred
# seq = get_amino_acid_sequence(best_pred["target_id"])
seq = df_aa_seq.set_index("target_id").loc[best_pred["target_id"], "aa_seq"]

# SMILES strings
X_drug = [
    best_pred["true_smiles"], 
    best_pred["pred_smiles"],
]

# Protein sequences
X_target = [
    seq,
    seq,
]

# Ground-truth labels??
labels = np.zeros(len(X_drug))

drug_encoding = "MPNN"
target_encoding = "CNN"

X = utils.data_process(
    X_drug, 
    X_target, 
    labels, 
    drug_encoding, 
    target_encoding, 
    split_method="no_split"
)

y = model.predict(X)
y

#### 4.2.1. Binding affinity of predicted molecules most similar to true molecules

In [None]:
# Add amino-acid sequence column:
df_binding = best_pred_per_pair.set_index("target_id").join(df_aa_seq.set_index("target_id")).reset_index()

# Only keep those for which we were able to retrieve an amino acid sequence
df_binding = df_binding.dropna()

In [None]:
%%time
# Predict binding affinity for true SMILES (from Papyrus dataset)
# and for our predicted SMILES

# Could be done more efficiently, since it's probably doing a lot of pairs multiple times.
# But it's not too slow, so it's fine for now

X_true = utils.data_process(
    df_binding["true_smiles"].values, 
    df_binding["aa_seq"].values, 
    np.zeros(len(df_binding)), 
    drug_encoding,  # same as above
    target_encoding,  # same as above
    split_method="no_split"
)

X_pred = utils.data_process(
    df_binding["pred_smiles"].values, 
    df_binding["aa_seq"].values, 
    np.zeros(len(df_binding)), 
    drug_encoding,  # same as above
    target_encoding,  # same as above
    split_method="no_split"
)

df_binding["binding_affinity_true"] = model.predict(X_true)
df_binding["binding_affinity_pred"] = model.predict(X_pred)

# Not needed anymore; removing it makes the table easier read
del df_binding["aa_seq"]

In [None]:
# How much better does it bind than the true SMILES?
#  >0: binds more
#  =0: binds exactly the same
#  <0: binds less 

df_binding["binding_affinity_diff"] = df_binding["binding_affinity_pred"] - df_binding["binding_affinity_true"] 

In [None]:
# Top 10 predicted molecules that bind better than the true molecule
df_binding.sort_values("binding_affinity_diff", ascending=False).iloc[:10]

In [None]:
# Top 10 predicted molecules that bind worse than the true molecule
df_binding.sort_values("binding_affinity_diff", ascending=True).iloc[:10]

In [None]:
df_binding.describe()

In [None]:
plt.title("Distribution of binding affinity differences")
sns.kdeplot(df_binding["binding_affinity_diff"], fill=True);
# sns.kdeplot(df_binding.drop_duplicates(["target_id", "pred_smiles"])["binding_affinity_diff"], fill=True);

#### 4.2.2. Binding affinity of all predicted molecules

Above, I only picked those predicted molecules that are most similar to the true molecules (`best_pred_per_pair`).  
But it looks like the similarity might not be the most important metric. Even very dissimilar ones can bind very well.  
Let's get the binding affinity of all predictions: 

In [None]:
# Add amino-acid sequence column:
df_binding_all = df_similarity.drop_duplicates(["target_id", "pred_smiles"]).set_index("target_id").join(df_aa_seq.set_index("target_id")).reset_index()

# Only keep those for which we were able to retrieve an amino acid sequence
df_binding_all = df_binding_all.dropna()

In [None]:
# Above, there were this many unique protein-molecule predictions:
print(len(best_pred_per_pair[["target_id", "pred_smiles"]].drop_duplicates()))

# But I didn't find the amino acid sequence for some proteins, so it ended up being only:
print(len(df_binding[["target_id", "pred_smiles"]].drop_duplicates()))

In [None]:
# Now, we look at all unique protein-molecule predictions.
# There are this many: 
print(len(df_similarity[["target_id", "pred_smiles"]].drop_duplicates()))

# But for some, I won't have the amino acid sequence again:
print(len(df_binding_all))

In [None]:
%%time
# Predict binding affinity for true SMILES (from Papyrus dataset)
# and for our predicted SMILES

# Could be done more efficiently, since it's probably doing a lot of pairs multiple times.
# But it's not too slow, so it's fine for now

X_true = utils.data_process(
    df_binding_all["true_smiles"].values, 
    df_binding_all["aa_seq"].values, 
    np.zeros(len(df_binding_all)), 
    drug_encoding,  # same as above
    target_encoding,  # same as above
    split_method="no_split"
)

X_pred = utils.data_process(
    df_binding_all["pred_smiles"].values, 
    df_binding_all["aa_seq"].values, 
    np.zeros(len(df_binding_all)), 
    drug_encoding,  # same as above
    target_encoding,  # same as above
    split_method="no_split"
)

df_binding_all["binding_affinity_true"] = model.predict(X_true)
df_binding_all["binding_affinity_pred"] = model.predict(X_pred)

# Not needed anymore; removing it makes the table easier read
del df_binding_all["aa_seq"]

In [None]:
# How much better does it bind than the true SMILES?
#  >0: binds more
#  =0: binds exactly the same
#  <0: binds less 

df_binding_all["binding_affinity_diff"] = df_binding_all["binding_affinity_pred"] - df_binding_all["binding_affinity_true"] 

In [None]:
# Top 10 predicted molecules that bind better than the true molecule
df_binding_all.sort_values("binding_affinity_diff", ascending=False).iloc[:10]

In [None]:
# Top 10 predicted molecules that bind worse than the true molecule
df_binding_all.sort_values("binding_affinity_diff", ascending=True).iloc[:10]

In [None]:
df_binding_all.describe()

In [None]:
plt.title("Distribution of binding affinity differences")
sns.kdeplot(df_binding["binding_affinity_diff"], fill=True);
# sns.kdeplot(df_binding.drop_duplicates(["target_id", "pred_smiles"])["binding_affinity_diff"], fill=True);