# sx.nciyes.main.ipynb

Notebook for training and inference.



## Configuration
### Running Configuration

In [1]:
cfg_development = True # Whether the testing paths are used.
cfg_use_nci = True # Temporarily unused.
cfg_singleton = False # Whether only one protein is integrated.
cfg_distinguish_by_timestamp = True # As name suggests.

### Path Configuration
Fill in the first block with paths.

In [2]:
input_path = None # Path to load all inputs.
output_path = None # Path to store intermediate and final outputs. Each run creates a folder here.
ds_path = None # Path to store .ds file for p2rank.

log_path = "./Logs/" # Paht to store log files.
p2rank_path = "../p2rank_2.3/prank" # Path of prank file.

pdb_df_fname = None # Filename of protein info dataframe.
ligand_df_fname = None # Filename of ligand info dataframe.
pdb_df_fpath = None # Full path of protein info dataframe.
ligand_df_fpath = None # Full path of ligand info dataframe.

pdb_path = "./Inputs/Dev/RenamedPDBBind/" # Path to load all the .pdb files of proteins.
ligand_path = "./Inputs/Dev/RenamedPDBBind/" # Path to load all the .mol2 files of ligands.

#### <font color="grey"><i>processing block</i></font>

In [3]:
tankbind_src_folder_path = "../tankbind/"
import sys
sys.path.insert(0, tankbind_src_folder_path)
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import time


# Development path settings.
if cfg_development:
    input_path = "./Inputs/Dev/"
    output_path = "./Outputs/Dev/"
    ds_path = "../../../"
    pdb_df_fname = "dev.pdbs.csv"
    ligand_df_fname = "dev.ligands.csv"
    nci_fname = "dev.nci.csv"
    pdb_df_fpath = "./Inputs/Dev/dev.pdbs.csv"
    ligand_df_fpath = "./Inputs/Dev/dev.ligands.csv"

# Path settings.
if pdb_df_fname and not pdb_df_fpath:
    pdb_df_fpath = f"{input_path}{pdb_df_fname}"
if ligand_df_fname and not ligand_df_fpath:
    ligand_df_fpath = f"{input_path}{ligand_df_fname}"
pdb_df_fpath = f"{input_path}{pdb_df_fname}"
ligand_df_fpath = f"{input_path}{ligand_df_fname}"
nci_fpath = f"{input_path}{nci_fname}"
p2rank = f"bash {p2rank_path}"

# Loading data.
pdb_df = pd.read_csv(pdb_df_fpath, index_col=0)
ligand_df = pd.read_csv(ligand_df_fpath, index_col=0)
pdb_code_list = list(pdb_df["pdb_code"])
pdb_fpath_list = list(pdb_df["pdb_fpath"])

# Prepare output folders.
if len(pdb_code_list) != 1:
    cfg_singleton = False
if cfg_distinguish_by_timestamp or (not cfg_singleton):
    timetag = time.strftime("%m%d%H%M")+"-"
else:
    timetag = ""
if cfg_singleton:
    main_path = f"{output_path}{timetag}{pdb_code_list[0]}/"
else:
    main_path = f"{output_path}{timetag}MultiProteins/"
os.system(f"rm -rf {main_path}")   
os.system(f"mkdir -p {main_path}")
os.system(f"rm -rf {main_path}/p2rank")
os.system(f"mkdir -p {main_path}/p2rank")


0

### Show Configurations

In [4]:
configs = {"main_path": main_path, "pdb_df_fname": pdb_df_fname, "pdb_path": pdb_path, 
           "ligand_df_fname": ligand_df_fname,
           "ligand_path": ligand_path, "p2rank_path": p2rank_path, 
           "pdb_list": pdb_code_list, 
           "cfg_development": cfg_development, "cfg_use_nci": cfg_use_nci, 
           "cfg_singleton": cfg_singleton, 
           "cfg_distinguish_by_timestamp": cfg_distinguish_by_timestamp}
print("Configs :")
for _key in configs.keys():
    print("    "+_key+" : "+str(configs[_key]))

