In [None]:
import sys
import pickle
import logging
import h5py 
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import multiprocessing as mp
from typing import Dict, List
from tqdm import tqdm
from deepchem.feat import Featurizer
from deepchem.utils.coordinate_box_utils import CoordinateBox
from deepchem.utils.rdkit_utils import load_molecule
from pathlib import Path
from deepchem.feat import RdkitGridFeaturizer, BindingPocketFeaturizer
from deepchem.utils.coordinate_box_utils import CoordinateBox
from deepchem.utils.rdkit_utils import load_molecule

### load the pdbbind 2019 paths

In [None]:
pdbbind_2019_path = Path("/p/lustre2/jones289/data/raw_data/v2019")
pdbbind_2019_subdirs = pdbbind_2019_path.glob("**/")

### instantiate the deepchem featurizers
* RdkitGridFeaturizer
* BindingPocketFeaturizer

In [None]:
rdkit_grid_feat = RdkitGridFeaturizer(feature_types=['ecfp', 
#                                                      'splif', # 'Atom' object has no attribute 'GetIndex'
#                                                      'sybyl', # not implemented
#                                                      'salt_bridge', # including this feature results in extremely large values (np.inf)
                                                     'charge', 
#                                                      'hbond', # this causes an index error
                                                     'pi_stack', # this feature may be equal to 0 much of the time
                                                     'cation_pi', # this feature may be equal to 0 much of the time
                                                    ],
                                      voxel_width=.5, sanitize=True)
rdkit_grid_feat

In [None]:
binding_pocket_feat = BindingPocketFeaturizer()
binding_pocket_feat

In [None]:
def boxes_to_atoms(coords: np.ndarray, boxes: List[CoordinateBox]
                  ) -> Dict[CoordinateBox, List[int]]:
    """Maps each box to a list of atoms in that box.
      Given the coordinates of a macromolecule, and a collection of boxes,
      returns a dictionary which maps boxes to the atom indices of the
      atoms in them.
      Parameters
      ----------
      coords: np.ndarray
        A numpy array of shape `(N, 3)`
      boxes: list
        List of `CoordinateBox` objects.
      Returns
      -------
      Dict[CoordinateBox, List[int]]
        A dictionary mapping `CoordinateBox` objects to lists of atom indices.
      """

    mapping = {}
    for box_ind, box in enumerate(boxes):
        box_atoms = []
        for atom_ind in range(len(coords)):
            atom = coords[atom_ind]
            if atom in box:
                box_atoms.append(atom_ind)
        mapping[box] = box_atoms
    return mapping

In [None]:
def compute_box(xyz):
    
    xyz = xyz.squeeze()
    x_min, x_max, y_min, y_max, z_min, z_max = xyz[:, 0].min(), \
                    xyz[:, 0].max(), xyz[:, 1].min(), xyz[:, 1].max(), \
                    xyz[:, 2].min(), xyz[:, 2].max()
    
    crystal_box = CoordinateBox(x_range=(x_min-1, x_max+1), \
                                y_range=(y_min-1, y_max+1), \
                                z_range=(z_min-1, z_max+1))

    return crystal_box

    
def featurize_complex_job(parent_dir, use_prot=False, use_pocket=True, verbose=False):
    
    assert use_prot != use_pocket
    
    pdbid = parent_dir.stem
    
    if use_prot:
        mol_path = parent_dir.with_name(pdbid) / f"{pdbid}_protein.pdb"
    if use_pocket:
        # the RDKitGridFeaturizer is choking on these file...and the error is being thrown by mdtraj..files are coming directly from pdbbind
        mol_path = parent_dir.with_name(pdbid) / f"{pdbid}_pocket.pdb"
        
    lig_path = parent_dir.with_name("ligands") / f"{pdbid}_ligand.pdb"
    
#     print(mol_path)
     
