In [139]:
import os
import pandas as pd
from rdkit import Chem
import csv

DATA = os.path.abspath("../data")
OUTPUT_PATH = os.path.join(DATA, "generated", "zairachem")
INPUT_PATH = os.path.join(DATA, "generated")

def read_smiles(file_name):
    with open(file_name, "r") as f:
        reader = csv.reader(f)
        smiles = []
        for r in reader:
            smiles += [r[0]]
    return smiles

def write_smiles(smiles, file_name):
    new_mols = smiles
    with open(os.path.join(OUTPUT_PATH, file_name), "w") as f:
        writer = csv.writer(f)
        for r in new_mols:
            writer.writerow([r])

In [143]:
from exmol import run_stoned
from tqdm import tqdm
import numpy as np
import time
import random
from rdkit import DataStructs
from rdkit.Chem import AllChem


class StonedSampler(object):
    def __init__(self, max_mutations=2, min_mutations=1):
        self.max_mutations = max_mutations
        self.min_mutations = min_mutations

    def sample(self, smiles, n):
        return run_stoned(
            smiles,
            num_samples=n,
            max_mutations=self.max_mutations,
            min_mutations=self.min_mutations,
        )


class StonedBatchSampler(object):

    def __init__(self, min_similarity=0.6, max_similarity=0.9, scorer=None, inflation=2, time_budget_sec=60):
        self.min_similarity = min_similarity
        self.max_similarity = max_similarity
        self.sampler = StonedSampler(max_mutations=5, min_mutations=1)
        if scorer is None:
            self.scorer = SybaClassifier()
            self.scorer.fitDefaultScore()
        else:
            self.scorer = scorer
        self.inflation = inflation
        self.time_budget_sec = time_budget_sec
        self.elapsed_time = 0
        self.finished = False

    def _sample(self, smiles_list, n):
        random.shuffle(smiles_list)
        n_individual = int(np.clip(self.inflation*n/len(smiles_list), 100, 1000))
        available_time = int((self.time_budget_sec - self.elapsed_time)) + 1
        samples_per_sec = 100
        estimated_time = len(smiles_list)/samples_per_sec
        if estimated_time > available_time:
            n_individual = 10
        sampled_smiles = []
        sampled_sim = []
        for smi in tqdm(smiles_list):
            t0 = time.time()
            sampled = self.sampler.sample(smi, n_individual)
            sampled_smiles += sampled[0]
            sampled_sim += sampled[1]
            t1 = time.time()
            dt = t1-t0
            self.elapsed_time += dt
            if self.elapsed_time > self.time_budget_sec:
                self.finished = True
                break
        smiles = []
        for smi, sim in zip(sampled_smiles, sampled_sim):
            if sim < self.min_similarity or sim > self.max_similarity:
                continue
            smiles += [smi]
        n = int(len(smiles)/self.inflation+1)
        smiles = self._select_by_similarity(smiles)
        smiles = self._select_by_score(smiles, n)
        return set(smiles)
    
    def _select_by_score(self, smiles, n):
        smiles = list(smiles)
        scores = [self.scorer.predict(smi) for smi in tqdm(smiles)]
        df = pd.DataFrame({"smiles": smiles, "score": scores})
        return list(df.sort_values(by="score").tail(n)["smiles"])
    
    def _select_by_similarity(self, smiles):
        sel_smiles = []
        for smi in tqdm(smiles):
            mol = Chem.MolFromSmiles(smi)
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2)
            sims = DataStructs.BulkTanimotoSimilarity(fp, self.seed_fps)
            sim = np.max(sims)
            if sim < self.min_similarity or sim > self.max_similarity:
                continue
            sel_smiles += [smi]
        return sel_smiles
    
    def sample(self, smiles_list, n):
        self.seed_smiles = list(smiles_list)
        self.seed_fps = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smi), 2) for smi in self.seed_smiles]
        smiles = set(smiles_list)
        sampled_smiles = set()
        for i in range(n):
            new_smiles = self._sample(list(smiles), n)
            sampled_smiles.update(new_smiles)
            smiles.update(new_smiles)
            if self.finished:
                break
        smiles = list(sampled_smiles)
        smiles = self._select_by_similarity(smiles)
        if len(smiles) > n:
            smiles = self._select_by_score(smiles, n)
        self.elapsed_time = 0
        self.finished = False
        return smiles

In [6]:
from syba.syba import SybaClassifier
syba = SybaClassifier()
syba.fitDefaultScore()


In [144]:
smp = StonedBatchSampler(scorer=syba, time_budget_sec=60*5)

In [146]:
smiles = read_smiles(os.path.join(INPUT_PATH, "docking_hits.csv"))
sampled_smiles = smp.sample(smiles, 10000)
write_smiles(sampled_smiles, "stoned-docking.csv")

 24%|█████████▏                            | 8227/34053 [05:02<15:50, 27.18it/s]
100%|███████████████████████████████████████| 5417/5417 [00:57<00:00, 94.39it/s]
100%|█████████████████████████████████████| 3826/3826 [00:01<00:00, 1983.57it/s]
100%|███████████████████████████████████████| 2709/2709 [00:31<00:00, 85.03it/s]


