In [1]:
tankbind_src_folder_path = "./tankbind/"
import sys
sys.path.insert(0, tankbind_src_folder_path)

In [2]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm

In [3]:
df = pd.read_csv('Mcule_10000.csv', index_col=0)

In [4]:
from feature_utils import get_protein_feature, get_clean_res_list
from Bio.PDB import PDBParser

In [5]:
def extract_protein_names(file_path):
    protein_names = []
    with open(file_path, 'r') as file:
        for line in file:
            # Get the base name without directory path
            base_name = os.path.basename(line.strip())
            # Split the base name to get the protein name without extension
            protein_name = os.path.splitext(base_name)[0]
            protein_names.append(protein_name)
    return protein_names

file_path = 'pdb_dataset.ds'
protein_names = extract_protein_names(file_path)
print(protein_names[0:5])

['1A07', '1A08', '1A09', '1A0N', '1A1A']


In [6]:
protein_names.__len__()

5963

In [11]:
import random
n = len(protein_names) // 5
sampled_protein_names = random.sample(protein_names, n)
print(sampled_protein_names[0:15], len(sampled_protein_names))

['5CSJ', '6DI1', '4JOA', '6V6K', '3QRI', '3NNU', '5CNN', '5TQ8', '2OWB', '1IVO', '2EXM', '3ML9', '3QZF', '1HCL', '2XYN'] 1192


In [12]:
def process_proteins(protein_names, pdb_directory):
    parser = PDBParser(QUIET=True)
    protein_dict = {}
    i = 0
    for proteinName in sampled_protein_names:
        try:
            # print(i, ': ', proteinName)
            proteinFile = f"{pdb_directory}/{proteinName}.pdb"
            s = parser.get_structure(proteinName, proteinFile)
            res_list = list(s.get_residues())
            clean_res_list = get_clean_res_list(res_list, ensure_ca_exist=True)
            protein_dict[proteinName] = get_protein_feature(clean_res_list)
        except Exception as e:
            print(f"Error processing {proteinName}: {e}")
        finally:
            i += 1
    
    return protein_dict

pdb_directory = "PDB_files"  # Directory containing PDB files

protein_dict = process_proteins(sampled_protein_names, pdb_directory)

In [13]:

protein_dict.keys().__len__() # didin't work for 2 proteins 1IAN & 6XR4 (indexError: Dimension out of range (expected to be in range of [-1, 0], but got 1))

1192

In [14]:
import torch
torch.save(protein_dict, 'precalculated_protein_features_small.pt')

In [15]:
protein_names_new = list(protein_dict.keys())
protein_names_new.__len__()

1192

In [16]:
info = []
for protein_name in protein_names_new:
    for i, line in tqdm(df.iterrows(), total=len(df)):
        smiles = line['smiles']
        compund_name = ''
        protein_name = protein_name
        com = ",".join([str(x.round(3)) for x in protein_dict[protein_name][0].mean(axis=0).numpy()])
        info.append([protein_name, compund_name, smiles, "protein_center", com])

info = pd.DataFrame(info, columns=['protein_name', 'compound_name', 'smiles', 'pocket_name', 'pocket_com'])


