In [3]:
from rdkit import Chem
from typing import List, Dict, Tuple, Set
from collections import namedtuple, deque
from rdkit.Chem import ChemicalFeatures,AllChem
import numpy as np
import torch
def get_mol(smiles: str, kekulize: bool = False) -> Chem.Mol:
    """SMILES string to Mol.

    Parameters
    ----------
    smiles: str,
        SMILES string for molecule
    kekulize: bool,
        Whether to kekulize the molecule
    """
    mol = Chem.MolFromSmiles(smiles)
    #mol = Chem.AddHs(mol)
    #AllChem.EmbedMolecule(mol)
    #AllChem.UFFOptimizeMolecule(mol)
    if mol is not None and kekulize:
        Chem.Kekulize(mol)
    return mol

def get_bond_info(mol: Chem.Mol) -> Dict:
    """Get information on bonds in the molecule.

    Parameters
    ----------
    mol: Chem.Mol
        Molecule
    """
    if mol is None:
        return {}

    bond_info = {}
    for bond in mol.GetBonds():
        a_start = bond.GetBeginAtom().GetAtomMapNum()
        a_end = bond.GetEndAtom().GetAtomMapNum()

        key_pair = sorted([a_start, a_end])
        bond_info[tuple(key_pair)] = [bond.GetBondTypeAsDouble(), bond.GetIdx()]
    return bond_info

In [4]:
def get_reaction_core(r: str, p: str, use_h_labels: bool = False) -> Tuple[Set, List]:
    """Get the reaction core and edits for given reaction

    Parameters
    ----------
    r: str,
        SMILES string representing the reactants
    p: str,
        SMILES string representing the product
    use_h_labels: bool,
        Whether to use change in hydrogen counts in edits
    """
    reac_mol = get_mol(r)
    prod_mol = get_mol(p)

    if reac_mol is None or prod_mol is None:
        print(reac_mol, prod_mol)
        return set(), []

    
    prod_bonds = get_bond_info(prod_mol)
    #print('prod_bonds:', prod_bonds)
    p_amap_idx = {atom.GetAtomMapNum(): atom.GetIdx() for atom in prod_mol.GetAtoms()} #assign product atom index
    #print('prod_map:', p_amap_idx)
    #print('                  ')
    
    max_amap = max([atom.GetAtomMapNum() for atom in reac_mol.GetAtoms()])
    for atom in reac_mol.GetAtoms():
        if atom.GetAtomMapNum() == 0:
            atom.SetAtomMapNum(max_amap + 1)
            max_amap += 1
    reac_bonds = get_bond_info(reac_mol)
    #print('reactants_bonds:',reac_bonds)
    reac_amap = {atom.GetAtomMapNum(): atom.GetIdx() for atom in reac_mol.GetAtoms()} #assign reactants atom index
    #print('reactants_map:', reac_amap)
    
    rxn_core = set()
    rxn_core1 = set()

    for bond in prod_bonds:
        if bond in reac_bonds and prod_bonds[bond][0] != reac_bonds[bond][0]:
            a_start, a_end = bond
            a_start, a_end = sorted([a_start, a_end])
            rxn_core.update([a_start, a_end])

        if bond not in reac_bonds:
            a_start, a_end = bond
            start, end = sorted([a_start, a_end])
            rxn_core.update([a_start, a_end])

    for bond in reac_bonds:
        if bond not in prod_bonds:
            amap1, amap2 = bond
            if (amap1 in p_amap_idx) and (amap2 in p_amap_idx):
                a_start, a_end = sorted([amap1, amap2])
                rxn_core.update([a_start, a_end])
                
                
                
    if use_h_labels:
        if len(rxn_core1) == 0:
            for atom in prod_mol.GetAtoms():
                amap_num = atom.GetAtomMapNum()

                numHs_prod = atom.GetTotalNumHs()
                numHs_reac = reac_mol.GetAtomWithIdx(reac_amap[amap_num]).GetTotalNumHs()

                if numHs_prod != numHs_reac:
                    rxn_core1.add(amap_num)
    #return rxn_core

    return rxn_core