Configs :
    main_path : ./Outputs/Dev/08040750-MultiProteins/
    pdb_df_fname : dev.pdbs.csv
    pdb_path : ./Inputs/Dev/RenamedPDBBind/
    ligand_df_fname : dev.ligands.csv
    ligand_path : ./Inputs/Dev/RenamedPDBBind/
    p2rank_path : ../p2rank_2.3/prank
    pdb_list : ['1g35']
    cfg_development : True
    cfg_use_nci : True
    cfg_singleton : False
    cfg_distinguish_by_timestamp : True


## Running!

### Get protein features

In [5]:
from sx_feature_utils import sx_get_protein_feature, get_clean_res_list
from Bio.PDB import PDBParser
def get_full_id(full_id_ls: list):
    chain_id = full_id_ls[2]
    res_id = full_id_ls[3][1]
    return chain_id + "_" + str(res_id)

In [6]:
parser = PDBParser(QUIET=True)
protein_dict = {}
protein_res_full_id_dict = {}
for _pname, _fpath in zip (pdb_code_list, pdb_fpath_list):
    s = parser.get_structure(_pname, _fpath)
    res_list = list(s.get_residues())
    clean_res_list = get_clean_res_list(res_list, ensure_ca_exist=True)
    clean_res_full_id_list = [get_full_id(x.full_id) for x in clean_res_list]
    protein_dict[_pname], protein_res_full_id_dict[_pname] = sx_get_protein_feature(clean_res_list, clean_res_full_id_list)

### Segmentation of proteins by p2rank

In [7]:
ds = f"{main_path}/protein_list.ds"
with open(ds, "w") as out:
    for _fpath in pdb_fpath_list:
        out.write(f"{ds_path}{_fpath}\n")

In [8]:
cmd = f"{p2rank} predict {ds} -o {main_path}/p2rank -threads 1"
os.system(cmd)

----------------------------------------------------------------------------------------------
 P2Rank 2.3
----------------------------------------------------------------------------------------------

predicting pockets for proteins from dataset [protein_list.ds]
processing [1g35_protein.pdb] (1/1)
predicting pockets finished in 0 hours 0 minutes 12.320 seconds
results saved to directory [/home/jovyan/TankBind/nciyes/Outputs/Dev/08040750-MultiProteins/p2rank]

----------------------------------------------------------------------------------------------
 finished successfully in 0 hours 0 minutes 13.254 seconds
----------------------------------------------------------------------------------------------


0

### Get ligand infomation

In [9]:
from rdkit import Chem
from feature_utils import get_canonical_smiles
from tqdm import tqdm

ligand_name_list = list(ligand_df["ligand_name"])
ligand_pdb_list = list(ligand_df["pdb_code"])
ligand_fpath_list = list(ligand_df["ligand_fpath"])

canonique_smiles = []
for _fpath in ligand_fpath_list:
    canonique_smiles.append(get_canonical_smiles(Chem.MolToSmiles(Chem.MolFromMol2File(_fpath))))
ligand_df["canonique_smiles"] = canonique_smiles

In [10]:
info = []
for i,line in tqdm(ligand_df.iterrows(), total=ligand_df.shape[0]):
    ligand_name = line["ligand_name"]
    pdb_code = line["pdb_code"]
    pdb_fname = line["pdb_code"] + "_protein"
    canonique_smiles = line["canonique_smiles"]
    ligand_fpath = line["ligand_fpath"]
    ligand_ftype = os.path.splitext(os.path.split(ligand_fpath)[1])[1]
    
    p2rankFile = f"{main_path}p2rank/{pdb_fname}.pdb_predictions.csv"
    pocket_df = pd.read_csv(p2rankFile)
    pocket_df.columns = pocket_df.columns.str.strip()
    pocket_coms = pocket_df[["center_x", "center_y", "center_z"]].values
    for ith_pocket, com in enumerate(pocket_coms):
        com = ",".join([str(a.round(3)) for a in com])
        info.append([ligand_name, pdb_code, canonique_smiles, ligand_fpath, ligand_ftype, f"pocket_{ith_pocket+1}", com])
info = pd.DataFrame(info, columns = ["ligand_name", "pdb_code", "canonique_smiles", 
                                     "ligand_fpath", "ligand_ftype", "pocket_name", "pocket_com"])
