# Data Preparation

In this notebook we are going to load the PDBBind dataset and prepare it for the training phase.

## 1. Download dataset
The PDBBind dataset can be found in [pdbbind 2016 datasets](http://www.pdbbind.org.cn). You will be asked to register (or log in) under a license agreement. Then, go to the CASFtab and download the CASF-2016 file. This file includes three folders: general-except-refined, refined and core. The general set contains 13308 protein−ligand binding complexes in total , the refined set contains 4057 complexes and the core set contains 290 complexes. The refined set contains less complex-sized data. Both the general and refined sets can be used for training and validation. The core set will be used for testing. Notice that you can also download the dataset by going to the Download tab and scrolling down to the PDBbind v2016 sign. Notice, however, that the core set is not included in this tab.

In [15]:
import numpy as np
import pandas as pd
import random
import os

## 2. Extract csv from index files in pdbbind

We extract the names of the folders containing the different complexes and store them in a csv file.

In [1]:
# path to the extracted PDBbind dataset
path = './'

In [2]:
%%bash -s $path --out missing

path=$1

echo 'pdbid,-logKd/Ki' > affinity_data.csv
cat $path/general-set-except-refined/index/INDEX_general_PL_data.2016 | while read l1 l2 l3 l4 l5; do
    if [[ ! $l1 =~ "#" ]]; then
        echo $l1,$l4
    fi
done >> affinity_data.csv


# Find affinities without structural data (i.e. with missing directories)

cut -f 1 -d ',' affinity_data.csv | tail -n +2 | while read l;
    do if [ ! -e $path/general-set-except-refined/$l ] && [ ! -e $path/refined-set/$l ]; then 
        echo $l;
    fi
done

In [3]:
missing = set(missing.split())
len(missing)

1

In [6]:
affinity_data = pd.read_csv('affinity_data.csv', comment='#')
affinity_data = affinity_data[~np.in1d(affinity_data['pdbid'], list(missing))]
affinity_data.head()

Unnamed: 0,pdbid,-logKd/Ki
0,3zzf,0.4
1,3gww,0.45
2,1w8l,0.49
3,3fqa,0.49
4,1zsb,0.6


In [7]:
# Check for NaNs
affinity_data['-logKd/Ki'].isnull().any()

False

We separate the refined, general and core sets.

In [10]:
core_set = set([ f.path[11:] for f in os.scandir('./core-set') if f.is_dir() ])#set(core_set)

refined_set = ! grep -v '#' $path/general-set-except-refined/index/INDEX_refined_data.2016 | cut -f 1 -d ' '
refined_set = set(refined_set)

general_set = set(affinity_data['pdbid'])


assert core_set & refined_set == core_set
assert refined_set & general_set == refined_set

len(general_set), len(refined_set), len(core_set)

(13307, 4057, 285)

In [11]:
affinity_data.loc[np.in1d(affinity_data['pdbid'], list(general_set)), 'set'] = 'general'

affinity_data.loc[np.in1d(affinity_data['pdbid'], list(refined_set)), 'set'] = 'refined'

affinity_data.loc[np.in1d(affinity_data['pdbid'], list(core_set)), 'set'] = 'core'

affinity_data.head()

Unnamed: 0,pdbid,-logKd/Ki,set
0,3zzf,0.4,general
1,3gww,0.45,general
2,1w8l,0.49,general
3,3fqa,0.49,general
4,1zsb,0.6,general


## 3. Separate training, validation and test sets

For the general and refined sets, we compute binding affinity quintiles of each set independently, then sample 10% of the data from each quintile to form each
validation set for general and refined, with the remaining data kept for the respective general and refined training sets. As the result of our protocol, the general and refined sets are partitioned into general-train, general-val, refined-train, and refined-val.

In [12]:
general = affinity_data[affinity_data.set=='general']
refined = affinity_data[affinity_data.set=='refined']
core = affinity_data[affinity_data.set=='core']

In [16]:
# General set
qs = np.quantile(general['-logKd/Ki'], np.arange(0,1.1,0.1)) #Get quantiles
qs[0]-=0.01
idxs_train = []
idxs_val = []
for i in range(len(qs)-1):
    q1 = qs[i]
    q2 = qs[i+1]
    gen_q = general[(general['-logKd/Ki']>q1) & (general['-logKd/Ki']<=q2)].index.values.tolist()
    size_train = int(0.9*len(gen_q)) # 90% training, 10% test
    idx_train = random.sample(gen_q,size_train)
    idx_val = [idx for idx in gen_q if idx not in idx_train]
    idxs_train+=idx_train
    idxs_val+=idx_val
general_train = general.loc[idxs_train]
general_train.set = 'general_train'
general_val = general.loc[idxs_val]
general_val.set = 'general_val'

In [17]:
# Refined set
qs = np.quantile(refined['-logKd/Ki'], np.arange(0,1.1,0.1)) #Get quantiles
qs[0]-=0.01
idxs_train = []
idxs_val = []
for i in range(len(qs)-1):
    q1 = qs[i]
    q2 = qs[i+1]
    ref_q = refined[(refined['-logKd/Ki']>q1) & (refined['-logKd/Ki']<=q2)].index.values.tolist()
    size_train = int(0.9*len(ref_q)) # 90% training, 10% test
    idx_train = random.sample(ref_q,size_train)
    idx_val = [idx for idx in ref_q if idx not in idx_train]
    idxs_train+=idx_train
    idxs_val+=idx_val
refined_train = refined.loc[idxs_train]
refined_train.set = 'refined_train'
refined_val = refined.loc[idxs_val]
refined_val.set = 'refined_val'

In [18]:
affinity_data = pd.concat([general_train, general_val, refined_train, refined_val, core])
affinity_data

Unnamed: 0,pdbid,-logKd/Ki,set
823,4a4e,3.32,general_train
544,3nk8,3.00,general_train
358,4h4b,2.70,general_train
250,1gbq,2.46,general_train
709,3fty,3.21,general_train
...,...,...,...
13226,5dwr,11.22,core
13238,4f2w,11.30,core
13241,2x00,11.33,core
13269,3o9i,11.82,core


In [None]:
affinity_data.to_csv('../data/affinity_data_cleaned.csv', index=False)

## 4. Prepare complexes with chimera

We begin with the preprocessing of the data. All protein−ligand binding complexes are protonated and charges solved using UCSF Chimera22 with AMBER ff14SB23
for standard residues and AM1-BCC24 for nonstandard residues, the default settings for the program. No additional steps are taken for crystal structure data.

We store all the .pdb files in a csv.

In [19]:
%%bash 

find . -name *.pdb >> pbd_files.csv

Now we run the preprocessing step. This will create .mol2 a file for each .pdb file. Notice that this can take a while (36h for me).

In [20]:
# path to the extracted PDBbind dataset
path = './pbd_files.csv'

In [None]:
%%bash -s $path
path=$1

# get list of pdb files from stdin and iterate over them. each instance of this script appends
# its PID to the tmp.mol2 file in order to prevent race conditions, enabling this to be run with
# gnu parallel

        tmp_file=$$_tmp.mol2

        echo "my tmp file is ${tmp_file}"

cat $path | while read pdbfile; do

                echo ${pdbfile}
                mol2file=${pdbfile%pdb}mol2
                # NOTICED THAT SOME INPUTS seem to never finish chimera step
                echo -e "open ${pdbfile} \n addh \n addcharge \n write format mol2 0 $$_tmp.mol2 \n stop" | chimera --silent --nogui 
                # Do not use TIP3P atom types, pybel cannot read them
                sed 's/H\.t3p/H    /' ${tmp_file} | sed 's/O\.t3p/O\.3  /' > $mol2file


done
echo "finished processing"

## 5. Generate hdf5 from the csv and the original pdbbind dataset

We extract the features from each complex and store them all in a hdf file.

We consider only the heavy atoms from each biological structure and heteroatoms (e.g., oxygens from crystallized water molecules).
• Element type: one-hot encoding of B, C, N, O, P, S, Se, halogen, or metal.
• Atom hybridization (1, 2, or 3).
• Number of heavy atom bonds (i.e., heavy valence).
• Structural properties: bit vector (1 where present) encoding of hydrophobic, aromatic, acceptor, donor, ring.
• Partial charge.
• Molecule type to indicate protein atom versus ligand atom (−1 for protein, 1 for ligand).
• Van der Waals radius.

We use OpenBabel to do this (version 3.X)

We begin by relocating all the .mol2 files to the folder data/raw/

In [None]:
find . -name *ligand.mol2 >> ligand_mol2_files.csv
find . -name *pocket.mol2 >> pocket_mol2_files.csv

In [None]:
path='ligand_mol2_files.csv'

In [None]:
%%bash -s $path
path=$1
cat $path | while read pdbfile; do
        #echo ${pdbfile}
        cp -R ${pdbfile} '../data/raw_data/'
done
echo "finished processing"

In [None]:
path='pocket_mol2_files.csv'

In [None]:
%%bash -s $path
path=$1
cat $path | while read pdbfile; do
        #echo ${pdbfile}
        cp -R ${pdbfile} '../data/raw_data/'
done
echo "finished processing"

Now we run the script to extract the features.

In [None]:
import os
from tf_bio_data import Featurizer
import numpy as np
import h5py
import argparse
from openbabel import pybel as pb
import warnings
#from data_generator.atomfeat_util import read_pdb, rdkit_atom_features, rdkit_atom_coords
#from data_generator.chem_info import g_atom_vdw_ligand, g_atom_vdw_protein
import xml.etree.ElementTree as ET
from rdkit.Chem.rdmolfiles import MolFromMol2File
import rdkit
from rdkit import Chem
from rdkit.Chem import rdchem
from openbabel.pybel import Atom
import pandas as pd
from tqdm import tqdm
from glob import glob

ob_log_handler = pb.ob.OBMessageHandler()
ob_log_handler.SetOutputLevel(0)

# TODO: compute rdkit features and store them in the output hdf5 file
# TODO: instead of making a file for each split, squash into one?


# TODO: not sure setting these to defaults is a good idea...
parser = argparse.ArgumentParser()
parser.add_argument("--input-pdbbind", default="/workspace/data/raw_data/") ## Change this path for the path of your *ligand.mol2 files and *pocket.mol2
parser.add_argument("--input-docking", default="/workspace/data/raw_data/")
parser.add_argument("--use-docking", default=False, action="store_true") # No docking
parser.add_argument("--use-exp", default=True, action="store_true")
parser.add_argument("--output", default="/workspace/data/processed") # Path to store the processed hdf
parser.add_argument(
    "--metadata", default="/workspace/data/affinity_data_cleaned.csv")
args = parser.parse_args()


def parse_element_description(desc_file):
    element_info_dict = {}
    element_info_xml = ET.parse(desc_file)
    for element in element_info_xml.iter():
        if "comment" in element.attrib.keys():
            continue
        else:
            element_info_dict[int(element.attrib["number"])] = element.attrib

    return element_info_dict


def parse_mol_vdw(mol, element_dict):
    vdw_list = []

    if isinstance(mol, pb.Molecule):
        for atom in mol.atoms:
            # NOTE: to be consistent between featurization methods, throw out the hydrogens
            if int(atom.atomicnum) == 1:
                continue
            if int(atom.atomicnum) == 0:
                continue
            else:
                vdw_list.append(
                    float(element_dict[atom.atomicnum]["vdWRadius"]))

    elif isinstance(mol, rdkit.Chem.rdchem.Mol):
        for atom in mol.GetAtoms():
            # NOTE: to be consistent between featurization methods, throw out the hydrogens
            if int(atom.GetAtomicNum()) == 1:
                continue
            else:
                vdw_list.append(
                    float(element_dict[atom.GetAtomicNum()]["vdWRadius"]))
    else:
        raise RuntimeError("must provide a pybel mol or an RDKIT mol")

    return np.asarray(vdw_list)


def featurize_pybel_complex(ligand_mol, pocket_mol, name, dataset_name):

    featurizer = Featurizer()
    charge_idx = featurizer.FEATURE_NAMES.index('partialcharge')

    # get ligand features
    ligand_coords, ligand_features = featurizer.get_features(
        ligand_mol, molcode=1)

    # ensures that partial charge on all atoms is non-zero?
    if not (ligand_features[:, charge_idx] != 0).any():
        raise RuntimeError(
            "invalid charges for the ligand {} ({} set)".format(name, dataset_name))

    # get processed pocket features
    pocket_coords, pocket_features = featurizer.get_features(
        pocket_mol, molcode=-1)
    if not (pocket_features[:, charge_idx] != 0).any():
        raise RuntimeError(
            "invalid charges for the pocket {} ({} set)".format(name, dataset_name))

    # center the coordinates on the ligand coordinates
    centroid_ligand = ligand_coords.mean(axis=0)
    ligand_coords -= centroid_ligand

    pocket_coords -= centroid_ligand
    data = np.concatenate((np.concatenate((ligand_coords, pocket_coords)),
                           np.concatenate((ligand_features, pocket_features))), axis=1)

    return data


def main():

    affinity_data = pd.read_csv(args.metadata)

    element_dict = parse_element_description("./elements.xml")

    failure_dict = {"name": [], "partition": [], "set": [], "error": []}

    for dataset_name, data in tqdm(affinity_data.groupby('set')):
        print("found {} complexes in {} set".format(len(data), dataset_name))

        if not os.path.exists(args.output):
            os.makedirs(args.output)

        with h5py.File('%s/%s.hdf' % (args.output, dataset_name), 'w') as f:

            for idx, row in tqdm(data.iterrows(), total=data.shape[0]):

                name = row['pdbid']

                affinity = row['-logKd/Ki']

                #receptor_path = row['receptor_path']

                '''
                    here is where the ligand(s) for both the experimental structure and the docking data need to be loaded.
                    * In order to do this, need an input path for both the experimental data as well as the docking data
                    * For docking data:
                        > Need to know how many poses there are, potentially up to 10 but not always the case
                        > May not have ligand/pocket data for names, need to handle this possibility

                    ######################################################################################################



                            BREAK THE MAIN LOOP INTO TWO PARTS....PROCESS DOCKING and PROCESS CRYSTAL STRUCTURES



                    ######################################################################################################

                '''

                ############################## CREATE THE PDB GROUP ##################################################
                # this is here in order to ensure any dataset that is created has passed the quality check, i.e. no failed complexes enter the output file

                grp = f.create_group(str(name))
                grp.attrs['affinity'] = affinity
                pybel_grp = grp.create_group("pybel")
                processed_grp = pybel_grp.create_group("processed")

                ############################### PROCESS THE DOCKING DATA ###############################
                if args.use_docking:
                    # READ THE DOCKING LIGAND POSES

                    # pose_path_list = glob("{}/{}/{}_ligand_pose_*.pdb".format(args.input_docking, name, name))
                    pose_path_list = glob(
                        "{}/{}/{}_ligand_pose_*.mol2".format(args.input_docking, name, name))

                    # if there are poses to read then we will read them, otherwise skip to the crystal structure loop
                    if len(pose_path_list) > 0:

                        # READ THE DOCKING POCKET DATA

                        #docking_pocket_file = "{}/{}/{}_pocket.mol2".format(args.input_docking, name, name)
                        docking_pocket_file = receptor_path

                        if not os.path.exists(docking_pocket_file):
                            warnings.warn("{} does not exists...this is likely due to failure in chimera preprocessing step, skipping to next complex...".format(
                                docking_pocket_file))
                            # NOTE: not putting a continue here because there may be crystal structure data
                        else:

                            # some docking files are corrupt (have nans for coords) and pybel doesn't do a great job of handling that
                            with open(docking_pocket_file, 'r') as handle:
                                data = handle.read()
                                if "nan" in data:
                                    warnings.warn("{} contains corrupt data, nan's".format(
                                        docking_pocket_file))
                                    # continue #TODO: THIS MAY PREVENT THE CRYSTAL STRUCTURE DATA FROM BEING PROCESSED

                                else:

                                    pose_pocket_vdw = []

                                    try:
                                        #docking_pocket = next(pybel.readfile('pdb', docking_pocket_file))
                                        docking_pocket = next(pybel.readfile(
                                            'mol2', docking_pocket_file))
                                        pose_pocket_vdw = parse_mol_vdw(
                                            mol=docking_pocket, element_dict=element_dict)

                                    except StopIteration:
                                        error = "pybel failed to read {} docking pocket file".format(
                                            name)
                                        warnings.warn(error)
                                        failure_dict["name"].append(name), failure_dict["partition"].append(
                                            "docking"), failure_dict["set"].append(dataset_name), failure_dict["error"].append(error)

                                    # in some, albeit strange, cases the pocket consists purely of hydrogen, skip over these if that is the case
                                    if len(pose_pocket_vdw) < 1:
                                        error = "{} docking pocket contains no heavy atoms, unable to store vdw radii".format(
                                            name)
                                        warnings.warn(error)
                                        failure_dict["name"].append(name), failure_dict["partition"].append(
                                            "docking"), failure_dict["set"].append(dataset_name), failure_dict["error"].append(error)

                                    else:

                                        docking = processed_grp.create_group(
                                            "docking")
                                        for pose_path in pose_path_list:

                                            try:
                                                #pose_ligand = next(pybel.readfile('pdb', pose_path))
                                                pose_ligand = next(
                                                    pb.readfile('mol2', pose_path))
                                                # do not add the hydrogens! they were already added in chimera and it would reset the charges
                                            except:
                                                error = "no ligand for {} ({} set)".format(
                                                    name, dataset_name)
                                                warnings.warn(error)
                                                failure_dict["name"].append(name), failure_dict["partition"].append(
                                                    "docking"), failure_dict["set"].append(dataset_name), failure_dict["error"].append(error)
                                                continue  # TODO:THIS MAY PREVENT THE CRYSTAL STRUCTURE DATA FROM BEING PROCESSED

                                            # extract the van der waals radii for the ligand/pocket
                                            pose_ligand_vdw = parse_mol_vdw(
                                                mol=pose_ligand, element_dict=element_dict)

                                            # in case the ligand consists purely of hydrogen, skip over these if that is the case
                                            if len(pose_ligand_vdw) < 1:
                                                error = "{} ligand consists purely of hydrogen, no heavy atoms to featurize".format(
                                                    name)
                                                warnings.warn(error)
                                                failure_dict["name"].append(name), failure_dict["partition"].append(
                                                    "docking"), failure_dict["set"].append(dataset_name), failure_dict["error"].append(error)
                                                continue  # TODO: THIS MAY PREVENT THE CRYSTAL STRUCTURE DATA FROM BEING PROCESSED

                                            try:
                                                pose_data = featurize_pybel_complex(
                                                    ligand_mol=pose_ligand, pocket_mol=docking_pocket, name=name, dataset_name=dataset_name)
                                            except RuntimeError as error:
                                                failure_dict["name"].append(name), failure_dict["partition"].append(
                                                    "docking"), failure_dict["set"].append(dataset_name), failure_dict["error"].append(error)
                                                continue  # TODO:THIS MAY PREVENT THE CRYSTAL STRUCTURE DATA FROM BEING PROCESSED

                                            pose_ligand_pocket_vdw = np.concatenate([pose_ligand_vdw.reshape(-1),
                                                                                     pose_pocket_vdw.reshape(-1)], axis=0)

                                            # enforce a constraint that the number of atoms for which we have features is equal to number for which we have VDW radii
                                            assert pose_ligand_pocket_vdw.shape[0] == pose_data.shape[0]

                                            # CREATE THE DOCKING POSE GROUP
                                            #pose_idx = pose_path.split(".pdb")[0].split("_")[-1]
                                            pose_idx = pose_path.split(
                                                ".mol2")[0].split("_")[-1]
                                            pose_grp = docking.create_group(
                                                pose_idx)

                                            # Now that we have passed the try/except blocks, featurize and store the docking data
                                            pose_grp.attrs["van_der_waals"] = pose_ligand_pocket_vdw

                                            pose_dataset = pose_grp.create_dataset("data", data=pose_data,
                                                                                   shape=pose_data.shape, dtype='float32', compression='lzf')

                else:
                    error = "{} does not contain any pose data".format(name)
                    # tqdm.write(error)
                    # failure_dict["name"].append(name), failure_dict["partition"].append(
                    #    "docking"), failure_dict["set"].append(dataset_name), failure_dict["error"].append(error)

                ############################### PROCESS THE CRYSTAL STRUCTURE DATA ###############################

                if args.use_exp:
                    # BEGIN QUALITY CONTROL: do not create the dataset until data has been verified
                    try:
                        crystal_ligand = next(pb.readfile(
                            'mol2', '%s/%s_ligand.mol2' % (args.input_pdbbind, name)))

                    # do not add the hydrogens! they were already added in chimera and it would reset the charges
                    except:
                        error = "no ligand for {} ({} set)".format(
                            name, dataset_name)
                        warnings.warn(error)
                        failure_dict["name"].append(name), failure_dict["partition"].append(
                            "crystal"), failure_dict["set"].append(dataset_name), failure_dict["error"].append(error)
                        continue

                    try:
                        crystal_pocket = next(pb.readfile(
                            'mol2', '%s/%s_pocket.mol2' % (args.input_pdbbind, name)))

                    except:
                        error = "no pocket for {} ({} set)".format(
                            name, dataset_name)
                        warnings.warn(error)
                        failure_dict["name"].append(name), failure_dict["partition"].append(
                            "crystal"), failure_dict["set"].append(dataset_name), failure_dict["error"].append(error)
                        continue

                    # extract the van der waals radii for the ligand/pocket
                    crystal_ligand_vdw = parse_mol_vdw(
                        mol=crystal_ligand, element_dict=element_dict)

                    # in some, albeit strange, cases the pocket consists purely of hydrogen, skip over these if that is the case
                    if len(crystal_ligand_vdw) < 1:
                        error = "{} ligand consists purely of hydrogen, no heavy atoms to featurize".format(
                            name)
                        warnings.warn(error)
                        failure_dict["name"].append(name), failure_dict["partition"].append(
                            "crystal"), failure_dict["set"].append(dataset_name), failure_dict["error"].append(error)
                        continue

                    crystal_pocket_vdw = parse_mol_vdw(
                        mol=crystal_pocket, element_dict=element_dict)
                    # in some, albeit strange, cases the pocket consists purely of hydrogen, skip over these if that is the case
                    if len(crystal_pocket_vdw) < 1:
                        error = "{} pocket consists purely of hydrogen, no heavy atoms to featurize".format(
                            name)
                        warnings.warn(error)
                        failure_dict["name"].append(name), failure_dict["partition"].append(
                            "crystal"), failure_dict["set"].append(dataset_name), failure_dict["error"].append(error)
                        continue

                    crystal_ligand_pocket_vdw = np.concatenate(
                        [crystal_ligand_vdw.reshape(-1), crystal_pocket_vdw.reshape(-1)], axis=0)
                    try:
                        crystal_data = featurize_pybel_complex(
                            ligand_mol=crystal_ligand, pocket_mol=crystal_pocket, name=name, dataset_name=dataset_name)
                    except RuntimeError as error:
                        failure_dict["name"].append(name), failure_dict["partition"].append(
                            "crystal"), failure_dict["set"].append(dataset_name), failure_dict["error"].append(error)
                        continue

                    # enforce a constraint that the number of atoms for which we have features is equal to number for which we have VDW radii
                    assert crystal_ligand_pocket_vdw.shape[0] == crystal_data.shape[0]

                    # END QUALITY CONTROL: made it past the try/except blocks....now featurize the data and store into the .hdf file
                    crystal_grp = processed_grp.create_group("pdbbind")
                    crystal_grp.attrs["van_der_waals"] = crystal_ligand_pocket_vdw
                    crystal_dataset = crystal_grp.create_dataset("data", data=crystal_data,
                                                                 shape=crystal_data.shape, dtype='float32', compression='lzf')

    failure_df = pd.DataFrame(failure_dict)
    failure_df.to_csv(
        "{}/failure_summary.csv".format(args.output), index=False)


main()

## 6. Generate 3D representation for 3D CNN training
The 3D atomic representation is described as follows. The input volume dimension is N × N × N × C, where N is the voxel grid size in each axis and C is
the number of atomic features described in the previous section (19 in our experiment). We chose N=48. The volume size in each dimension is approximately 48 Å, where each voxel size is 1 Å, which is sufficient to cover the entire pocket region while minimizing the collisions between atoms. Each atom is assigned to at least one voxel or more, depending on its Van der Waals radius or the user-defined size. In the case of the collisions between atoms, we apply element-wise addition to the atom features. Once all atoms are voxelized, Gaussian blur with σ = 1 is applied to populate the atom features into neighboring voxels and to avoid too sparse of a representationin the 3D voxel grid.

In [None]:
import os
import sys
import shutil
import argparse
import csv
import h5py
import numpy as np
import scipy as sp
import scipy.ndimage



parser = argparse.ArgumentParser()
parser.add_argument(
    "--main-dir", default="/workspace/data/processed", help="main dataset directory") # Path of the hdf filed
parser.add_argument("--use-external", default=True,
                    help="whether external test file is used or not") # Always set to True
parser.add_argument("--input-file", default="pdbbind2016_core_test.hdf", help="input test HDF filename") # Change for general/refined and train/val
parser.add_argument("--output-file", default="core_test_3dnn.hdf",
                    help="output test HDF filename") # Name of the output file that it will create
args = parser.parse_args()


# do not change unless the hdf structure is changed
g_csv_header = ['ligand_id', 'file_prefix', 'label', 'train_test_split', 'atom_count',
                'xsize', 'ysize', 'zsize', 'p_atom_count1', 'p_atom_count2', 'p_xsize', 'p_ysize', 'p_zsize']

g_feat_tool_list = ['pybel', 'rdkit']
g_feat_tool_ind = 0

g_feat_type_list = ['raw', 'processed']
g_feat_type_ind = 1

g_feat_pdbbind_type_list = ['crystal', 'docking']  # for display
g_feat_pdbbind_type_list2 = ['pdbbind', 'docking']  # for reading hdf
g_feat_pdbbind_type_ind = 0

g_feat_data_str = 'data'


g_main_dir = "/workspace/data/processed"

g_target_dataset = "pdbbind2016"
g_target_trainval_type = "refined"
g_external_test = True  # default is False
if g_external_test:
    g_feat_suffix = "%s_%s_%s" % (
        g_feat_tool_list[g_feat_tool_ind], g_feat_type_list[g_feat_type_ind], g_feat_pdbbind_type_list[g_feat_pdbbind_type_ind])
else:
    g_feat_suffix = "%s_%s_%s_%s_%s" % (g_target_dataset, g_target_trainval_type,
                                        g_feat_tool_list[g_feat_tool_ind], g_feat_type_list[g_feat_type_ind], g_feat_pdbbind_type_list[g_feat_pdbbind_type_ind])

g_3D_relative_size = False
g_3D_size_angstrom = 48  # valid only when g_3D_relative_size = False
g_3D_size_dim = 48  # 48
g_3D_atom_radius = 1
g_3D_atom_radii = False
g_3D_sigma = 1
if g_feat_tool_ind == 0:
    g_3D_dim = [g_3D_size_dim, g_3D_size_dim, g_3D_size_dim, 19]
else:
    g_3D_dim = [g_3D_size_dim, g_3D_size_dim, g_3D_size_dim, 75]

size_angstrom = g_3D_size_angstrom
if g_3D_relative_size:
    size_angstrom = 0

if g_3D_atom_radii:
    g_3D_suffix = "%d_%d_radii_sigma%d_rot0" % (
        size_angstrom, g_3D_size_dim, g_3D_sigma)
else:
    g_3D_suffix = "%d_%d_radius%d_sigma%d_rot0" % (
        size_angstrom, g_3D_size_dim, g_3D_atom_radius, g_3D_sigma)

if g_external_test:
    g_input_test_hd_file = "core_test.hdf"
else:
    g_input_train_hd_file = "%s_%s_train.hdf" % (
        g_target_dataset, g_target_trainval_type)
    g_input_val_hd_file = "%s_%s_val.hdf" % (
        g_target_dataset, g_target_trainval_type)
    g_input_test_hd_file = "%s_core_test.hdf" % (g_target_dataset)

g_output_hd_compress = True
g_output_train_hd_file = "%s_%s_train.hdf" % (g_feat_suffix, g_3D_suffix)
g_output_val_hd_file = "%s_%s_val.hdf" % (g_feat_suffix, g_3D_suffix)
g_output_test_hd_file = "%s_%s_test.hdf" % (g_feat_suffix, g_3D_suffix)
g_output_csv = "%s_%s_info.csv" % (g_feat_suffix, g_3D_suffix)


# for argument setting
g_main_dir = args.main_dir
if args.use_external:
    g_input_test_hd_file = args.input_file
    g_output_test_hd_file = args.output_file
    g_output_csv = args.output_file[:-4] + ".csv"


def rotate_3D(input_data):
    rotation_angle = np.random.uniform() * 2 * np.pi
    cosval = np.cos(rotation_angle)
    sinval = np.sin(rotation_angle)
    rotation_matrix = np.array(
        [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]])
    #rotated_data = np.zeros(input_data.shape, dtype=np.float32)
    rotated_data = np.dot(input_data, rotation_matrix)
    return rotated_data


