In [1]:
import os
import pandas as pd
from rdkit import Chem
import csv

DATA = os.path.abspath("../data")
OUTPUT_PATH = os.path.join(DATA, "generated", "zairachem")
INPUT_PATH = os.path.join(DATA, "generated")

def read_smiles(file_name):
    with open(file_name, "r") as f:
        reader = csv.reader(f)
        smiles = []
        for r in reader:
            smiles += [r[0]]
    return smiles

def write_smiles(smiles, file_name):
    new_mols = smiles
    with open(os.path.join(OUTPUT_PATH, file_name), "w") as f:
        writer = csv.writer(f)
        for r in new_mols:
            writer.writerow([r])

In [2]:
from exmol import run_stoned
from tqdm import tqdm
import numpy as np
import time
import random
from rdkit import DataStructs
from rdkit.Chem import AllChem


class StonedSampler(object):
    def __init__(self, max_mutations=2, min_mutations=1):
        self.max_mutations = max_mutations
        self.min_mutations = min_mutations

    def sample(self, smiles, n):
        return run_stoned(
            smiles,
            num_samples=n,
            max_mutations=self.max_mutations,
            min_mutations=self.min_mutations,
        )


class StonedBatchSampler(object):

    def __init__(self, min_similarity=0.6, max_similarity=0.9, scorer=None, inflation=2, time_budget_sec=60):
        self.min_similarity = min_similarity
        self.max_similarity = max_similarity
        self.sampler = StonedSampler(max_mutations=5, min_mutations=1)
        if scorer is None:
            self.scorer = SybaClassifier()
            self.scorer.fitDefaultScore()
        else:
            self.scorer = scorer
        self.inflation = inflation
        self.time_budget_sec = time_budget_sec
        self.elapsed_time = 0
        self.finished = False

    def _sample(self, smiles_list, n):
        random.shuffle(smiles_list)
        n_individual = int(np.clip(self.inflation*n/len(smiles_list), 100, 1000))
        available_time = int((self.time_budget_sec - self.elapsed_time)) + 1
        samples_per_sec = 100
        estimated_time = len(smiles_list)/samples_per_sec
        if estimated_time > available_time:
            n_individual = 10
        sampled_smiles = []
        sampled_sim = []
        for smi in tqdm(smiles_list):
            t0 = time.time()
            sampled = self.sampler.sample(smi, n_individual)
            sampled_smiles += sampled[0]
            sampled_sim += sampled[1]
            t1 = time.time()
            dt = t1-t0
            self.elapsed_time += dt
            if self.elapsed_time > self.time_budget_sec:
                self.finished = True
                break
        smiles = []
        for smi, sim in zip(sampled_smiles, sampled_sim):
            if sim < self.min_similarity or sim > self.max_similarity:
                continue
            smiles += [smi]
        n = int(len(smiles)/self.inflation+1)
        smiles = self._select_by_similarity(smiles)
        smiles = self._select_by_score(smiles, n)
        return set(smiles)
    
    def _select_by_score(self, smiles, n):
        smiles = list(smiles)
        scores = [self.scorer.predict(smi) for smi in tqdm(smiles)]
        df = pd.DataFrame({"smiles": smiles, "score": scores})
        return list(df.sort_values(by="score").tail(n)["smiles"])
    
    def _select_by_similarity(self, smiles):
        sel_smiles = []
        for smi in tqdm(smiles):
            mol = Chem.MolFromSmiles(smi)
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2)
            sims = DataStructs.BulkTanimotoSimilarity(fp, self.seed_fps)
            sim = np.max(sims)
            if sim < self.min_similarity or sim > self.max_similarity:
                continue
            sel_smiles += [smi]
        return sel_smiles
    
    def sample(self, smiles_list, n):
        self.seed_smiles = list(smiles_list)
        self.seed_fps = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smi), 2) for smi in self.seed_smiles]
        smiles = set(smiles_list)
        sampled_smiles = set()
        for i in range(n):
            new_smiles = self._sample(list(smiles), n)
            sampled_smiles.update(new_smiles)
            smiles.update(new_smiles)
            if self.finished:
                break
        smiles = list(sampled_smiles)
        smiles = self._select_by_similarity(smiles)
        if len(smiles) > n:
            smiles = self._select_by_score(smiles, n)
        self.elapsed_time = 0
        self.finished = False
        return smiles

In [3]:
from syba.syba import SybaClassifier
syba = SybaClassifier()
syba.fitDefaultScore()


In [5]:
smp = StonedBatchSampler(min_similarity=0.7, max_similarity=0.95, scorer=syba, time_budget_sec=60*5)

In [9]:
smiles = read_smiles(os.path.join(INPUT_PATH, "docking_top100_hits_lib_aug.csv"))
sampled_smiles = smp.sample(smiles, 10000)
write_smiles(sampled_smiles, "stoned-docking_top100_lib_aug.csv")

 24%|███████████████████▋                                                              | 1803/7524 [05:00<15:54,  5.99it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 6270/6270 [00:05<00:00, 1201.63it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 3972/3972 [00:00<00:00, 4784.07it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 3133/3133 [00:02<00:00, 1188.85it/s]


In [8]:
smiles = read_smiles(os.path.join(INPUT_PATH, "pocketvec_hits_lib_aug.csv"))
sampled_smiles = smp.sample(smiles, 10000)
write_smiles(sampled_smiles, "stoned-pocketvec_lib_aug.csv")

 64%|████████████████████████████████████████████████████▏                             | 3337/5241 [05:00<02:51, 11.10it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 4616/4616 [00:02<00:00, 1770.77it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 2961/2961 [00:00<00:00, 8302.47it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 2279/2279 [00:01<00:00, 1762.94it/s]


In [7]:
smiles = read_smiles(os.path.join(INPUT_PATH, "pocketvec_mw250_hits_lib_aug.csv"))
sampled_smiles = smp.sample(smiles, 10000)
write_smiles(sampled_smiles, "stoned-pocketvec_mw250_lib_aug.csv")

100%|██████████████████████████████████████████████████████████████████████████████████| 1324/1324 [02:40<00:00,  8.26it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 3623/3623 [00:00<00:00, 4248.85it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 2451/2451 [00:00<00:00, 7212.35it/s]
 37%|██████████████████████████████                                                    | 1139/3107 [02:20<04:02,  8.10it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 3465/3465 [00:00<00:00, 4239.21it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 1324/1324 [00:00<00:00, 6995.33it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 3038/3038 [00:00<00:00, 4175.87it/s]


In [6]:
smiles = read_smiles(os.path.join(INPUT_PATH, "known_hits_lib_aug.csv"))
sampled_smiles = smp.sample(smiles, 10000)
write_smiles(sampled_smiles, "stoned-known_lib_aug.csv")

 52%|██████████████████████████████████████████▌                                       | 2175/4195 [05:00<04:39,  7.23it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 5345/5345 [00:02<00:00, 2041.00it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 3069/3069 [00:00<00:00, 6206.69it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 2653/2653 [00:01<00:00, 2053.30it/s]
