In [1]:
# env imports
import sys
import os
import logging
import numpy as np
import random
import pandas as pd
import torch
from tqdm import tqdm
import pickle

import rdkit.Chem as Chem
from rdkit.Chem import AllChem
from Bio.PDB import PDBParser
import torchmetrics
from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader

In [3]:
tankbind_src_folder_path = "./tankbind/"
sys.path.insert(0, tankbind_src_folder_path)

In [4]:
# imports from tankbind
from feature_utils import get_protein_feature, get_clean_res_list, extract_torchdrug_feature_from_mol, get_canonical_smiles
from utils import construct_data_from_graph_gvp, evaulate_with_affinity, evaulate
from model import get_model
from generation_utils import get_LAS_distance_constraint_mask, get_info_pred_distance, write_with_new_coords
from metrics import print_metrics, myMetric

# Protein_dict

In [None]:
def process_proteins(protein_names, pdb_directory='PDB_files'):
    parser = PDBParser(QUIET=True)
    protein_dict = {}

    for proteinName in protein_names:
        try:
            proteinFile = f"{pdb_directory}/{proteinName}.pdb"
            s = parser.get_structure(proteinName, proteinFile)
            res_list = list(s.get_residues())
            clean_res_list = get_clean_res_list(res_list, ensure_ca_exist=True)
            protein_dict[proteinName] = get_protein_feature(clean_res_list)
        except Exception as e:
            print(f"Error processing {proteinName}: {e}")

    return protein_dict

In [5]:
kiba_df = pd.read_csv('kiba_data_df_with_ids.csv')
kiba_df.head()

Unnamed: 0.1,Unnamed: 0,Smiles,molecules,target_affinity,uniprot_id,pdb_id
0,0,COC1=C(C=C2C(=C1)CCN=C2C3=CC(=C(C=C3)Cl)Cl)Cl,MTVKTEAAKGTLTYSRMRGMVAILIAFMKQRRMGLNDFIQKIANNS...,11.1,O00141,2R5T
1,1,COC1=C(C=C2C(=C1)CCN=C2C3=CC(=C(C=C3)Cl)Cl)Cl,MSWSPSLTTQTCGAWEMKERLGTGGFGNVIRWHNQETGEQIAIKQC...,11.1,O14920,3BRT
2,2,COC1=C(C=C2C(=C1)CCN=C2C3=CC(=C(C=C3)Cl)Cl)Cl,MERPPGLRPGAGGPWEMRERLGTGGFGNVCLYQHRELDLKIAIKSC...,11.1,O15111,3BRT
3,3,COC1=C(C=C2C(=C1)CCN=C2C3=CC(=C(C=C3)Cl)Cl)Cl,MRPSGTAGAALLALLAALCPASRALEEKKVCQGTSNKLTQLGTFED...,11.1,P00533,1IVO
4,4,COC1=C(C=C2C(=C1)CCN=C2C3=CC(=C(C=C3)Cl)Cl)Cl,MELAALCRWGLLLALLPPGAASTQVCTGTDMKLRLPASPETHLDML...,11.1,P04626,1MFG


In [7]:
# get protein names as list
protein_names = kiba_df['pdb_id'].unique()
print(f"Number of unique proteins: {len(protein_names)}")

Number of unique proteins: 226


In [8]:
# generate protein features --> protein_dict
protein_dict = process_proteins(protein_names)

In [None]:
# some checks
print(f"Number of proteins in protein_dict: {len(protein_dict)}")

Number of proteins in protein_dict: 226


In [None]:
# Save the protein_dict to a file (only for demonstration, not needed in practice)
with open("protein_dict.pkl", "wb") as f:
    pickle.dump(protein_dict, f)

# load protein_dict from pickle file (only for demonstration, not needed in practice)
with open("protein_dict.pkl", "rb") as f:
    protein_dict = pickle.load(f)

# save as protein_dict as pt file
torch.save(protein_dict, "protein_dict.pt")

# Molecule_dict

In [12]:

def process_molecules(smiles):
    """
    Create a dictionary of molecules from a list of SMILES strings.
    Each molecule is represented by its canonical SMILES and its features.

    smiles: ndarray of unique SMILES strings

    """
    molecule_dict = {}
    for molecule in smiles:
        smiles = get_canonical_smiles(molecule)
        mol = Chem.MolFromSmiles(smiles)
        mol.Compute2DCoords()
        molecule_dict[molecule] = extract_torchdrug_feature_from_mol(mol, has_LAS_mask=True)
    return molecule_dict

In [13]:
# get unique SMILES strings
smiles = kiba_df['Smiles'].unique()
print(f"Number of unique SMILES: {len(smiles)}")

Number of unique SMILES: 2068


In [14]:
# generate molecule features --> molecule_dict
molecule_dict = process_molecules(smiles)

In [15]:
# some checks
print(f"Number of molecules in molecule_dict: {len(molecule_dict)}")

Number of molecules in molecule_dict: 2068


In [16]:
# save molecule_dict to a pt file
torch.save(molecule_dict, "molecule_dict.pt")