def get_3D_bound(xyz_array):
    xmin = min(xyz_array[:, 0])
    ymin = min(xyz_array[:, 1])
    zmin = min(xyz_array[:, 2])
    xmax = max(xyz_array[:, 0])
    ymax = max(xyz_array[:, 1])
    zmax = max(xyz_array[:, 2])
    return xmin, ymin, zmin, xmax, ymax, zmax


def get_3D_all(xyz, feat, vol_dim, xmin, ymin, zmin, xmax, ymax, zmax, atom_radius=1, atomtype_ind=-1, sigma=0):

    # initialize volume
    vol_data = np.zeros(
        (vol_dim[0], vol_dim[1], vol_dim[2], vol_dim[3]), dtype=np.float32)
    vol_tag = np.zeros((vol_dim[0], vol_dim[1], vol_dim[2]), dtype=np.int32)

    # voxel size (assum voxel size is the same in all axis
    vox_size = (zmax - zmin) / vol_dim[0]

    # assign xyz (only center)
    for ind in range(xyz.shape[0]):
        x = xyz[ind, 0]
        y = xyz[ind, 1]
        z = xyz[ind, 2]
        if x < xmin or x > xmax or y < ymin or y > ymax or z < zmin or z > zmax:
            continue

        cx = (x - xmin) / (xmax - xmin) * (vol_dim[2] - 1)
        cy = (y - ymin) / (ymax - ymin) * (vol_dim[1] - 1)
        cz = (z - zmin) / (zmax - zmin) * (vol_dim[0] - 1)

        vol_tag[int(cz), int(cy), int(cx)] += 1
        if vol_tag[int(cz), int(cy), int(cx)] == 1:
            vol_data[int(cz), int(cy), int(cx), :] = feat[ind, :]

    # assign xyz
    for ind in range(xyz.shape[0]):
        x = xyz[ind, 0]
        y = xyz[ind, 1]
        z = xyz[ind, 2]
        if x < xmin or x > xmax or y < ymin or y > ymax or z < zmin or z > zmax:
            continue

        # compute van der Waals radius and atomic density, use 1 if not available
        if atomtype_ind >= 0:
            vdw_radius = g_atom_vdw_ligand[feat[ind, atomtype_ind]]
            atom_radius = 1 + vdw_radius * vox_size

        # setup atom ranges
        cx = (x - xmin) / (xmax - xmin) * (vol_dim[2] - 1)
        cy = (y - ymin) / (ymax - ymin) * (vol_dim[1] - 1)
        cz = (z - zmin) / (zmax - zmin) * (vol_dim[0] - 1)

        vx_from = max(0, int(cx - atom_radius))
        vx_to = min(vol_dim[2] - 1, int(cx + atom_radius))
        vy_from = max(0, int(cy - atom_radius))
        vy_to = min(vol_dim[1] - 1, int(cy + atom_radius))
        vz_from = max(0, int(cz - atom_radius))
        vz_to = min(vol_dim[0] - 1, int(cz + atom_radius))

        # uniform density
        for vz in range(vz_from, vz_to + 1):
            for vy in range(vy_from, vy_to + 1):
                for vx in range(vx_from, vx_to + 1):
                    if vol_tag[vz, vy, vx] == 0:
                        vol_data[vz, vy, vx, :] = feat[ind, :]

    # gaussian filter
    if sigma > 0:
        for i in range(vol_data.shape[-1]):
            vol_data[:, :, :, i] = sp.ndimage.filters.gaussian_filter(
                vol_data[:, :, :, i], sigma=sigma, truncate=2)

    return vol_data, vol_tag