In [5]:
import pandas as pd
index = []
r_mapped = []
r_unmapped = []
a_u = []
p_mapped = []
p_u = []
y = []
core = []
core_index = []
label = []
updated_reactions = []

#the data that going to be processed
large_df = pd.read_csv('dy_no_rdkit.csv')
df = large_df
print(df.shape)
error = []
reactants = df['reactant_smiles'].values
products = df['product_smiles'].values
yie = df['yield'].values
for i in range(df.shape[0]): 
    try: 
        if (i%1000 == 0):
            print(i)
        rxn = df['mapped_rxn'].values[i]
        a = df['reagents'].values[i]

        r, p = rxn.split('>>')
        ri = r.split('.') 
        r_original = df['reactants'].values[i]
        a_original = df['reagents'].values[i]
        p_original = df['products'].values[i]
        for ai in a.split('.'):
            ri.remove(ai)   
        r = '.'.join(ri)
        
        rxn_core = get_reaction_core(r,p)
        
        if(rxn_core == set()):
            error.append(i)
            
        if(rxn_core != set()):
            #print(rxn_core)
            index.append(i) ##index is important!!!
            
            r_mapped.append(r)
            p_mapped.append(p)
            a_u.append(a)
            #print(reactants[i])
            r_unmapped.append(reactants[i])
            #print(products[i])
            p_u.append(products[i])
            #print(round(float(yie[i]),2))
            #print(np.round(yie[i], decimals=2))
            y.append(yie[i])
            core.append(rxn_core)
                
            core_index_atom = []
            for ri in r.split('.'):
                reac_moli = get_mol(ri)
                #core_index_atom.append(np.array())
                moli_center = [str(atom.GetIdx()) for atom in reac_moli.GetAtoms() if (atom.GetAtomMapNum() in rxn_core)]
                moli_center = '.'.join(moli_center)
                
                core_index_atom.append(moli_center)
            core_index_atom = '>'.join(core_index_atom)
            #print(core_index_atom)
            core_index.append(core_index_atom)
            label.append(df['labels'].values[i])
            
            reaction_original = r_original + '.' + '>' + a_original + '>' + p_original # for t5chem
            updated_reactions.append(reaction_original)
                   
    except:
        error.append(i) #can be used to index the deleted reactions!!!!
        
print(len(updated_reactions))
data = {'reactants_mapped':r_mapped, 'reactants':r_unmapped, 'reagents':a_u, 'products_mapped':p_mapped, 'products':p_u, 'yield':y, 'core_index_mapping': core, 'core_index_atom': core_index, 'label': label}
df1 = pd.DataFrame(data)
print(df1.head(3))
df1.to_csv('USPTO500MT_test_processed_100.csv', index=False)



(14238, 21)
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
9497
                                    reactants_mapped  \
0  [CH3:1][C:2]1([CH3:3])[CH2:4][CH2:5][CH2:6][CH...   
1  Br[c:21]1[cH:20][cH:19][c:18](-[c:17]2[cH:16][...   
2  Cl[S:10](=[O:11])(=[O:12])[c:13]1[cH:14][cH:15...   

                                           reactants  \
0                              CC1(C)CCCCC(C)(C)C1=O   
1  COC(=O)NC(C(=O)N1CCCC1c1ncc(-c2ccc(Br)cc2)n1CO...   
2          CCCC[C@H]1CCC(O)C1.Cc1ccc(S(=O)(=O)Cl)cc1   

                                            reagents  \
0  CCOCC.O.O=S(=O)(O)O.[Al+3].[H-].[H-].[H-].[H-]...   
1  CC(=O)[O-].CC(=O)[O-].CC(C)(C)[O-].Cc1ccccc1.[...   
2                                           c1ccncc1   

                                     products_mapped  \
0  [CH3:1][C:2]1([CH3:3])[CH2:4][CH2:5][CH2:6][CH...   
1  [CH3:1][O:2][C:3](=[O:4])[NH:5][CH:6]([C:7](=[...   
2  [CH3:1][CH2:2][CH2:3][CH2:4][C@H:5]1[CH2:6][CH...   

      

In [14]:
len(df1)

8716