In [1]:
import sys
import os

import rdkit
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import AllChem
from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.ipython_useSVG=True
from rdkit.Chem import Descriptors
from rdkit.Chem import rdMolDescriptors
from rdkit.Chem import rdChemReactions
from rdkit.Chem import rdqueries # faster than iterating atoms https://sourceforge.net/p/rdkit/mailman/message/34538007/ 
from rdkit.Chem.rdchem import Atom
from rdkit import DataStructs
import numpy as np

from itertools import chain
import random

from tqdm import tqdm
import csv
import re 
import pickle
import copy

import torch
import torch.nn.functional as F
import torch.nn as nn

### data

In [11]:
sample_rxns = [
    '[CH:2](=[O:3])[NH:4][c:5]1[c:6]([CH2:7][CH2:8][c:9]2[n+:10]([CH3:15])[cH:11][cH:12][cH:13][cH:14]2)[cH:16][c:17]([O:20][CH3:21])[cH:18][cH:19]1>[I-].[I-].C[n+]1ccccc1C=Cc1ccccc1[N+](=O)[O-]>[CH:2](=[O:3])[NH:4][c:5]1[c:6]([CH2:7][CH2:8][CH:9]2[N:10]([CH3:15])[CH2:11][CH2:12][CH2:13][CH2:14]2)[cH:16][c:17]([O:20][CH3:21])[cH:18][cH:19]1',
 'Cl[c:2]1[c:3]([C:4](=[O:5])[c:6]2[cH:7][c:8]([CH:13]([C:14](=[O:15])[OH:16])[CH3:17])[cH:9][cH:10][c:11]2[OH:12])[cH:18][cH:19][cH:20][n:21]1>O.[Cu].[Cu](I)I.[OH-].[Na+]>[c:2]12[c:3]([c:4](=[O:5])[c:6]3[cH:7][c:8]([CH:13]([C:14](=[O:15])[OH:16])[CH3:17])[cH:9][cH:10][c:11]3[o:12]1)[cH:18][cH:19][cH:20][n:21]2',
 '[CH3:3][c:4]1[c:5]([OH:12])[c:6]([CH3:11])[cH:7][cH:8][c:9]1[CH3:10].CC(O)=[O:15]>II>[CH3:3][C:4]1=[C:9]([CH3:10])[C:8](=[O:15])[CH:7]=[C:6]([CH3:11])[C:5]1=[O:12]',
 '[CH3:1][O:2][CH:3]([C:4]#[CH:5])[CH2:6][CH2:7][CH2:8][CH2:9][CH3:10].I[I:29]>CC(C)C(C)BC(C)C(C)C.C[N+](C)(C)[O-].[OH-].[Na+]>[CH3:1][O:2][CH:3](/[CH:4]=[CH:5]/[I:29])[CH2:6][CH2:7][CH2:8][CH2:9][CH3:10]',
 'O=[C:4]1[CH:3]([CH2:1][CH3:2])[CH2:8][CH2:7][CH2:6][CH:5]1[CH3:9].[CH3:11][CH:12]([CH2:13][O:14][CH3:15])[NH2:16]>C1(C)C=CC=CC=1.CCOCC.[Ti](Cl)(Cl)(Cl)Cl>[CH2:1]([CH3:2])[CH:3]1[C:4](=[N:16][CH:12]([CH3:11])[CH2:13][O:14][CH3:15])[CH:5]([CH3:9])[CH2:6][CH2:7][CH2:8]1',
]

In [56]:
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import rdChemReactions
from rdkit import DataStructs
import numpy as np
from scipy import sparse