def get_3D_all2(xyz, feat, vol_dim, relative_size=True, size_angstrom=48, atom_radii=None, atom_radius=1, sigma=0):

    # get 3d bounding box
    xmin, ymin, zmin, xmax, ymax, zmax = get_3D_bound(xyz)

    # initialize volume
    vol_data = np.zeros(
        (vol_dim[0], vol_dim[1], vol_dim[2], vol_dim[3]), dtype=np.float32)

    if relative_size:
        # voxel size (assum voxel size is the same in all axis
        vox_size = float(zmax - zmin) / float(vol_dim[0])
    else:
        vox_size = float(size_angstrom) / float(vol_dim[0])
        xmid = (xmin + xmax) / 2.0
        ymid = (ymin + ymax) / 2.0
        zmid = (zmin + zmax) / 2.0
        xmin = xmid - (size_angstrom / 2)
        ymin = ymid - (size_angstrom / 2)
        zmin = zmid - (size_angstrom / 2)
        xmax = xmid + (size_angstrom / 2)
        ymax = ymid + (size_angstrom / 2)
        zmax = zmid + (size_angstrom / 2)
        vox_size2 = float(size_angstrom) / float(vol_dim[0])
        #print(vox_size, vox_size2)

    # assign each atom to voxels
    for ind in range(xyz.shape[0]):
        x = xyz[ind, 0]
        y = xyz[ind, 1]
        z = xyz[ind, 2]
        if x < xmin or x > xmax or y < ymin or y > ymax or z < zmin or z > zmax:
            continue

        # compute van der Waals radius and atomic density, use 1 if not available
        if not atom_radii is None:
            vdw_radius = atom_radii[ind]
            atom_radius = 1 + vdw_radius * vox_size

        cx = (x - xmin) / (xmax - xmin) * (vol_dim[2] - 1)
        cy = (y - ymin) / (ymax - ymin) * (vol_dim[1] - 1)
        cz = (z - zmin) / (zmax - zmin) * (vol_dim[0] - 1)

        vx_from = max(0, int(cx - atom_radius))
        vx_to = min(vol_dim[2] - 1, int(cx + atom_radius))
        vy_from = max(0, int(cy - atom_radius))
        vy_to = min(vol_dim[1] - 1, int(cy + atom_radius))
        vz_from = max(0, int(cz - atom_radius))
        vz_to = min(vol_dim[0] - 1, int(cz + atom_radius))

        for vz in range(vz_from, vz_to + 1):
            for vy in range(vy_from, vy_to + 1):
                for vx in range(vx_from, vx_to + 1):
                    vol_data[vz, vy, vx, :] += feat[ind, :]

    # gaussian filter
    if sigma > 0:
        for i in range(vol_data.shape[-1]):
            vol_data[:, :, :, i] = sp.ndimage.filters.gaussian_filter(
                vol_data[:, :, :, i], sigma=sigma, truncate=2)

    return vol_data