In [147]:
smiles = read_smiles(os.path.join(INPUT_PATH, "docking_top100_hits.csv"))
sampled_smiles = smp.sample(smiles, 10000)
write_smiles(sampled_smiles, "stoned-docking_top100.csv")

100%|█████████████████████████████████████████| 100/100 [01:18<00:00,  1.27it/s]
100%|█████████████████████████████████████| 1387/1387 [00:00<00:00, 3008.80it/s]
100%|███████████████████████████████████████| 953/953 [00:00<00:00, 1840.93it/s]
 70%|████████████████████████████▌            | 554/794 [03:41<01:36,  2.50it/s]
100%|█████████████████████████████████████| 4347/4347 [00:01<00:00, 2937.62it/s]
100%|███████████████████████████████████████| 947/947 [00:00<00:00, 1836.91it/s]
100%|█████████████████████████████████████| 1634/1634 [00:00<00:00, 2802.99it/s]


In [148]:
smiles = read_smiles(os.path.join(INPUT_PATH, "docking_top1000_hits.csv"))
sampled_smiles = smp.sample(smiles, 10000)
write_smiles(sampled_smiles, "stoned-docking_top1000.csv")

 79%|███████████████████████████████▊        | 794/1000 [05:01<01:18,  2.64it/s]
100%|█████████████████████████████████████| 5427/5427 [00:04<00:00, 1137.72it/s]
100%|█████████████████████████████████████| 3621/3621 [00:01<00:00, 1945.81it/s]
100%|█████████████████████████████████████| 2713/2713 [00:02<00:00, 1128.64it/s]


In [149]:
smiles = read_smiles(os.path.join(INPUT_PATH, "pocketvec_hits.csv"))
sampled_smiles = smp.sample(smiles, 10000)
write_smiles(sampled_smiles, "stoned-pocketvec.csv")

100%|███████████████████████████████████████████| 80/80 [00:35<00:00,  2.25it/s]
100%|███████████████████████████████████████| 576/576 [00:00<00:00, 5141.24it/s]
100%|███████████████████████████████████████| 342/342 [00:00<00:00, 4837.00it/s]
100%|█████████████████████████████████████████| 366/366 [01:15<00:00,  4.86it/s]
100%|█████████████████████████████████████| 1564/1564 [00:00<00:00, 5430.63it/s]
100%|███████████████████████████████████████| 438/438 [00:00<00:00, 3333.72it/s]
100%|█████████████████████████████████████████| 776/776 [02:50<00:00,  4.56it/s]
100%|█████████████████████████████████████| 3622/3622 [00:00<00:00, 5113.28it/s]
100%|███████████████████████████████████████| 865/865 [00:00<00:00, 3740.69it/s]
  6%|██▎                                      | 87/1549 [00:20<05:44,  4.24it/s]
100%|███████████████████████████████████████| 443/443 [00:00<00:00, 4693.82it/s]
100%|█████████████████████████████████████████| 98/98 [00:00<00:00, 2698.68it/s]
100%|███████████████████████

In [150]:
smiles = read_smiles(os.path.join(INPUT_PATH, "pocketvec_mw250_hits.csv"))
sampled_smiles = smp.sample(smiles, 10000)
write_smiles(sampled_smiles, "stoned-pocketvec_mw250.csv")

100%|███████████████████████████████████████████| 15/15 [00:40<00:00,  2.71s/it]
100%|███████████████████████████████████████| 803/803 [00:00<00:00, 4536.04it/s]
100%|███████████████████████████████████████| 569/569 [00:00<00:00, 3026.67it/s]
100%|█████████████████████████████████████████| 409/409 [01:57<00:00,  3.48it/s]
100%|█████████████████████████████████████| 2694/2694 [00:00<00:00, 4847.90it/s]
100%|███████████████████████████████████████| 601/601 [00:00<00:00, 2812.21it/s]
 50%|████████████████████▌                    | 484/965 [02:23<02:22,  3.38it/s]
100%|█████████████████████████████████████| 3226/3226 [00:00<00:00, 4610.86it/s]
100%|███████████████████████████████████████| 662/662 [00:00<00:00, 2921.90it/s]
100%|█████████████████████████████████████| 1548/1548 [00:00<00:00, 4209.03it/s]


In [152]:
smiles = read_smiles(os.path.join(INPUT_PATH, "known_hits.csv"))
sampled_smiles = smp.sample(smiles, 10000)
write_smiles(sampled_smiles, "stoned-known.csv")

100%|███████████████████████████████████████████| 16/16 [00:47<00:00,  2.95s/it]
100%|███████████████████████████████████████| 775/775 [00:00<00:00, 4738.15it/s]
100%|███████████████████████████████████████| 434/434 [00:00<00:00, 2920.32it/s]
100%|█████████████████████████████████████████| 403/403 [02:02<00:00,  3.28it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 2242/2242 [00:00<00:00, 4444.11it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 278/278 [00:00<00:00, 2957.11it/s]
 64%|██████████████████████████████████████████████████████                              | 427/663 [02:10<01:12,  3.26it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 2565/2565 [00:00<00:00, 4409.32it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 356/356 [00:00<00:00, 2732.15it/s]
100%|███████████████████████████████████████████████████