In [12]:
# %load ligandnet.py
import os
import joblib
import argparse
import time
import pandas as pd
import numpy as np
from ddt.utility import FeatureGenerator
import errno
import json
from collections import OrderedDict
from tqdm import tqdm
import argparse

In [35]:
class LigandNet(object):
    MODELS_DIR = os.path.join('models/files')
    
    def __init__(self):
        self.load_models()
        
    def load_models(self):
        #TODO: Avoid loading all the models
        # Read the best models
        with open('best_models.txt', 'r') as f:
            best_models = f.read().splitlines()[:50]
        
        self.uniprot_ids = [model_path[:6] for model_path in best_models]
        self.models = [joblib.load(os.path.join(self.MODELS_DIR, model_path)) for model_path in best_models]
            
    def get_features(self, input, input_type):
        # TODO: Add functionality for reading from a smi file containing a bulk of smiles
        ft = FeatureGenerator()
        if input_type == 'smiles':
            ft.load_smiles(input)
        else:
            ft.load_sdf(input)
        cmpd_id, features = ft.extract_tpatf()
        return cmpd_id, features.reshape(-1, 2692)
    
    # Get predictions
    def get_prediction(self, input, input_type, confidence_threshold=0.5):
        results = {}
        cmpd_id, features = self.get_features(input, input_type)
        cmpd_id = np.array(cmpd_id)
        for uniprot_id, model in tqdm(zip(self.uniprot_ids, self.models), total=703):
            pred = model.predict_proba(features)[:, 1]
            mask = pred >= confidence_threshold
            for _id, _pred in zip(cmpd_id[mask], pred[mask]):
                # Create a dictionary for compound if not exists
                if _id not in results.keys():
                    results[_id] = {}
                # Update the compound result dictionary
                results[_id].update({uniprot_id:_pred})
        return results

In [36]:
l = LigandNet()

In [49]:
r = l.get_prediction('CCCC', 'smiles', 0.005)

  7%|▋         | 50/703 [00:00<00:11, 54.57it/s]


In [50]:
r

{'Cmpd1': {'O00206': 0.01,
  'O00329': 0.12453546272950158,
  'O00398': 0.007955998,
  'O00519': 0.047452199945055215,
  'O00748': 0.013333333333333334,
  'O00767': 0.015485651370918459,
  'O14649': 0.0080336165,
  'O14672': 0.3962052124407,
  'O14920': 0.13783955337360168,
  'O15054': 0.06776798,
  'O15379': 0.026473444502453654,
  'O15496': 0.4388301674112895,
  'O15530': 0.2853474377955889,
  'O43318': 0.013770750231887178,
  'O60341': 0.19,
  'O60427': 0.47349631405766146,
  'O60502': 0.13518072481187046,
  'O60563': 0.006062961713391305,
  'O60725': 0.0075}}

In [52]:
for k, v in r.items():
    print(type(k))
    for _k, _v in v.items():
#         v[_k] = float(_/v)
        print(_k, _v)
        print(type(k), type(_v))

<class 'numpy.str_'>
O00206 0.01
<class 'numpy.str_'> <class 'numpy.float64'>
O00329 0.12453546272950158
<class 'numpy.str_'> <class 'numpy.float64'>
O00398 0.007955998
<class 'numpy.str_'> <class 'numpy.float32'>
O00519 0.047452199945055215
<class 'numpy.str_'> <class 'numpy.float64'>
O00748 0.013333333333333334
<class 'numpy.str_'> <class 'numpy.float64'>
O00767 0.015485651370918459
<class 'numpy.str_'> <class 'numpy.float64'>
O14649 0.0080336165
<class 'numpy.str_'> <class 'numpy.float32'>
O14672 0.3962052124407
<class 'numpy.str_'> <class 'numpy.float64'>
O14920 0.13783955337360168
<class 'numpy.str_'> <class 'numpy.float64'>
O15054 0.06776798
<class 'numpy.str_'> <class 'numpy.float32'>
O15379 0.026473444502453654
<class 'numpy.str_'> <class 'numpy.float64'>
O15496 0.4388301674112895
<class 'numpy.str_'> <class 'numpy.float64'>
O15530 0.2853474377955889
<class 'numpy.str_'> <class 'numpy.float64'>
O43318 0.013770750231887178
<class 'numpy.str_'> <class 'numpy.float64'>
O60341 0.19