###############################################################################
# start the main script
g_prefix = ''

if g_external_test:
    input_test_hdf = h5py.File(os.path.join(
        g_main_dir, g_input_test_hd_file), 'r')
    output_test_hdf = h5py.File(os.path.join(
        g_main_dir, g_output_test_hd_file), 'w')
else:
    # open input hd5
    input_train_hdf = h5py.File(os.path.join(
        g_main_dir, g_input_train_hd_file), 'r')
    input_val_hdf = h5py.File(os.path.join(
        g_main_dir, g_input_val_hd_file), 'r')
    input_test_hdf = h5py.File(os.path.join(
        g_main_dir, g_input_test_hd_file), 'r')

    # create output hd5
    output_train_hdf = h5py.File(os.path.join(
        g_main_dir, g_output_train_hd_file), 'w')
    output_val_hdf = h5py.File(os.path.join(
        g_main_dir, g_output_val_hd_file), 'w')
    output_test_hdf = h5py.File(os.path.join(
        g_main_dir, g_output_test_hd_file), 'w')

# create output csv
output_csv_fp = open(os.path.join(g_main_dir, g_output_csv), 'w')
output_csv = csv.writer(output_csv_fp, delimiter=',')
output_csv.writerow(g_csv_header)


###############################################################################
# generate 3D for ligand and complex

