In [None]:
# Additional requirements just for this notebook:

# !pip install huggingface_hub[cli]
# !pip install esm 
# !pip install py3Dmol

In [1]:
# General:
import os 
from omegaconf import OmegaConf
# import huggingface_hub
import requests
import pathlib
import pickle
import pandas as pd
import torch

# ESM:
# from esm.models.esmc import ESMC
# from esm.models.esm3 import ESM3
from esm.sdk.api import ESMProtein, LogitsConfig, GenerationConfig
from esm.utils.structure.protein_chain import ProteinChain
# from esm.utils.types import FunctionAnnotation
from biotite.database import rcsb
# from fa_helper import visualize_function_annotations, get_keywords_from_interpro, interpro2keywords
from scripts.helpers.pdb import gene_to_pdb, fetch_pdb_ids

# Visualization:
import matplotlib.pyplot as plt
from scripts.helpers.visualization import visualize_3D_protein

# Synformer:
# from synformer.chem.fpindex import FingerprintIndex
# from synformer.chem.matrix import ReactantReactionMatrix
# from synformer.chem.mol import Molecule
from synformer.models.synformer import Synformer
from scripts.sample_naive import load_model, featurize_smiles, load_protein_molecule_pairs, sample

In [2]:
# Loading ESM-related configs:
assert os.path.exists("configs/esm.yml"), "Missing config file: configs/esm.yml" 
esm_config = OmegaConf.load("configs/esm.yml")
# huggingface_hub.login(esm_config.hf_token)

### Data: protein–molecule pairs, protein embeddings and synthetic pathways

In [3]:
# Dataset paths
protein_molecule_pairs_val_path = "data/protein_molecule_pairs/papyrus_val_19399.csv"
protein_embeddings_path = "data/protein_embeddings/embeddings_selection_float16_4973.pth"
synthetic_pathways_path = "data/synthetic_pathways/filtered_pathways_370000.pth"

In [4]:
df_protein_molecule_pairs = load_protein_molecule_pairs(protein_molecule_pairs_val_path)
df_protein_molecule_pairs = df_protein_molecule_pairs.reset_index()

# Some example entries
df_protein_molecule_pairs.sample(10)

Unnamed: 0,SMILES,target_id,short_target_id
15471,O=C(CSc1c2oc3c(cccc3)c2ncn1)Nc1c(C(F)(F)F)cccc1,Q9HBX9_WT,Q9HBX9
9252,Cc1cc(C(=O)CSc2nnnn2-c2ccccc2)c(C)n1-c1ccc2OCO...,Q13526_WT,Q13526
6814,COc1ccc(C(CC(=O)N2CCC3(CC2)OCCO3)NS(=O)(=O)c2c...,P02791_WT,P02791
13282,Cc1cc(C)c(CCN2CCCC2C)c(C)c1,Q9QYN8_WT,Q9QYN8
7619,COc1c(S(=O)(=O)N2CCCc3ccccc32)cc(C(=O)Nc2nc3c(...,Q9NR56_WT,Q9NR56
1603,COc1ccc(C(=O)NN=Cc2c(CO)cnc(C)c2O)cc1,P52292_WT,P52292
14378,CC(=O)c1ccc(NC(=O)c2ccc3nc(C)sc3c2)cc1,Q9NR56_WT,Q9NR56
11568,O=C1c2c(cccc2)-c2c1cc(-c1ccc(O)cc1)nn2,P19643_WT,P19643
4393,COc1ccc2nc(NC(=O)C(NS(=O)(=O)c3cccs3)C(C)C)sc2c1,Q9NR56_WT,Q9NR56
6123,Cc1ccc(C(=O)C=Cc2cc(Br)c(F)cc2)o1,P21397_WT,P21397


In [5]:
protein_embeddings = torch.load(protein_embeddings_path, map_location=torch.device("cpu"))
print(len(protein_embeddings), "protein embeddings")

synthetic_pathways = torch.load(synthetic_pathways_path, map_location=torch.device("cpu"))
print(len(synthetic_pathways), "synthetic pathways")

4973 protein embeddings
70936 synthetic pathways


### Protein-SynFormer model

In [6]:
# Protein-Synformer model configs
# config_path = "configs/prot2drug.yml"  
# config = OmegaConf.load(config_path)

# Sampling settings
model_path = "data/trained_weights/epoch=23-step=28076.ckpt"
config_path = None
device = "cpu"

In [7]:
model, fpindex, rxn_matrix = load_model(model_path, config_path, device)

### Example

In [12]:
# Random example
ex_smiles, ex_target_id, ex_short_target_id = df_protein_molecule_pairs.sample().iloc[0]
ex_protein_embeddings = protein_embeddings[ex_target_id].float()
ex_synthetic_pathway_true = synthetic_pathways[ex_smiles]

print("SMILES:", ex_smiles)
print("Target:", ex_target_id)
print("Protein embeddings:", ex_protein_embeddings.shape)
print("True synthetic pathway:", ex_synthetic_pathway_true)

SMILES: CN1C2CCC1C(=Cc1cccn1C)C(=O)C2=Cc1cccn1C
Target: O95398_WT
Protein embeddings: torch.Size([925, 1152])
True synthetic pathway: [(1, -1), (3, 4449), (3, 1974), (2, 3), (3, 1974), (2, 3), (0, -1)]


In [13]:
ex_pdb_id, ex_df_pdb_ids = fetch_pdb_ids(ex_short_target_id)
print("PDB ID:", ex_pdb_id)

ex_protein_chain = ProteinChain.from_pdb(rcsb.fetch(ex_pdb_id, "pdb"), chain_id="A") 

# Get protein object with all the ground-truth data (except function for some reason) 
# In the code, they don't provide a way to automatically fetch function annotations, 
# instead I have to fetch them myself and then set ex_protein.function_annotations 
ex_known_protein = ESMProtein.from_protein_chain(ex_protein_chain) 

# Get protein with just the sequence data 
# So that we can predict the other tracks later 
# ex_protein = ESMProtein(sequence=ex_protein_chain.sequence) 

# print(len(ex_known_protein.sequence))
print(ex_known_protein.sequence)

# TODO: have ESM predict binding site and then also visualize it 
# (already done in Binding Site notebook)






visualize_3D_protein(ex_known_protein, style="cartoon")

PDB ID: 6H7E
ASTERVLRAGRQLHRHLLATCPNLIRDRKYHLRLYRQCCSGRELVDGILALGHSRSQVVGICQVLLDEGALCHVKHDWAFQDRDAQFYRFPGPEPEPVEELAEAVALLSQRGPDALLTVALRKPPGQRTDEELDLIFEELLHIKAVAHLSNSVKRELAAVLLFEPHSKAGTVLFSQGDKGTSWYIIWKGSVNVVTHGKGLVTTLHEGDDFGQLALVNDAPRAATIILREDNCHFLRVDKQDFNRIIK


<py3Dmol.view at 0x149b21840>

In [18]:
sample(
    ex_smiles, 
    ex_target_id,
    model, 
    fpindex, 
    rxn_matrix,  
    protein_embeddings, 
    device,
    repeat=10
)

  0%|          | 0/23 [00:00<?, ?it/s]

Total: 0 / 1
