In [1]:
%load_ext autoreload
%autoreload 2

import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import rdkit
from rdkit.Chem import AllChem as AllChem
from rdkit.SimDivFilters.rdSimDivPickers import MaxMinPicker

from utils import *

In [2]:
with open("AID_628/data.pkl","rb") as file:
    data = pickle.load(file)
    
smiles_ls = data.get("smiles_ls")
mol_ls = data.get("mol_ls")
fp_ls = data.get("fp_ls")
activity_ls = data.get("activity_ls")
ds_ls = data.get("ds_ls")
scaffold_dict = data.get("scaffold_dict")

del data

y = np.array(activity_ls)

SEED = 0
P_INIT = 0.15
P_ITER = 0.05
N_ITER = 4
N_TOTAL = len(y)



def initialize_scfs(scaffold_dict: dict) -> (list, list):
    
    # initialize scaffolds
    scaffold_ls = []
    idx_to_scf = []
    for scf_smiles in scaffold_dict:
        scaffold_ls.append(Scaffold(scf_smiles, scaffold_dict[scf_smiles], Beta_dist()))
        scf_obj = scaffold_ls[-1]
        idx_to_scf += [(scf_obj, comp) for comp in scf_obj.unsampled]

    # sort and prune list
    # resulting list will map compound index to corresponding scaffold object
    idx_to_scf = sorted(idx_to_scf, key=lambda x: x[1])
    idx_to_scf = list(zip(*idx_to_scf))[0]
    
    return scaffold_ls, idx_to_scf

scaffold_ls, idx_to_scf = initialize_scfs(scaffold_dict)

In [3]:
for idx in range(len(y)):
    scf = idx_to_scf[idx]
    scf.observe(idx, y[idx])
    
true_sorted_scf = sorted(scaffold_ls, key=lambda x: x.dist.ppf(.10))[::-1]

In [7]:
[s for s in true_sorted_scf[:20]]# if len(s.sampled.keys())>10][:20]