feat_tool_str = g_feat_tool_list[g_feat_tool_ind]
feat_type_str = g_feat_type_list[g_feat_type_ind]
feat_pdbbind_str = g_feat_pdbbind_type_list2[g_feat_pdbbind_type_ind]

if g_external_test:
    input_hdfs = [input_test_hdf]
    output_hdfs = [output_test_hdf]
    output_prefixes = [g_output_test_hd_file[:-4]]
    traintest_splits = [2]
else:
    input_hdfs = [input_train_hdf, input_val_hdf, input_test_hdf]
    output_hdfs = [output_train_hdf, output_val_hdf, output_test_hdf]
    output_prefixes = [g_output_train_hd_file[:-4],
                       g_output_val_hd_file[:-4], g_output_test_hd_file[:-4]]
    traintest_splits = [0, 1, 2]

for input_hdf, output_hdf, output_prefix, split in zip(input_hdfs, output_hdfs, output_prefixes, traintest_splits):
    for lig_id in input_hdf.keys():
        # if len(g_prefix) > 0 and not lig_id.startswith(g_prefix):
        # g_prefixcontinue

        feat_tool_list = input_hdf[lig_id]
        if not feat_tool_str in feat_tool_list:
            continue
        feat_type_list = feat_tool_list[feat_tool_str]
        if not feat_type_str in feat_type_list:
            continue

        feat_pdbbind_list = feat_type_list[feat_type_str]
        if not feat_pdbbind_str in feat_pdbbind_list:
            continue

        print("processing %s" % lig_id)

        if g_feat_pdbbind_type_ind == 1:
            feat_data_0 = feat_pdbbind_list[feat_pdbbind_str]
            for n in range(1, 11):  # assuming pose1 to pose10
                if not str(n) in feat_data_0:
                    continue

                feat_data = feat_data_0[str(n)]
                input_data = feat_data[g_feat_data_str]
                input_radii = None
                if g_3D_atom_radii:
                    input_radii = feat_data.attrs['van_der_waals']
                input_affinity = input_hdf[lig_id].attrs['affinity']

                input_xyx = input_data[:, 0:3]
                input_feat = input_data[:, 3:]

                output_3d_data = get_3D_all2(input_xyx, input_feat, g_3D_dim, g_3D_relative_size,
                                             g_3D_size_angstrom, input_radii, g_3D_atom_radius, g_3D_sigma)
                print(input_data.shape, 'is converted into ',
                      output_3d_data.shape)

                lig_id_pose = lig_id + '_' + str(n)
                if g_output_hd_compress:
                    output_hdf.create_dataset(
                        lig_id_pose, data=output_3d_data, shape=output_3d_data.shape, dtype='float32', compression='lzf')
                else:
                    output_hdf.create_dataset(
                        lig_id_pose, data=output_3d_data, shape=output_3d_data.shape, dtype='float32')
                output_hdf[lig_id_pose].attrs['affinity'] = input_affinity

                lig_prefix = '%d/%s/%s' % (split, output_prefix, lig_id_pose)
                output_csv.writerow(
                    [lig_id_pose, lig_prefix, input_affinity, split, 0, 0, 0, 0, 0, 0, 0, 0, 0])
        else:
            feat_data = feat_pdbbind_list[feat_pdbbind_str]
            input_data = feat_data[g_feat_data_str]
            input_radii = None
            if g_3D_atom_radii:
                input_radii = feat_data.attrs['van_der_waals']
            input_affinity = input_hdf[lig_id].attrs['affinity']

            input_xyx = input_data[:, 0:3]
            input_feat = input_data[:, 3:]

            output_3d_data = get_3D_all2(input_xyx, input_feat, g_3D_dim, g_3D_relative_size,
                                         g_3D_size_angstrom, input_radii, g_3D_atom_radius, g_3D_sigma)
            print(input_data.shape, 'is converted into ', output_3d_data.shape)

            #dgroup = output_hdf.create_group(lig_id)
            if g_output_hd_compress:
                output_hdf.create_dataset(
                    lig_id, data=output_3d_data, shape=output_3d_data.shape, dtype='float32', compression='lzf')
            else:
                output_hdf.create_dataset(
                    lig_id, data=output_3d_data, shape=output_3d_data.shape, dtype='float32')
            output_hdf[lig_id].attrs['affinity'] = input_affinity

            lig_prefix = '%d/%s/%s' % (split, output_prefix, lig_id)
            output_csv.writerow(
                [lig_id, lig_prefix, input_affinity, split, 0, 0, 0, 0, 0, 0, 0, 0, 0])

output_csv_fp.close()
if g_external_test:
    output_test_hdf.close()
else:
    output_train_hdf.close()
    output_val_hdf.close()
    output_test_hdf.close()


Done!

The pdbbind2016_general_train.hdf and general_train_3dnn.csv will be the ones given to the training algorithm (this can be changed to the refined set). 