def create_rxn_MorganFP(rxn_smi, fp_type='diff', 
                        radius=2, rctfp_size=16384, prodfp_size=16384, 
                        max_rcts=4, 
                        useChirality=True, dtype='int8'):
    '''
    fp_type: 'diff' or 'sep', 
    'diff' (difference):
    Creates reaction MorganFP following Schneider et al in J. Chem. Inf. Model. 2015, 55, 1, 39–53
    reactionFP = productFP - sum(reactantFPs)
    
    'sep' (separate):
    Creates separate reactantsFP and productFP following Gao et al in ACS Cent. Sci. 2018, 4, 11, 1465–1476
    '''
    # initialise empty fp numpy arrays
    if fp_type == 'diff':
        diff_fp = np.empty(fp_size, dtype = dtype)
    elif fp_type == 'sep':
        rcts_fp = np.empty(rctfp_size, dtype = dtype)
        prod_fp = np.empty(prodfp_size, dtype = dtype)
    elif fp_type == 'precomp':
        rct_fps = np.empty((max_rcts, rctfp_size)) 
        prod_fp = np.empty((1, rctfp_size), dtype = dtype)
        assert rctfp_size == prodfp_size, 'rctfp_size != prodfp_size, unable to build sparse matrix!!!'
    else:
        print('ERROR: fp_type not recognised!')
        return
    
    # create product FP
    prod_mol = Chem.MolFromSmiles(rxn_smi.split('>')[-1])
    try:
        prod_fp_bit = AllChem.GetMorganFingerprintAsBitVect(
                        mol=prod_mol, radius=radius, nBits=prodfp_size, useChirality=useChirality)
        if fp_type == 'precomp':
            DataStructs.ConvertToNumpyArray(prod_fp_bit, prod_fp[0, :])
        else:      # on-the-fly creation of MorganFP during training  
            fp = np.empty(prodfp_size, dtype = dtype)   # temporarily store numpy array as fp 
            DataStructs.ConvertToNumpyArray(prod_fp_bit, fp)
            if fp_type == 'diff':
                diff_fp += fp
            elif fp_type == 'sep':
                prod_fp = fp
    except Exception as e:
        print("Cannot build product fp due to {}".format(e))
        return
                                  
    # create reactant FPs, subtracting each from product FP
    rcts_smi = rxn_smi.split('>')[0].split('.')
    for i, rct_smi in enumerate(rcts_smi):
        rct_mol = Chem.MolFromSmiles(rct_smi)
        try:
            rct_fp_bit = AllChem.GetMorganFingerprintAsBitVect(
                            mol=rct_mol, radius=radius, nBits=rctfp_size, useChirality=useChirality)
            if fp_type == 'precomp':
                DataStructs.ConvertToNumpyArray(rct_fp_bit, rct_fps[i, :])
            else:     # on-the-fly creation of MorganFP during training  
                fp = np.empty(rctfp_size, dtype = dtype)
                DataStructs.ConvertToNumpyArray(rct_fp_bit, fp)
                if fp_type == 'diff':
                    diff_fp -= fp
                elif fp_type == 'sep':
                    rcts_fp += fp
        except Exception as e:
            print("Cannot build reactant fp due to {}".format(e))
            return
    
    if fp_type == 'diff':
        return diff_fp
    elif fp_type == 'sep':
        return np.concatenate([rcts_fp, prod_fp])
    elif fp_type == 'precomp':
        rct_fps = rct_fps.reshape(1, -1) # flatten into 1 long row-array 
        return np.concatenate([rct_fps, prod_fp], axis=1), len(rcts_smi)
    
    
def make_sparse_FP(rxn_smi_dataset, radius, fp_size, max_rcts, 
                   useChirality=True, dtype='int8', toprint=False, every=10000):
    '''
    rxn_smi_dataset: expects a list of rxn_smi (strings)
    returns: a sparse matrix of reaction fingerprints + a list of num_rcts per rxn
    '''
    sparse_fps, list_num_rcts = [], []   
    for i, rxn_smi in enumerate(rxn_smi_dataset):          
        rxn_fp, num_rcts = create_rxn_MorganFP(rxn_smi, fp_type='precomp', 
                            radius=radius, rctfp_size=fp_size, prodfp_size=fp_size, 
                            max_rcts=max_rcts, 
                            useChirality=useChirality, dtype=dtype)
        rxn_fp_sparse = sparse.csr_matrix(rxn_fp)
        sparse_fps.append(rxn_fp_sparse)
        list_num_rcts.append(num_rcts)
        if toprint and i%every == 0:
            print('Processed: {} rxn SMILES'.format(i))

    return sparse.vstack(sparse_fps), list_num_rcts

In [59]:
with open(os.getcwd()+'/clean_rxn_50k_nomap_noreagent.pickle', 'rb') as handle:
    clean_rxn = pickle.load(handle)

In [None]:
%%time
clean_rxn_sparse_fp = {'train': None, 'valid': None, 'test': None}
clean_rxn_sparse_fp_numrcts = {'train': None, 'valid': None, 'test': None}

for key in clean_rxn.keys():
    print('\nMaking sparse FPs for {}'.format(key))
    fp_sparse_matrix, list_num_rcts = make_sparse_FP(clean_rxn[key], 3, 4096, max_rcts=4, 
                                       useChirality=True, dtype='int8', toprint=True, every=10000)
    clean_rxn_sparse_fp[key] = fp_sparse_matrix
    clean_rxn_sparse_fp_numrcts[key] = list_num_rcts

# takes about 2 mins 15 secs