[O=c1[nH]c(=O)c2c(nc(SCCN3CCCCC3)n2Cc2ccccc2)[nH]1 | 6 / 6,
 O=S(=O)(Nc1ccc2[nH]c(CCN3CCCCC3)nc2c1)c1ccccc1 | 5 / 5,
 c1ccc(CC2CCCCC2)cc1 | 4 / 4,
 c1ccc(CNc2ncc(-c3ccc4c(c3)OCO4)[nH]2)cc1 | 4 / 4,
 O=C(Nc1ccccc1)c1cnc(N2CCNCC2)c2ccccc12 | 4 / 4,
 c1ccc2nc(NC3=NCN(CCCN4CCOCC4)CN3)ncc2c1 | 3 / 3,
 O=C(CCN1CCN(C2CCCCC2)CC1)Nc1c[nH]c2ccccc12 | 3 / 3,
 O=c1c2c3c(sc2ncn1C1CCCCC1)CNCC3 | 3 / 3,
 O=C(Cc1ccccc1)OC1CC2CCC(C1)[NH2+]2 | 3 / 3,
 O=C(OCCN1CCCCC1)c1cc2ccccc2o1 | 3 / 3,
 N=c1[nH]c2nc3ccccn3c(=O)c2cc1C(=O)NCCc1ccccc1 | 3 / 3,
 O=S(=O)(c1ccccc1)N1CCN(C2CCN(Cc3ccccc3)CC2)CC1 | 3 / 3,
 O=c1[nH]c2ccc(S(=O)(=O)N3CCc4ccccc43)cc2o1 | 3 / 3,
 c1ccc2nc(NC3=NCN(Cc4ccc5c(c4)OCO5)CN3)ncc2c1 | 3 / 3,
 O=C(CCN1CCOCC1)N1c2ccccc2Sc2ccccc21 | 3 / 3,
 O=C1C=C(Nc2c[nH]n(-c3ccccc3)c2=O)C(=O)N1c1ccccc1 | 3 / 3,
 c1ccc2c(c1)Nc1ccccc1S2 | 7 / 9,
 N=c1[nH]c2nc3ccccn3c(=O)c2cc1C(=O)NCc1ccc2c(c1)OCO2 | 4 / 5,
 O=C(Nc1ccccc1)Nc1ccc2[nH]c(CCN3CCCCC3)nc2c1 | 2 / 2,
 O=S(=O)(c1cccc(-c2cn3ccccc3n2)c1)N1CCCCC1 | 2 /

In [6]:
def random_pick(fp_ls: List[rdkit.DataStructs.cDataStructs.ExplicitBitVect], 
                n: int, 
                seed: int = 42) -> List[int]:
    """Random diverse compound picking based on Rdkit MaxMinPicker"""
    picker = MaxMinPicker()
    return list(picker.LazyBitVectorPick(fp_ls, len(fp_ls), n, seed=seed))

# random initial pick
init_idx = random_pick(fp_ls, int(P_INIT*len(fp_ls)))

In [14]:
def update_scaffolds(idx_ls, scf_ls, idx_to_scf, activity_ls):
    for idx in idx_ls:
        scf = idx_to_scf[idx]
        scf.observe(idx, activity_ls[idx])
        
sampled = [] + init_idx

update_scaffolds(init_idx, scaffold_ls, idx_to_scf, y)

In [15]:
n_next = int(P_ITER * N_TOTAL)
print(f"Adding {n_next} new samples each iteration")

for _ in range(N_ITER):
    print(f"Iteration {_}")
    
    # sample distributions
    probs = []
    for scf in scaffold_ls:
        probs += scf.sample()
    probs = sorted(probs[::-1])

    sample_idxs = list(list(zip(*probs[:n_next]))[1])
    sampled += sample_idxs

    update_scaffolds(sample_idxs, scaffold_ls, idx_to_scf, y)
    
    print(f"Hit rate: {sum(y[sample_idxs])/len(sample_idxs):.4f}")

Adding 3183 new samples each iteration
Iteration 0
Hit rate: 0.0339
Iteration 1
Hit rate: 0.0280
Iteration 2
Hit rate: 0.0314
Iteration 3
Hit rate: 0.0280


In [16]:
print(f"Sampled {len(sampled)/len(y):.4f} of the library")
print(f"Hit rate: {sum(y[sampled])/len(sampled):.4f}, or {sum(y[sampled])/sum(y):.4f} of all hits")

Sampled 0.3500 of the library
Hit rate: 0.0335, or 0.3427 of all hits


In [17]:
sorted_scf = sorted(scaffold_ls, key=lambda x: x.dist.ppf25())[::-1]

In [9]:
len(scaffold_ls)

31353

In [10]:
len(y)

63662

In [19]:
[s for s in sorted_scf[:200]]

[c1ccc2c(c1)oc1c(Nc3ccncc3)ncnc12 | 2 / 2,
 c1ncc2c3c(sc2n1)CNCC3 | 2 / 2,
 N=c1[nH]c2nc3ccccn3c(=O)c2cc1C(=O)NCc1ccc2c(c1)OCO2 | 2 / 2,
 c1ccc(CNc2ncc(-c3ccccc3)[nH]2)cc1 | 2 / 2,
 C1=Nc2sc3c(c2C2=NCCN12)CCCC3 | 1 / 1,
 c1ccc(CNc2ccc3c(c2)ncn3C2CCCCC2)nc1 | 1 / 1,
 c1ccc(N2CCN(CCc3nc4ccccc4[nH]3)CC2)cc1 | 1 / 1,
 O=c1[nH]cnc2oc3ccccc3c(=O)c12 | 1 / 1,
 c1ccc(CNCC2CCCN2)cc1 | 1 / 1,
 O=S(=O)(Nc1ccc2[nH]c(CCN3CCCCC3)nc2c1)c1ccccc1 | 1 / 1,
 O=c1cc(CN2CCNCC2)c2ccccc2o1 | 1 / 1,
 c1cc(-c2nnc3sc(C4CCCCC4)nn23)n[nH]1 | 1 / 1,
 O=c1[nH]c(=O)c2c(nc(N3CCN(c4ccccc4)CC3)n2CCCSc2nncs2)[nH]1 | 1 / 1,
 O=C1C2ON(c3ccccc3)C(c3ccccc3)C2C(=O)N1Cc1ccccc1 | 1 / 1,
 O=S(=O)(c1cccc(-c2cn3ccccc3n2)c1)N1CCCCC1 | 1 / 1,
 O=c1[nH]c(=O)c2c(nc(N3CCN(c4ccccc4)CC3)n2CCSc2nccs2)[nH]1 | 1 / 1,
 O=C(NCCCn1ccnc1)C1c2ccccc2Oc2ccccc21 | 1 / 1,
 c1coc(CNc2ncnc3sc4c(c23)CCNC4)c1 | 1 / 1,
 c1ccc(-c2nc(NC3=NCN(CCN4CCOCC4)CN3)nc3ccccc23)cc1 | 1 / 1,
 O=c1[nH]c(-c2cccs2)nc2sc3c(c12)CCOC3 | 1 / 1,
 C(=Cc1ccccc1)CSc1nnco1 | 1 /