#     '''
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        if use_pocket:
            pocket_coords, pocket_mol = load_molecule(str(mol_path), add_hydrogens=False, calc_charges=False)
            crystal_box = compute_box(pocket_coords)
        
            try:
                rdkit_feats = rdkit_grid_feat._featurize((str(mol_path), str(lig_path)))
                bind_pocket_feats = binding_pocket_feat.featurize(str(mol_path), pockets=[crystal_box])
                #             feats = bind_pocket_feats
                feats = np.asarray([rdkit_feats, bind_pocket_feats])
                return pdbid, feats

            except (AttributeError, OSError, Exception) as e:
                print(pdbid, e)
#     '''

def process_data(pdbbind_2019_subdir_list):
    
    with mp.Pool(mp.cpu_count()) as pool:
        result_list = list(tqdm(pool.imap(featurize_complex_job, pdbbind_2019_subdir_list), 
                                total=len(pdbbind_2019_subdir_list)))
    
    return result_list
    
def dump_result_to_h5(result_list, output_path):
    assert output_path is not None
    with h5py.File(output_path, 'w') as f:
        for result in tqdm(result_list, desc="dumping output to hdf5 file..."):
            pdbid = result[0]
            binding_pocket_feat = result[1]
            
            print(binding_pocket_feat.shape)
#             print(pdbid)
#             '''
            affinity = pdbbind_2019_df.loc[pdbbind_2019_df['pdbid'] == pdbid]['-logKd/Ki']
            pdbid_group = f.require_group(pdbid)
            pdbid_group.attrs['-logKd/Ki'] = affinity
            pdbid_group.require_dataset("BindingPocketFeaturizer",
                                        data=result[1], 
                                        shape=result[1].shape, 
                                        dtype=np.float32)
#             '''
    
    print("done.")

In [None]:
pdbbind_2019_df = pd.read_csv("/p/lustre2/jones289/data/pdbbind/metadata/pdbbind_v2019_metadata.csv")
print(pdbbind_2019_df.head())
pdbbind_2019_path = Path("/p/lustre2/jones289/data/raw_data/v2019")
pdbbind_2019_subdirs = pdbbind_2019_path.glob("*/")
pdbbind_2019_subdir_list = list(pdbbind_2019_subdirs)
pdbbind_2019_subdir_list = [x for x in pdbbind_2019_subdir_list if x.name in pdbbind_2019_df['name'].values.tolist()]
print(len(pdbbind_2019_subdir_list))

In [None]:
result_list = process_data(pdbbind_2019_subdir_list)

In [None]:
result_list[0][1][1]

In [None]:
dump_result_to_h5(result_list, "deepchem_baseline_feats.h5")

In [None]:
with h5py.File("deepchem_baseline_feats.h5", 'r') as f:
    print(list(f['1a30']['BindingPocketFeaturizer']))

In [None]:
refined_df = pd.read_csv('/g/g13/jones289/workspace/hd-cuda-master/datasets/pdbbind_fingerprints/' \
                         + 'pdbbind_2016_fps_new/v_2016_refined_pdbid_list.csv', index_col=0)
refined_df

In [None]:
def visualize_label_dist(df, bind_thresh=4):
    non_bind_df = df
    f, ax = plt.subplots(1,1)
    non_bind_df['label'] = refined_df.apply(
                lambda x: int(x['-logKd/Ki'] > bind_thresh), axis=1)
    ax.set_title(f"no-bind (0) and bind (1) counts with thresh={bind_thresh}")
    sns.countplot(non_bind_df['label'], ax=ax)

In [None]:
sns.distplot(refined_df['-logKd/Ki'])

for thresh in [2,4,6,8, 10]:
    visualize_label_dist(refined_df, bind_thresh=thresh)

In [None]:
id_list = [x for x in data.keys()]
id_list

In [None]:
data_values = np.asarray([x for x in data.values()]).squeeze()

In [None]:
data_values.shape

In [None]:
len(id_list)

In [None]:
core_set_df = pd.read_csv("/p/lustre2/jones289/data" \
                          + "/pdbbind/metadata/" \
                          + "pdbbind_2016_core_test.csv")
core_set_df

In [None]:
core_test_dict = {key: {'data': value, 
                        '-logKd/Ki': core_set_df[core_set_df['pdbid'] == key]['-logKd/Ki'].values} 
                  for key,value in data.items() if key in core_set_df['pdbid'].values}