In [14]:
if __name__ == "__main__":
    start = time.time()
    parser = argparse.ArgumentParser(
        description="Ligand activity prediction using LigandNet")
    parser.add_argument('--sdf', action='store',
                        dest='sdf', help='SDF file location')
    parser.add_argument('--smiles', action='store', type=str,
                        dest='smiles', help='SMILES')
#     parser.add_argument('--out', action='store', dest='out',
#                         required=False, help='Output directory')
    parser.add_argument('--confidence', action='store', dest='confidence', type=float,
                        default=0.50, help='Minimum confidence to consider for prediction. Default is 0.5')
    
    args = parser.parse_args()
    
    if not (args.smiles or args.sdf):
        parser.error('No input found. Provide --smiles or --sdf')
    
    print(f"Loading the LigandNet models ...")
    l = LigandNet()
    
    if args.sdf is not None:
        if not os.path.isfile(args.sdf):
            raise FileNotFoundError(errno.ENOENT, os.strerror(
                errno.ENOENT), args.sdf)

        results = l.get_prediction(args.sdf, 'sdf', args.confidence)
        print(results)

    if args.smiles is not None:
        results = l.get_prediction(args.smiles,'smiles', args.confidence)
        print(results)


usage: ipykernel_launcher.py [-h] [--sdf SDF] [--smiles SMILES]
                             [--confidence CONFIDENCE]
ipykernel_launcher.py: error: unrecognized arguments: -f /data/mhassan/.local/share/jupyter/runtime/kernel-bcbfbfd7-590f-4097-8e5f-4e28b9b031ea.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [15]:
import json

In [19]:
d = {'Cmpd1': {'O00206': 0.01, 'O00329': 0.12, 'O00398': 0.01, 'O00519': 0.05, 'O00748': 0.01, 'O00767': 0.02, 'O14649': 0.01, 'O14672': 0.4, 'O14920': 0.14, 'O15054': 0.07, 'O15379': 0.03, 'O15496': 0.44, 'O15530': 0.29, 'O43318': 0.01, 'O60341': 0.19, 'O60427': 0.47, 'O60502': 0.14, 'O60563': 0.01, 'O60725': 0.01, 'O60885': 0.05, 'O60911': 0.02, 'O95749': 0.01, 'O75173': 0.25, 'O75460': 0.02, 'O75762': 0.99, 'O75907': 0.35, 'O94953': 0.22, 'O95665': 0.02, 'O95822': 0.04, 'P00338': 0.31, 'P00374': 0.01, 'P08311': 0.29, 'P00492': 0.05, 'P00734': 0.04, 'P00740': 0.2, 'P00797': 0.05, 'P02708': 0.63, 'P02766': 0.01, 'P03951': 0.44, 'P04066': 0.15}}

In [22]:
json.dumps(d)

'{"Cmpd1": {"O00206": 0.01, "O00329": 0.12, "O00398": 0.01, "O00519": 0.05, "O00748": 0.01, "O00767": 0.02, "O14649": 0.01, "O14672": 0.4, "O14920": 0.14, "O15054": 0.07, "O15379": 0.03, "O15496": 0.44, "O15530": 0.29, "O43318": 0.01, "O60341": 0.19, "O60427": 0.47, "O60502": 0.14, "O60563": 0.01, "O60725": 0.01, "O60885": 0.05, "O60911": 0.02, "O95749": 0.01, "O75173": 0.25, "O75460": 0.02, "O75762": 0.99, "O75907": 0.35, "O94953": 0.22, "O95665": 0.02, "O95822": 0.04, "P00338": 0.31, "P00374": 0.01, "P08311": 0.29, "P00492": 0.05, "P00734": 0.04, "P00740": 0.2, "P00797": 0.05, "P02708": 0.63, "P02766": 0.01, "P03951": 0.44, "P04066": 0.15}}'