info

100%|██████████| 1/1 [00:00<00:00, 178.31it/s]


Unnamed: 0,ligand_name,pdb_code,canonique_smiles,ligand_fpath,ligand_ftype,pocket_name,pocket_com
0,AHF,1g35,COC(=O)c1ccc(CN2[C@H](COc3ccccc3)[C@H](O)[C@@H...,./Inputs/Dev/RenamedPDBBind/1g35_ligand_AHF.mol2,.mol2,pocket_1,"12.492,23.016,5.769"
1,AHF,1g35,COC(=O)c1ccc(CN2[C@H](COc3ccccc3)[C@H](O)[C@@H...,./Inputs/Dev/RenamedPDBBind/1g35_ligand_AHF.mol2,.mol2,pocket_2,"26.141,16.353,-7.039"
2,AHF,1g35,COC(=O)c1ccc(CN2[C@H](COc3ccccc3)[C@H](O)[C@@H...,./Inputs/Dev/RenamedPDBBind/1g35_ligand_AHF.mol2,.mol2,pocket_3,"7.319,39.411,11.192"
3,AHF,1g35,COC(=O)c1ccc(CN2[C@H](COc3ccccc3)[C@H](O)[C@@H...,./Inputs/Dev/RenamedPDBBind/1g35_ligand_AHF.mol2,.mol2,pocket_4,"15.226,34.484,-10.408"


### Construct dataset

#### <i>Prepare nci dataframe (will be moved to sx.nciyes.process.ipynb)</i>

In [11]:
nci_df=pd.read_csv(nci_fpath, index_col = 0)


In [12]:
import torch
torch.set_num_threads(1)
from torch_geometric.data import Dataset
from sx_utils import sx_construct_data_from_graph_gvp
import rdkit.Chem as Chem    # conda install rdkit -c rdkit if import failure.
from feature_utils import extract_torchdrug_feature_from_mol
from sx_feature_utils import sx_extract_torchdrug_feature_from_mol
from sx_new_utils import sx_ligand_dedocking
from sx_new_utils import sx_get_nci_matrix

#### <i>Definition of dataset (will be moved to definition block)</i>

In [13]:
class MyDataset_VS(Dataset):
    def __init__(self, root, data=None, protein_dict=None, proteinMode=0, compoundMode=1,
                pocket_radius=20, shake_nodes=None, 
                transform=None, pre_transform=None, pre_filter=None, generate_3D_conf = False,
                protein_res_full_id_dict=None, nci_df=None 
                ):
        self.data = data
        self.protein_dict = protein_dict
        super().__init__(root, transform, pre_transform, pre_filter)
        print(self.processed_paths)
        self.data = torch.load(self.processed_paths[0])
        self.protein_dict = torch.load(self.processed_paths[1])
        self.nci_df=nci_df
        self.protein_res_full_id_dict = protein_res_full_id_dict
        
        self.proteinMode = proteinMode
        self.pocket_radius = pocket_radius
        self.compoundMode = compoundMode
        self.shake_nodes = shake_nodes
        self.generate_3D_conf = generate_3D_conf
        #self.printflag = True
    @property
    def processed_file_names(self):
        return ['data.pt', 'protein.pt']

    def process(self):
        torch.save(self.data, self.processed_paths[0])
        torch.save(self.protein_dict, self.processed_paths[1])

    def len(self):
        return len(self.data)

    def get(self, idx):
        line = self.data.iloc[idx]
        canonique_smiles = line['canonique_smiles']
        pocket_com = line['pocket_com']
        pocket_com = np.array(pocket_com.split(",")).astype(float) if type(pocket_com) == str else pocket_com
        pocket_com = pocket_com.reshape((1, 3))
        use_whole_protein = line['use_whole_protein'] if "use_whole_protein" in line.index else False
        protein_name = line['pdb_code']
        protein_node_xyz, protein_seq, protein_node_s, protein_node_v, protein_edge_index, protein_edge_s, protein_edge_v = self.protein_dict[protein_name]
        protein_res_full_id = self.protein_res_full_id_dict[protein_name]
        ligand_fpath = line['ligand_fpath']
        ligand_ftype = line['ligand_ftype']
        ligand_name = line['ligand_name']
        
        if ligand_ftype == ".mol2":
            mol = Chem.MolFromMol2File(ligand_fpath)
        elif ligand_ftype == ".mol":
            mol = Chem.MolFromMolFile(ligand_fpath)
        # mol dedocking before computing features
        mol = sx_ligand_dedocking(mol, self.generate_3D_conf)
        mol.Compute2DCoords()  
        coords, compound_node_features, input_atom_edge_list, input_atom_edge_attr_list, pair_dis_distribution = sx_extract_torchdrug_feature_from_mol(mol, has_LAS_mask=True)
        '''
        except:
            print(f"something wrong with {ligand_name} (pair_id:{pair_id}, smiles: {canonique_smiles}), to prevent this stops our screening, we repalce it with a placeholder smiles 'CCC'")
            canonique_smiles = 'CCC'
            mol = Chem.MolFromSmiles(canonique_smiles)
            mol.Compute2DCoords()
            coords, compound_node_features, input_atom_edge_list, input_atom_edge_attr_list, pair_dis_distribution = sx_extract_torchdrug_feature_from_mol(mol, has_LAS_mask=True)
        '''
        # y is distance map, instead of contact map.
        data, input_node_list, keepNode = sx_construct_data_from_graph_gvp(protein_node_xyz, protein_seq, protein_node_s, 
                              protein_node_v, protein_edge_index, protein_edge_s, protein_edge_v,
                              coords, compound_node_features, input_atom_edge_list, input_atom_edge_attr_list,
                              pocket_radius=self.pocket_radius, use_whole_protein=use_whole_protein, includeDisMap=True,
                              use_compound_com_as_pocket=False, chosen_pocket_com=pocket_com, compoundMode=self.compoundMode)
        data.compound_pair = pair_dis_distribution.reshape(-1, 16)
        res_full_id = [_full_id for (_full_id, _keep) in zip(protein_res_full_id, keepNode) if _keep]
        atom_names = [_atom.GetPropsAsDict()["_TriposAtomName"] for _atom in mol.GetAtoms()]
        data.nci_matrix = sx_get_nci_matrix(protein_name, ligand_name, res_full_id, atom_names, self.nci_df)
        
        #data.nci_matrix = get_nci_infos(protein_name, ligand_name, pair_id)
        print(data)
        #print("NCINums", np.sum(data.nci_matrix))
        return data


In [14]:
dataset_path = f"{main_path}/dataset/"
os.system(f"rm -r {dataset_path}")
os.system(f"mkdir -p {dataset_path}")
dataset = MyDataset_VS(dataset_path, data=info, protein_dict=protein_dict, protein_res_full_id_dict=protein_res_full_id_dict, nci_df=nci_df)

['Outputs/Dev/08040750-MultiProteins/dataset/processed/data.pt', 'Outputs/Dev/08040750-MultiProteins/dataset/processed/protein.pt']


rm: cannot remove './Outputs/Dev/08040750-MultiProteins//dataset/': No such file or directory
Processing...
Done!


### Now its turn!

In [15]:
import logging
from torch_geometric.loader import DataLoader
from tqdm import tqdm    # pip install tqdm if fails.
from model import get_model

dataset_path = f"{main_path}/dataset/"
os.system(f"rm -r {dataset_path}")
os.system(f"mkdir -p {dataset_path}")
dataset = MyDataset_VS(dataset_path, data=info, protein_dict=protein_dict, protein_res_full_id_dict=protein_res_full_id_dict, nci_df=nci_df)

['Outputs/Dev/08040750-MultiProteins/dataset/processed/data.pt', 'Outputs/Dev/08040750-MultiProteins/dataset/processed/protein.pt']


Processing...
Done!


In [16]:
bias_list = []
batched_energy_list = []

batch_size = 13
device = 'cuda:4' if torch.cuda.is_available() else 'cpu'
# device= 'cpu'
logging.basicConfig(level=logging.INFO)
model = get_model(0, logging, device)
# modelFile = "../saved_models/re_dock.pt"
# self-dock model
modelFile = "../saved_models/self_dock.pt"

model.load_state_dict(torch.load(modelFile, map_location=device))
_ = model.eval()

data_loader = DataLoader(dataset, batch_size=batch_size, follow_batch=['x', 'y', 'compound_pair'], shuffle=False, num_workers=8)
affinity_pred_list = []
y_pred_list = []
for data in tqdm(data_loader):
    data = data.to(device)
    y_pred, affinity_pred = model(data)
    affinity_pred_list.append(affinity_pred.detach().cpu())
    if False:
        # we don't need to save the predicted distance map in HTVS setting.
        for i in range(data.y_batch.max() + 1):
            y_pred_list.append((y_pred[data['y_batch'] == i]).detach().cpu())
affinity_pred_list = torch.cat(affinity_pred_list)

07:50:32   5 stack, readout2, pred dis map add self attention and GVP embed, compound model GIN


  0%|          | 0/1 [00:00<?, ?it/s]

HeteroData(
  dis_map=[6110],
  node_xyz=[130, 3],
  coords=[47, 3],
  y=[6110],
  seq=[130],
  compound_pair=[2209, 16],
  nci_matrix=[130, 47],
  [1mprotein[0m={
    node_s=[130, 6],
    node_v=[130, 3, 3]
  },
  [1mcompound[0m={ x=[47, 56] },
  [1m(protein, p2p, protein)[0m={
    edge_index=[2, 3310],
    edge_s=[3310, 32],
    edge_v=[3310, 1, 3]
  },
  [1m(compound, c2c, compound)[0m={
    edge_index=[2, 102],
    edge_weight=[102],
    edge_attr=[102, 19]
  }
)
HeteroData(
  dis_map=[3901],
  node_xyz=[83, 3],
  coords=[47, 3],
  y=[3901],
  seq=[83],
  compound_pair=[2209, 16],
  nci_matrix=[83, 47],
  [1mprotein[0m={
    node_s=[83, 6],
    node_v=[83, 3, 3]
  },
  [1mcompound[0m={ x=[47, 56] },
  [1m(protein, p2p, protein)[0m={
    edge_index=[2, 2076],
    edge_s=[2076, 32],
    edge_v=[2076, 1, 3]
  },
  [1m(compound, c2c, compound)[0m={
    edge_index=[2, 102],
    edge_weight=[102],
    edge_attr=[102, 19]
  }
)
HeteroData(
  dis_map=[4324],
  node_xyz=[92,

100%|██████████| 1/1 [00:01<00:00,  1.92s/it]

Zzzzsdgesfgaersghsdrghaerdsgegsredgasdhsdfgs efhad torch.Size([4, 130, 47, 128])
fadsldgkdq;sjaf ]]]]]] 0 <class 'numpy.ndarray'>
ldsjvzsdjga (130, 47)
fadsldgkdq;sjaf ]]]]]] 1 <class 'numpy.ndarray'>
ldsjvzsdjga (83, 47)
fadsldgkdq;sjaf ]]]]]] 2 <class 'numpy.ndarray'>
ldsjvzsdjga (92, 47)
fadsldgkdq;sjaf ]]]]]] 3 <class 'numpy.ndarray'>
ldsjvzsdjga (83, 47)





In [17]:
info = dataset.data
info['affinity'] = affinity_pred_list

In [18]:
chosen = info.loc[info.groupby(['protein_name', 'smiles'],sort=False)['affinity'].agg('idxmax')].reset_index()
chosen.to_csv(f"{pre}/TBAff_{target_name}_{pdb_code}_TBChosen.csv")

KeyError: 'protein_name'

In [None]:
info.to_csv(f"{pre}/result_info.csv")
if given_pocket:
    info_right_pocket = info[info["pocket_name"]==right_pocket]
    info_right_pocket.to_csv(f"{pre}/result_info_rightpocket.csv")

## Related-Atoms Process

In [None]:
use_related = False
if use_related:
    energy_list = []
    for _batch in batched_energy_list:
        for item in _batch:
            energy_list.append(item)
            
    related_atoms_list = []
    num_atoms_list = []
    for p in info.iterrows():
        sm = p[1]["smiles"]
        #print(sm)
        smiles = get_canonical_smiles(sm)
        #print(smiles)
        mol = Chem.MolFromSmiles(smiles)

        related = []
        for _atom in mol.GetAtoms():
            if _atom.GetAtomicNum() in [7, 8]:
                related.append(_atom.GetIdx())
        related_atoms_list.append(related)
        num_atoms_list.append(mol.GetNumHeavyAtoms())
    
    colsum_energy_list = [i.sum(axis=0) for i in energy_list]
    related_energy_list = [colsum_energy_list[i][related_atoms_list[i]].sum() for i in range(len(colsum_energy_list))]
    total_energy_list = [i.sum() for i in energy_list]

    related_energy_tensor = torch.tensor(related_energy_list)
    total_energy_tensor = torch.tensor(total_energy_list)
    nonrelated_energy_tensor = total_energy_tensor - related_energy_tensor

    bias = bias_list[0]
    outleaky = torch.nn.LeakyReLU()

    related_affinity = outleaky(bias + related_energy_tensor)
    nonrelated_affinity = outleaky(bias + nonrelated_energy_tensor)
    
    info['related_affinity'] = related_affinity
    info['nonrelated_affinity'] = nonrelated_affinity
    
    related_chosen = info.loc[info.groupby(['protein_name', 'smiles'],sort=False)['related_affinity'].agg('idxmax')].reset_index()
    nonrelated_chosen = info.loc[info.groupby(['protein_name', 'smiles'],sort=False)['nonrelated_affinity'].agg('idxmax')].reset_index()
    related_chosen.to_csv(f"{pre}/result_chosen_related.csv")

# 以下暂时不用

In [None]:
from generation_utils import get_LAS_distance_constraint_mask, get_info_pred_distance, write_with_new_coords
# pick one with affinity greater than 7.
chosen = info.loc[info.groupby(['protein_name', 'smiles'],sort=False)['affinity'].agg('idxmax')].reset_index()
chosen = chosen.query("affinity > 7").reset_index(drop=True)
line = chosen.iloc[0]
idx = line['index']
one_data = dataset[idx]
data_with_batch_info = next(iter(DataLoader(dataset[idx:idx+1], batch_size=1, 
                         follow_batch=['x', 'y', 'compound_pair'], shuffle=False, num_workers=1)))
y_pred, affinity_pred = model(data_with_batch_info)

coords = one_data.coords.to(device)
protein_nodes_xyz = one_data.node_xyz.to(device)
n_compound = coords.shape[0]
n_protein = protein_nodes_xyz.shape[0]
y_pred = y_pred.reshape(n_protein, n_compound).to(device).detach()
y = one_data.dis_map.reshape(n_protein, n_compound).to(device)
compound_pair_dis_constraint = torch.cdist(coords, coords)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper__index_select)

In [None]:
smiles = line['smiles']
print(smiles)
mol = Chem.MolFromSmiles(smiles)
mol.Compute2DCoords()
LAS_distance_constraint_mask = get_LAS_distance_constraint_mask(mol).bool()
info = get_info_pred_distance(coords, y_pred, protein_nodes_xyz, compound_pair_dis_constraint, 
                              LAS_distance_constraint_mask=LAS_distance_constraint_mask,
                              n_repeat=1, show_progress=False)
toFile = f'{base_pre}/one_tankbind.sdf'
new_coords = info.sort_values("loss")['coords'].iloc[0].astype(np.double)
write_with_new_coords(mol, new_coords, toFile)

In [None]:
import nglview   # conda install nglview -c conda-forge if import failure

proteinName = "6dlo"
proteinFile = f"{base_pre}/{proteinName}.pdb"
view = nglview.show_file(nglview.FileStructure(proteinFile), default=False)
view.add_representation('cartoon', selection='protein', color='white')

predictedFile = f'{base_pre}/one_tankbind.sdf'
rdkit = view.add_component(nglview.FileStructure(predictedFile), default=False)
rdkit.add_ball_and_stick(color='red')
view

In [None]:
view.render_image()

In [None]:
view._display_image()