In [1]:
%load_ext autoreload
%autoreload 2

import typing
from typing import List, Iterable
import pickle
import numpy as np
import pandas as pd

import rdkit
from rdkit.Chem import AllChem as AllChem
from rdkit.SimDivFilters.rdSimDivPickers import MaxMinPicker

from sklearn.ensemble import RandomForestClassifier

In [2]:
def random_pick(fp_ls: List[rdkit.DataStructs.cDataStructs.ExplicitBitVect], 
                n: int, 
                seed: int = 42) -> List[int]:
    """Random diverse compound picking based on Rdkit MaxMinPicker"""
    picker = MaxMinPicker()
    return list(picker.LazyBitVectorPick(fp_ls, len(fp_ls), n, seed=seed))


class Indice():
    def __init__(self, size: int) -> None:
        self.unsampled = list(range(size))
        self.sampled = []
        
    def add(self, idxs: Iterable[int]) -> None:
        self.sampled = list(set(self.sampled + list(idxs)))
        self.unsampled = list(set(self.unsampled) - set(self.sampled))

# test
idxs = Indice(10)

idxs.add([1,2])
print(idxs.sampled)
print(idxs.unsampled)

idxs.add([1,2,8])
print(idxs.sampled)
print(idxs.unsampled)

idxs.add([1,2,8])
print(idxs.sampled)
print(idxs.unsampled)

[1, 2]
[0, 3, 4, 5, 6, 7, 8, 9]
[8, 1, 2]
[0, 3, 4, 5, 6, 7, 9]
[8, 1, 2]
[0, 3, 4, 5, 6, 7, 9]


In [3]:
SEED = 0
P_INIT = 0.15
P_ITER = 0.05
N_ITER = 4
P_EXPLOIT = 0.8
N_JOBS = 8
    
def iHTS(path: str):

    with open(path,"rb") as file:
        data = pickle.load(file)

    smiles_ls = data.get("smiles_ls")
    mol_ls = data.get("mol_ls")
    fp_ls = data.get("fp_ls")
    activity_ls = data.get("activity_ls")
    ds_ls = data.get("ds_ls")
    N_TOTAL = len(fp_ls)
    
    del data

    X = np.concatenate([np.array(ds_ls), np.array(fp_ls)], axis=1)
    y = np.array(activity_ls)
    

    # define model
    rf = RandomForestClassifier(n_estimators=1200, class_weight="balanced",
        max_features='log2',bootstrap=True,min_samples_split = 8,
        min_samples_leaf = 3, n_jobs = N_JOBS,random_state=SEED)
    
    
    clf = RandomForestClassifier(n_estimators=1200, class_weight="balanced",
    max_features='log2',bootstrap=True,min_samples_split = 8,
    min_samples_leaf = 3, n_jobs = 4,random_state=SEED)

    # select initial subset
    init_idx = random_pick(fp_ls, int(P_INIT*len(fp_ls)))
    idxs = Indice(len(fp_ls))
    idxs.add(init_idx)

    for it in range(N_ITER):
        print(it)
        
        # train supversied model
        # form dataset
        X_sampled = X[idxs.sampled]
        y_sampled = y[idxs.sampled]
        X_unsampled = X[idxs.unsampled]

        # fit & make prediction
        clf.fit(X_sampled, y_sampled)
        probs = clf.predict_proba(X_unsampled)[:,0]  # prob of being inactive

        # select next batch
        # add exploitation set
        n_exploit = int(P_ITER * N_TOTAL * P_EXPLOIT)
        idx_exploit = np.argsort(probs)[:n_exploit]  # sorted from low -> high
        idxs.add(np.array(idxs.unsampled)[idx_exploit])

        # add exploration set
        n_explore = int(P_ITER * N_TOTAL * (1-P_EXPLOIT))
        idx_explore = random_pick([fp_ls[i] for i in idxs.unsampled], n_explore)
        idxs.add(np.array(idxs.unsampled)[idx_explore])

    return idxs

In [4]:
idxs1 = iHTS("dataset/AID_628/data.pkl")

(63662, 1121) (63662,)
0
1
2
3


In [7]:
try:
    idxs2 = iHTS("dataset/test/AID_1259354/data.pkl")
except:
    pass

(77674, 1121) (77674,)
0


  array = numpy.asarray(array, order=order, dtype=dtype)


In [8]:
try:
    idxs3 = iHTS("dataset/test/AID_488969/data.pkl")
except:
    pass

(2166, 1121) (2166,)
0
1
2
3


In [9]:
try:
    idxs4 = iHTS("dataset/test/AID_598/data.pkl")
except:
    pass

(85210, 1121) (85210,)
0
1
2
3