### 16384 FP np arrays are massive, saving to disk causes memory error (>8 GB). need to convert to sparse csr matrix then save as .npz 
- for 4096 x (4 max rcts + 1 prod) sparse matrix, storage size is about 80 MB for USPTO-50k dataset

In [49]:
with open(os.getcwd()+'/clean_rxn_50k_sparse_FPs.pickle', 'wb') as handle:
    pickle.dump(clean_rxn_sparse_fp, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
with open(os.getcwd()+'/clean_rxn_50k_sparse_FPs_numrcts.pickle', 'wb') as handle:
    pickle.dump(clean_rxn_sparse_fp_numrcts, handle, protocol=pickle.HIGHEST_PROTOCOL)

### .npz takes up less space (9 MB) than .pickle (80 MB) 

In [52]:
sparse.save_npz(os.getcwd()+'/clean_rxn_50k_sparse_FPs_train.npz', clean_rxn_sparse_fp['train'])
sparse.save_npz(os.getcwd()+'/clean_rxn_50k_sparse_FPs_valid.npz', clean_rxn_sparse_fp['valid'])
sparse.save_npz(os.getcwd()+'/clean_rxn_50k_sparse_FPs_test.npz', clean_rxn_sparse_fp['test'])

### Try on USPTO_STEREO
- max_rcts = 9 

In [61]:
with open(os.getcwd()+'/USPTO_STEREO_pickles/clean_rxn_nomap_noreagent.pickle', 'rb') as handle:
    clean_rxn = pickle.load(handle)

In [63]:
%%time
clean_rxn_sparse_fp = {'train': None, 'valid': None, 'test': None}
clean_rxn_sparse_fp_numrcts = {'train': None, 'valid': None, 'test': None}

for key in clean_rxn.keys():
    print('\nMaking sparse FPs for {}'.format(key))
    fp_sparse_matrix, list_num_rcts = make_sparse_FP(clean_rxn[key], 3, 4096, max_rcts=9, 
                                       useChirality=True, dtype='int8', toprint=True, every=10000)
    clean_rxn_sparse_fp[key] = fp_sparse_matrix
    clean_rxn_sparse_fp_numrcts[key] = list_num_rcts


Making sparse FPs for train
Processed: 0 rxn SMILES
Processed: 10000 rxn SMILES
Processed: 20000 rxn SMILES
Processed: 30000 rxn SMILES
Processed: 40000 rxn SMILES
Processed: 50000 rxn SMILES
Processed: 60000 rxn SMILES
Processed: 70000 rxn SMILES
Processed: 80000 rxn SMILES
Processed: 90000 rxn SMILES
Processed: 100000 rxn SMILES
Processed: 110000 rxn SMILES
Processed: 120000 rxn SMILES
Processed: 130000 rxn SMILES
Processed: 140000 rxn SMILES
Processed: 150000 rxn SMILES
Processed: 160000 rxn SMILES
Processed: 170000 rxn SMILES
Processed: 180000 rxn SMILES
Processed: 190000 rxn SMILES
Processed: 200000 rxn SMILES
Processed: 210000 rxn SMILES
Processed: 220000 rxn SMILES
Processed: 230000 rxn SMILES
Processed: 240000 rxn SMILES
Processed: 250000 rxn SMILES
Processed: 260000 rxn SMILES
Processed: 270000 rxn SMILES
Processed: 280000 rxn SMILES
Processed: 290000 rxn SMILES
Processed: 300000 rxn SMILES
Processed: 310000 rxn SMILES
Processed: 320000 rxn SMILES
Processed: 330000 rxn SMILES

In [67]:
clean_rxn_sparse_fp['test']

<48165x40960 sparse matrix of type '<class 'numpy.float64'>'
	with 13412426 stored elements in Compressed Sparse Row format>

In [64]:
sparse.save_npz(os.getcwd()+'/USPTO_STEREO_pickles/clean_rxn_sparse_FPs_train.npz', clean_rxn_sparse_fp['train'])
sparse.save_npz(os.getcwd()+'/USPTO_STEREO_pickles/clean_rxn_sparse_FPs_valid.npz', clean_rxn_sparse_fp['valid'])
sparse.save_npz(os.getcwd()+'/USPTO_STEREO_pickles/clean_rxn_sparse_FPs_test.npz', clean_rxn_sparse_fp['test'])

In [66]:
with open(os.getcwd()+'/USPTO_STEREO_pickles/clean_rxn_sparse_FPs_numrcts.pickle', 'wb') as handle:
    pickle.dump(clean_rxn_sparse_fp_numrcts, handle, protocol=pickle.HIGHEST_PROTOCOL)