100%|██████████| 10000/10000 [00:01<00:00, 8029.12it/s]
100%|██████████| 10000/10000 [00:01<00:00, 8599.54it/s]
100%|██████████| 10000/10000 [00:01<00:00, 8624.00it/s]
100%|██████████| 10000/10000 [00:01<00:00, 8430.75it/s]
100%|██████████| 10000/10000 [00:01<00:00, 8823.90it/s]
100%|██████████| 10000/10000 [00:01<00:00, 8774.90it/s]
100%|██████████| 10000/10000 [00:01<00:00, 7286.31it/s]
100%|██████████| 10000/10000 [00:01<00:00, 8908.09it/s]
100%|██████████| 10000/10000 [00:01<00:00, 8732.26it/s]
100%|██████████| 10000/10000 [00:01<00:00, 8529.25it/s]
100%|██████████| 10000/10000 [00:01<00:00, 8678.24it/s]
100%|██████████| 10000/10000 [00:01<00:00, 8483.69it/s]
100%|██████████| 10000/10000 [00:01<00:00, 8879.74it/s]
100%|██████████| 10000/10000 [00:01<00:00, 8894.99it/s]
100%|██████████| 10000/10000 [00:01<00:00, 8816.00it/s]
100%|██████████| 10000/10000 [00:01<00:00, 7979.60it/s]
100%|██████████| 10000/10000 [00:01<00:00, 8838.87it/s]
100%|██████████| 10000/10000 [00:01<00:00, 8347.

In [17]:
info.head()
info.__len__()


11920000

In [18]:
info

Unnamed: 0,protein_name,compound_name,smiles,pocket_name,pocket_com
0,5CSJ,,CC1=C(C2=CC=C(F)C=C2)N=C(NC(C2=C(C)C3=C(CCCC3=...,protein_center,"-11.36,-12.253,19.679"
1,5CSJ,,CC(N(C(NC1C2OC(CC2)C1)=O)C)CS(C)(=O)=O,protein_center,"-11.36,-12.253,19.679"
2,5CSJ,,S(=O)(=O)(C1C=CC2=C(C=1)N=C(N2C1CCCC1)C)N1CCC(...,protein_center,"-11.36,-12.253,19.679"
3,5CSJ,,N1CCCCC1CC1=CC=NC=C1,protein_center,"-11.36,-12.253,19.679"
4,5CSJ,,N1(CCC2C=CC(CNC(=O)NC3=CC=CC=C3C)=CC1=2)C(C1=C...,protein_center,"-11.36,-12.253,19.679"
...,...,...,...,...,...
11919995,3S00,,CCCN(C(OC(C)(C)C)=O)CCNC(C1=C(SC)C=CC=C1)=O,protein_center,"-78.279,102.779,-49.735"
11919996,3S00,,N[C@H](C1C=NC=CC=1Br)C(F)(F)F,protein_center,"-78.279,102.779,-49.735"
11919997,3S00,,C12N=CN(C(=O)C=1N=NN2CC1C=CC(=CC=1)F)CC(=O)NC1...,protein_center,"-78.279,102.779,-49.735"
11919998,3S00,,N1(CCC[C@H](O)[C@@H]1CC1=CC=CC=C1F)CC1C=CC=CC=...,protein_center,"-78.279,102.779,-49.735"


In [41]:
small_info = info[0:20000]

In [38]:
# info.to_csv('info.csv')

### Construct the Dataset

In [19]:
import torch
torch.set_num_threads(1)

In [20]:
from torch_geometric.data import Dataset
from utils import construct_data_from_graph_gvp
import rdkit.Chem as Chem    # conda install rdkit -c rdkit if import failure.
from feature_utils import extract_torchdrug_feature_from_mol, get_canonical_smiles

In [29]:
class MyDataset_VS(Dataset):
    def __init__(self, root, data=None, protein_dict=None, proteinMode=0, compoundMode=1,
                 pocket_radius=20, shake_nodes=None,
                 transform=None, pre_transform=None, pre_filter=None):
        self.data = data
        self.protein_dict = protein_dict
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data = torch.load(self.processed_paths[0])
        self.protein_dict = torch.load(self.processed_paths[1])
        self.proteinMode = proteinMode
        self.pocket_radius = pocket_radius
        self.compoundMode = compoundMode
        self.shake_nodes = shake_nodes

    @property
    def processed_file_names(self):
        return ['data.pt', 'proteins.pt']

    def process(self):
        torch.save(self.data, self.processed_paths[0])
        torch.save(self.protein_dict, self.processed_paths[1])

    def len(self):
        return len(self.data)

    def get(self, idx):
        line = self.data.iloc[idx]
        smiles = line['smiles']
        pocket_com = line['pocket_com']
        pocket_com = np.array(pocket_com.split(",")).astype(float) if isinstance(pocket_com, str) else pocket_com
        pocket_com = pocket_com.reshape((1, 3))
        use_whole_protein = line.get('use_whole_protein', False)

        protein_name = line['protein_name']
        protein_data = self.protein_dict.get(protein_name)
        
        if protein_data is None:
            raise ValueError(f"Protein {protein_name} not found in pre-calculated protein dictionary")

        protein_node_xyz, protein_seq, protein_node_s, protein_node_v, protein_edge_index, protein_edge_s, protein_edge_v = protein_data

        try:
            smiles = get_canonical_smiles(smiles)
            mol = Chem.MolFromSmiles(smiles)
            mol.Compute2DCoords()
            coords, compound_node_features, input_atom_edge_list, input_atom_edge_attr_list, pair_dis_distribution = extract_torchdrug_feature_from_mol(mol, has_LAS_mask=True)
        except Exception as e:
            print(f"Error processing {smiles}: {e}. Using placeholder 'CCC'.")
            smiles = 'CCC'
            mol = Chem.MolFromSmiles(smiles)
            mol.Compute2DCoords()
            coords, compound_node_features, input_atom_edge_list, input_atom_edge_attr_list, pair_dis_distribution = extract_torchdrug_feature_from_mol(mol, has_LAS_mask=True)
        
        data, input_node_list, keepNode = construct_data_from_graph_gvp(
            protein_node_xyz, protein_seq, protein_node_s, protein_node_v, 
            protein_edge_index, protein_edge_s, protein_edge_v,
            coords, compound_node_features, input_atom_edge_list, input_atom_edge_attr_list,
            pocket_radius=self.pocket_radius, use_whole_protein=use_whole_protein, includeDisMap=True,
            use_compound_com_as_pocket=False, chosen_pocket_com=pocket_com, compoundMode=self.compoundMode
        )
        data.compound_pair = pair_dis_distribution.reshape(-1, 16)

        return data

In [32]:
# Load the dataset
dataset_path = f"dataset1192"
os.system(f"rm -r {dataset_path}")
os.system(f"mkdir -p {dataset_path}")
# dataset = MyDataset_VS(root=dataset_path, data=info, protein_dict='precalculated_protein_features_small.pt')

dataset = MyDataset_VS(root=dataset_path, data=info, protein_dict=protein_dict)


Processing...
Done!


In [33]:
dataset.__len__()

11920000

### Model Testing (only for 5 proteins, very compute intensive, for more use google collab)

In [20]:
import logging
from torch_geometric.loader import DataLoader
from tqdm import tqdm    # pip install tqdm if fails.
from model import get_model

#### Filter dataset (only use 10% and remove large proteins/pockets)

In [34]:
# Assuming `dataset` is a list or a dataset object that supports indexing and slicing
# Reduce dataset size for evaluation (e.g., take only 10% of the dataset)
# reduced_dataset_size = int(len(dataset) * 0.1)
# reduced_dataset = dataset[:reduced_dataset_size]

# Filter out large proteins/pockets
# Assuming `dataset` has a method or a way to access `dis_map` for each item
# filtered_dataset = [item for item in dataset if item.dis_map.shape[0] <= 10000]

In [24]:
batch_size = 5
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device= 'cpu'
logging.basicConfig(level=logging.INFO)
model = get_model(0, logging, device)

# self-dock model
modelFile = "./model/self_dock.pt"

model.load_state_dict(torch.load(modelFile, map_location=device))
_ = model.eval()

data_loader = DataLoader(dataset, batch_size=batch_size, follow_batch=['x', 'y', 'compound_pair'], shuffle=False, num_workers=8)
affinity_pred_list = []
y_pred_list = []
for data in tqdm(data_loader):
    data = data.to(device)
    y_pred, affinity_pred = model(data)
    affinity_pred_list.append(affinity_pred.detach().cpu())
    if False:
        # we don't need to save the predicted distance map in HTVS setting.
        for i in range(data.y_batch.max() + 1):
            y_pred_list.append((y_pred[data['y_batch'] == i]).detach().cpu())

affinity_pred_list = torch.cat(affinity_pred_list)

09:01:39   5 stack, readout2, pred dis map add self attention and GVP embed, compound model GIN


AttributeError: module 'torch' has no attribute '_six'