core_test_dict

In [None]:
len(core_test_dict)

In [None]:
pdbbind_2016_df = pd.read_csv('/p/lustre2/jones289/data/pdbbind/metadata/pdbbind_2016_train_val_test.csv')
refined_set_df = pdbbind_2016_df[pdbbind_2016_df.apply(lambda x: x['pdbbind_set'] == 'refined', axis=1)]
refined_set_df

In [None]:
refined_no_core_set_df = refined_set_df[refined_set_df['pdbid'].apply(lambda x: x not in core_test_dict.keys())]
refined_no_core_set_df

In [None]:
refined_train_dict = {key:value for key,value in data.items() if key in refined_no_core_set_df['pdbid'].values}
refined_train_dict

refined_train_dict = {key: {'data': value, 
                        '-logKd/Ki': refined_no_core_set_df[refined_no_core_set_df['pdbid'] == key]['-logKd/Ki'].values} 
                  for key,value in data.items() if key in refined_no_core_set_df['pdbid'].values}
len(refined_train_dict)

In [None]:
def convert_dict_to_tup(mydict, binary_only=False, upper_thresh=8, lower_thresh=6):
    
    pdbids = list(mydict.keys())
    data = [value['data'].flatten() for key,value in mydict.items()]
    labels = [value['-logKd/Ki'] for key,value in mydict.items()]
    
    class_labels = []
    for label in labels:
        if label > upper_thresh:
            class_labels.append(1)
        elif label < lower_thresh:
            class_labels.append(0)
        else:
            class_labels.append(2)
    if not binary_only:
            
        return pdbids, data, class_labels
    else:
        
        binary_pdbids =[]
        binary_data = []
        binary_labels = []
        
        for pdbid, el, class_label in zip(pdbids, data, class_labels):
            if class_label == 2:
                pass
            else:
#                 print(pdbid, class_label)
                binary_pdbids.append(pdbid)
                binary_data.append(el)
                binary_labels.append(class_label)
        
#         print(type(binary_labels), np.unique(binary_labels))
        return binary_pdbids, binary_data, binary_labels



In [None]:
def dump_dataset(core_data_dict, refined_data_dict):
    
    for upper, lower in ([(6,8), (4,4), (6,6)]):
        core_ids, core_data, core_labels = convert_dict_to_tup(core_data_dict, binary_only=True,
                                                      upper_thresh=upper, lower_thresh=lower)
        with open(f'deepchem_aa_prot_feats_core_2_class_{lower}_{upper}_thresh.pkl', 'wb') as handle:
            pickle.dump((core_data, core_labels), handle)
        
        
        refined_ids, refined_data, refined_labels = convert_dict_to_tup(refined_train_dict, binary_only=True,
                                                               upper_thresh=upper, lower_thresh=lower)
    
        with open(f'deepchem_aa_prot_feats_refined_2_class_{lower}_{upper}_thresh.pkl', 'wb') as handle:
            pickle.dump((refined_data, refined_labels), handle)


In [None]:
dump_dataset(core_test_dict, refined_train_dict)

In [None]:
def describe_feats(feats):
    voxels, aa_count_feat = feats
    for dim in range(voxels.shape[-1]):
        min_vox = voxels[:,:,:,dim].min()
        max_vox = voxels[:,:,:,dim].max()
        mean_vox = voxels[:,:,:,dim].mean()
        
        flat_voxels = voxels[:, :,:, dim].flatten()
        occ_rate = 100 * (flat_voxels[flat_voxels != 0].shape[0] / flat_voxels.shape[0])
        print(f"i: {dim}, min={min_vox}, max={max_vox}, mean_vox={mean_vox}, occupancy%: {occ_rate:0.4f}")
    
#         print(flat_voxels[flat_voxels != 0])
#         print(feats[0].shape)
#         print(voxels.shape)
    print(aa_count_feat)
    print(aa_count_feat.sum(axis=1))


In [None]:
describe_feats(feats)