In [1]:
from rdkit import RDLogger
lg = RDLogger.logger()
lg.setLevel(4)

In [2]:
from __future__ import print_function
from rdkit.Chem.Draw import IPythonConsole, ReactionToImage, MolToImage, MolsToGridImage
from IPython.display import SVG, display, clear_output
import rdkit.Chem as Chem
import rdkit.Chem.AllChem as AllChem
from rdkit import DataStructs
import pandas as pd
import numpy as np
from tqdm import tqdm
import json
import sys
sys.path.append('../../')
from retrosim.utils.draw import ReactionStringToImage, TransformStringToImage
from retrosim.utils.generate_retro_templates import process_an_example
from retrosim.data.get_data import get_data_df, split_data_df
from rdchiral.main import rdchiralRun, rdchiralReaction, rdchiralReactants

from joblib import Parallel, delayed
import multiprocessing
num_cores = multiprocessing.cpu_count()

In [3]:
data = get_data_df('../data/data_processed.csv')
split_data_df(data) # 80/10/10 within each class

15150 rows with class value 1
11893 rows with class value 2
5661 rows with class value 3
909 rows with class value 4
672 rows with class value 5
8237 rows with class value 6
4613 rows with class value 7
811 rows with class value 8
1834 rows with class value 9
230 rows with class value 10
train    40004
test      5006
val       5000
Name: dataset, dtype: int64


In [4]:
data.loc[data['dataset'] == 'test']['class'].value_counts()

1     1515
2     1190
6      824
3      567
7      462
9      184
4       91
8       82
5       68
10      23
Name: class, dtype: int64

## Select one set of settings to test
(use the test_similiarity.py script to do the full set of experiments)

In [5]:
class_ = 3
similarity_metric = DataStructs.BulkTanimotoSimilarity # BulkDiceSimilarity or BulkTanimotoSimilarity
similarity_label = 'Tanimoto'
getfp = lambda smi: AllChem.GetMorganFingerprint(Chem.MolFromSmiles(smi), 2, useFeatures=True)
getfp_label = 'Morgan2Feat'
dataset = 'test'
label = '{}_class{}_fp{}_sim{}'.format(
    dataset,
    class_,
    getfp_label,
    similarity_label,
)

### Only get new FPs if necessary - is a little slow

In [6]:
try:
    if prev_FP != getfp_label:
        raise NameError
except NameError:
    all_fps = []
    for smi in tqdm(data['prod_smiles']):
        all_fps.append(getfp(smi))
    data['prod_fp'] = all_fps
    prev_FP = getfp_label

100%|██████████████████████████████████████████████████████████████████████████| 50010/50010 [00:39<00:00, 1275.70it/s]


### Get the training data subset of the full data

In [7]:
if class_ != 'all':
    datasub = data.loc[data['class'] == class_].loc[data['dataset'] == 'train']
    datasub_val = data.loc[data['class'] == class_].loc[data['dataset'] == dataset]
else:
    datasub = data.loc[data['dataset'] == 'train']
    datasub_val = data.loc[data['dataset'] == dataset]
fps = list(datasub['prod_fp'])
print('Size of knowledge base: {}'.format(len(fps)))

Size of knowledge base: 4528


## Go through full validation/test data, define the function for processing

In [8]:
jx_cache = {}
draw = False
debug = False

def do_one(ix, draw=draw, debug=debug, max_prec=100, nopause=False):
    
    rec_for_printing = ''
    
    ex = Chem.MolFromSmiles(datasub_val['prod_smiles'][ix])
    rct = rdchiralReactants(datasub_val['prod_smiles'][ix])
    if draw: 
        print('Mol {}'.format(ix))
    if debug:
        rec_for_printing += datasub_val['prod_smiles'][ix] + '\n'
        rec_for_printing += 'True reaction:'
        rec_for_printing += datasub_val['rxn_smiles'][ix] + '\n'
    fp = datasub_val['prod_fp'][ix]
    
    sims = similarity_metric(fp, [fp_ for fp_ in datasub['prod_fp']])
    js = np.argsort(sims)[::-1]

    if draw: display(ReactionStringToImage(datasub_val['rxn_smiles'][ix]))
    
    prec_goal = Chem.MolFromSmiles(datasub_val['rxn_smiles'][ix].split('>')[0])
    [a.ClearProp('molAtomMapNumber') for a in prec_goal.GetAtoms()]
    prec_goal = Chem.MolToSmiles(prec_goal, True)
    
    # Sometimes stereochem takes another canonicalization...
    prec_goal = Chem.MolToSmiles(Chem.MolFromSmiles(prec_goal), True)
    if debug:
        rec_for_printing += 'prec_goal: {}\n'.format(prec_goal)
    # Get probability of precursors
    probs = {}
    
    for ji, j in enumerate(js[:max_prec]):
        jx = datasub.index[j]
        
        if draw:
            print('\n\n' + '-'*50 + '\n')
            print('RANK {} PRECEDENT'.format(ji+1))
            print('PRODUCT MATCH SCORE: {}'.format(sims[j]))
            display(ReactionStringToImage(datasub['rxn_smiles'][jx]))
        if debug:
            rec_for_printing += '\nReaction precedent {}, prod similarity {}\n'.format(
                ji+1, sims[j])
            rec_for_printing += '-> rxn_smiles {}\n'.format(datasub['rxn_smiles'][jx])
        if jx in jx_cache:
            (rxn, template, rcts_ref_fp) = jx_cache[jx]
        else:
            retro_canonical = process_an_example(datasub['rxn_smiles'][jx], super_general=True)
            if retro_canonical is None:
                continue
            template = '(' +  retro_canonical.replace('>>', ')>>') 
            rcts_ref_fp = getfp(datasub['rxn_smiles'][jx].split('>')[0])
            rxn = rdchiralReaction(template)
            jx_cache[jx] = (rxn, template, rcts_ref_fp)
        if debug:
            rec_for_printing += '-> template: {}\n'.format(template)
        if draw:
            print('-> template: {}'.format(template))
            
        try:
            outcomes = rdchiralRun(rxn, rct, combine_enantiomers=False)
        except Exception as e:
            print(e)
            outcomes = []

        if not outcomes and draw:
            print('No precursors could be generated!')
            
        for precursors in outcomes:
            precursors_fp = getfp(precursors)
            precursors_sim = similarity_metric(precursors_fp, [rcts_ref_fp])[0]
            if debug:
                rec_for_printing += 'prec sim {} smiles {}\n'.format(precursors_sim, precursors)
            if draw:
                print('Precursor similarity {}, overall {}, smiles {}'.format(
                    precursors_sim, precursors_sim*sims[j], precursors))
                display(MolToImage(Chem.MolFromSmiles(precursors)))
            if precursors in probs:
                probs[precursors] = max(probs[precursors], precursors_sim * sims[j])
            else:
                probs[precursors] = precursors_sim * sims[j]
    
    testlimit = 50
    mols = []
    legends = []

    found_rank = 9999
    for r, (prec, prob) in enumerate(sorted(probs.iteritems(), key=lambda x:x[1], reverse=True)[:testlimit]):
        mols.append(Chem.MolFromSmiles(prec))
        if prec == prec_goal:
            found_rank = r + 1
            legends.append('[TRUE] {}'.format(prob))
        else:
            legends.append('{}'.format(prob))
            pass
    if found_rank == 9999 and debug:
        print(datasub_val['rxn_smiles'][ix])
        display(ReactionStringToImage(datasub_val['rxn_smiles'][ix]))
        print(prec_goal)
        draw = True
    
    if draw: 
        img=MolsToGridImage(mols[:9],molsPerRow=3,subImgSize=(300,300),legends=legends[:9])
        display(img)
        for mol in mols[:9]:
            print(Chem.MolToSmiles(mol, True))
    if debug:
        print(rec_for_printing)
    if (draw or debug) and not nopause:
        raw_input('pause')
        clear_output()
        
    return found_rank

## More examples from each class

In [9]:
import random
if class_ != 'all': 
    random.seed(123)
else:
    import time
    random.seed(time.time())
ix = [ixi for ixi in datasub_val.loc[datasub_val['class'] == 3].index]
ixi = random.choice(ix)
print(ixi)
print(class_)
ixi = 45105
do_one(ixi, debug=False, draw=True, nopause=True)

45342
3


KeyError: 45105