In [1]:
!pip install rdkit -q

Collecting rdkit
  Downloading rdkit-2024.9.5-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.0 kB)
Downloading rdkit-2024.9.5-cp311-cp311-manylinux_2_28_x86_64.whl (34.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m34.3/34.3 MB[0m [31m54.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rdkit
Successfully installed rdkit-2024.9.5


In [3]:
import sys
from collections import defaultdict
import pandas as pd
import sys, os, re
import rdkit
from rdkit import Chem, RDLogger
from rdkit.Chem import rdChemReactions
RDLogger.DisableLog('rdApp.*')
from template_extractor import extract_from_reaction
from rdkit.Chem import rdDepictor
rdDepictor.SetPreferCoordGen(True)
from rdkit import Chem

In [4]:
def build_template_extractor(args):
    setting = {'verbose': False, 'use_stereo': False, 'use_symbol': False, 'max_unmap': 5, 'retro': False, 'remote': True, 'least_atom_num': 0}
    for k in setting.keys():
        if k in args.keys():
            setting[k] = args[k]
    print ('Template extractor setting:', setting)
    return lambda x: extract_from_reaction(x, setting)

def get_reaction_template(extractor, rxn, _id = 0):
    rxn = {'reactants': rxn.split('>>')[0], 'products': rxn.split('>>')[1], '_id': _id}
    result = extractor(rxn)
    return rxn, result

def get_full_template(template, H_change, Charge_change, Chiral_change):
    H_code = ''.join([str(H_change[k+1]) for k in range(len(H_change))])
    Charge_code = ''.join([str(Charge_change[k+1]) for k in range(len(Charge_change))])
    Chiral_code = ''.join([str(Chiral_change[k+1]) for k in range(len(Chiral_change))])
    return '_'.join([template, H_code, Charge_code, Chiral_code])

In [5]:
args ={'verbose': False, 'use_stereo': False, 'use_symbol': True, 'max_unmap': 5, 'retro': False, 'remote': True, 'least_atom_num': 0,
      'dataset':'USPTO_Mechanism'}

In [6]:
extractor = build_template_extractor(args)

Template extractor setting: {'verbose': False, 'use_stereo': False, 'use_symbol': True, 'max_unmap': 5, 'retro': False, 'remote': True, 'least_atom_num': 0}


In [7]:
rxn = '[CH3:1][CH2:2][Cl:3].[OH-:4]>>[CH3:1][CH2:2][OH:4].[Cl-:3]'
output = get_reaction_template(extractor, rxn, _id = 0)
output[1]['reaction_smarts']

'[OH-;D0:4].[CH2;D2;+0:2]-[Cl;H0;D1;+0:3]>>[Cl-;H0;D0:3].[CH2;D2;+0:2]-[OH;D1;+0:4]'

In [8]:
atom_mapped_rxns_list = ['[O:1]=[CH:2][c:3]1[cH:4][n:5][cH:6][cH:7][cH:8]1.[OH3+:9]>>[OH+:1]=[CH:2][c:3]1[cH:4][n:5][cH:6][cH:7][cH:8]1.[OH2:9]',
 '[OH+:5]=[CH:6][c:7]1[cH:8][n:9][cH:10][cH:11][cH:12]1.[OH:1][CH2:2][CH2:3][OH:4]>>[OH+:1]([CH2:2][CH2:3][OH:4])[CH:6]([OH:5])[c:7]1[cH:8][n:9][cH:10][cH:11][cH:12]1',
 '[OH+:1]([CH2:2][CH2:3][OH:4])[CH:5]([OH:6])[c:7]1[cH:8][n:9][cH:10][cH:11][cH:12]1>>[O:1]([CH2:2][CH2:3][OH:4])[CH:5]([OH2+:6])[c:7]1[cH:8][n:9][cH:10][cH:11][cH:12]1',
 '[O:1]([CH2:2][CH2:3][OH:4])[CH:5]([OH2+:6])[c:7]1[cH:8][n:9][cH:10][cH:11][cH:12]1>>[O+:1](\\[CH2:2][CH2:3][OH:4])=[CH:5]/[c:7]1[cH:8][n:9][cH:10][cH:11][cH:12]1.[OH2:6]',
 '[O+:1](\\[CH2:2][CH2:3][OH:4])=[CH:5]/[c:6]1[cH:7][n:8][cH:9][cH:10][cH:11]1>>[O:1]1[CH2:2][CH2:3][OH+:4][CH:5]1[c:6]1[cH:7][n:8][cH:9][cH:10][cH:11]1',
 '[O:3]1[CH2:4][CH2:5][OH+:6][CH:7]1[c:8]1[cH:9][n:10][cH:11][cH:12][cH:13]1.[OH2:1].[OH2:2]>>[O:3]1[CH2:4][CH2:5][O:6][CH:7]1[c:8]1[cH:9][n:10][cH:11][cH:12][cH:13]1.[OH3+:1]']

In [9]:
rxn = atom_mapped_rxns_list[2]
output = get_reaction_template(extractor, rxn, _id = 0)
output[1]['reaction_smarts']

'[OH;D1;+0:6].[OH+;D2:1]>>[OH2+;D1:6].[O;H0;D2;+0:1]'

In [10]:
templates = []
for rxn in atom_mapped_rxns_list:
  output = get_reaction_template(extractor, rxn, _id = 0)
  templates.append((output[1]['reaction_smarts'], output[1]['intra_only']))
templates

[('[O;H0;D1;+0:1].[OH3+;D0:9]>>[OH+;D1:1].[OH2;D0;+0:9]', False),
 ('[OH;D1;+0:1].[CH;D2;+0:6]=[OH+;D1:5]>>[OH+;D2:1]-[CH;D3;+0:6]-[OH;D1;+0:5]',
  False),
 ('[OH;D1;+0:6].[OH+;D2:1]>>[OH2+;D1:6].[O;H0;D2;+0:1]', True),
 ('[O;H0;D2;+0:1]-[CH;D3;+0:5]-[OH2+;D1:6]>>[OH2;D0;+0:6].[CH;D2;+0:5]=[O+;H0;D2:1]',
  True),
 ('[OH;D1;+0:4].[CH;D2;+0:5]=[O+;H0;D2:1]>>[O;H0;D2;+0:1]-[CH;D3;+0:5]-[OH+;D2:4]',
  True),
 ('[OH2;D0;+0:1].[OH+;D2:6]>>[OH3+;D0:1].[O;H0;D2;+0:6]', False)]

In [11]:
templates_ = []
for i in templates:
  r, p = i[0].split('>>')
  templates_.append(((f"({r})>>{p}"), i[1]))
templates_

[('([O;H0;D1;+0:1].[OH3+;D0:9])>>[OH+;D1:1].[OH2;D0;+0:9]', False),
 ('([OH;D1;+0:1].[CH;D2;+0:6]=[OH+;D1:5])>>[OH+;D2:1]-[CH;D3;+0:6]-[OH;D1;+0:5]',
  False),
 ('([OH;D1;+0:6].[OH+;D2:1])>>[OH2+;D1:6].[O;H0;D2;+0:1]', True),
 ('([O;H0;D2;+0:1]-[CH;D3;+0:5]-[OH2+;D1:6])>>[OH2;D0;+0:6].[CH;D2;+0:5]=[O+;H0;D2:1]',
  True),
 ('([OH;D1;+0:4].[CH;D2;+0:5]=[O+;H0;D2:1])>>[O;H0;D2;+0:1]-[CH;D3;+0:5]-[OH+;D2:4]',
  True),
 ('([OH2;D0;+0:1].[OH+;D2:6])>>[OH3+;D0:1].[O;H0;D2;+0:6]', False)]

In [12]:
import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit.Chem import rdChemReactions as Reactions
from rdkit.Chem.Draw import IPythonConsole
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
import matplotlib.pyplot as plt
from rdkit.Chem import rdDepictor
rdDepictor.SetPreferCoordGen(True)
import os
os.getcwd()

import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem.Draw.MolDrawing import MolDrawing, DrawingOptions
import re
from collections import Counter
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem import rdChemReactions
from rdkit.Chem import rdDepictor
rdDepictor.SetPreferCoordGen(True)
#from rdchiral.template_extractor import extract_from_reaction

In [13]:
from itertools import permutations
from rdkit import Chem
from collections import Counter


def out_(outcome):
  mapnums = [a.GetAtomMapNum() for m in outcome for a in m.GetAtoms() if a.GetAtomMapNum()]
  #print(mapnums)
  if len(mapnums) != len(set(mapnums)): # duplicate?

      merged_mol = Chem.RWMol(outcome[0])
      merged_map_to_id = {a.GetAtomMapNum(): a.GetIdx() for a in outcome[0].GetAtoms() if a.GetAtomMapNum()}
      for j in range(1, len(outcome)):
          new_mol = outcome[j]
          for a in new_mol.GetAtoms():
              if a.GetAtomMapNum() not in merged_map_to_id:
                  merged_map_to_id[a.GetAtomMapNum()] = merged_mol.AddAtom(a)
          for b in new_mol.GetBonds():
              bi = b.GetBeginAtom().GetAtomMapNum()
              bj = b.GetEndAtom().GetAtomMapNum()

              if not merged_mol.GetBondBetweenAtoms(
                      merged_map_to_id[bi], merged_map_to_id[bj]):
                  merged_mol.AddBond(merged_map_to_id[bi],
                      merged_map_to_id[bj], b.GetBondType())
                  merged_mol.GetBondBetweenAtoms(
                      merged_map_to_id[bi], merged_map_to_id[bj]
                  ).SetStereo(b.GetStereo())
                  merged_mol.GetBondBetweenAtoms(
                      merged_map_to_id[bi], merged_map_to_id[bj]
                  ).SetBondDir(b.GetBondDir())
      outcome = merged_mol.GetMol()

  else:
      new_outcome = outcome[0]
      for j in range(1, len(outcome)):
          new_outcome = AllChem.CombineMols(new_outcome, outcome[j])
      outcome = new_outcome

  return outcome


#this code is from:
def make_rxns(source_rxn, reactants, intra_only=False):
    new_rxns = []
    product_sets = source_rxn.RunReactants(reactants)
    #print('product_sets before:\n', product_sets, '\n')

    if intra_only:
      product_sets = [(out_(pset),) for pset in product_sets]

    for pset in product_sets:
        new_rxn = AllChem.ChemicalReaction()
        for react in reactants:
            react = Chem.Mol(react)
            for a in react.GetAtoms():
                a.SetIntProp('molAtomMapNumber', a.GetIdx())
            new_rxn.AddReactantTemplate(react)
        for prod in pset:
            for a in prod.GetAtoms():
                a.SetIntProp('molAtomMapNumber', int(a.GetProp('react_atom_idx')))
            new_rxn.AddProductTemplate(prod)
            #print('new reaction:===>\n', AllChem.ReactionToSmiles(new_rxn))

            #----some tweaking, needed for one edge case---------
            # Get the reaction SMILES
            rxn_smiles = AllChem.ReactionToSmiles(new_rxn)

            # Tweak: Remove parentheses if they enclose the product part
            reactant_smiles, product_smiles = rxn_smiles.split('>>')
            if product_smiles.startswith('(') and product_smiles.endswith(')'):
                product_smiles = product_smiles[1:-1]
            tweaked_rxn_smiles = f"{reactant_smiles}>>{product_smiles}"

            # Convert back to a `ChemicalReaction` object
            tweaked_rxn = AllChem.ReactionFromSmarts(tweaked_rxn_smiles)
            new_rxns.append(tweaked_rxn)
            #----------------------------------------------------------
    return new_rxns

def match_pdts(true_pdts, gen_pdts):
    # Split each input string into separate SMILES components
    true_parts = true_pdts.split('.')
    gen_parts = gen_pdts.split('.')
    # Canonicalize each component of true_pdts
    true_canonicals = set()
    for smi in true_parts:
        mol = Chem.MolFromSmiles(smi)
        [a.SetAtomMapNum(0) for a in mol.GetAtoms()]
        if mol is not None:
            # Convert to canonical SMILES without stereochemistry
            true_canonicals.add(Chem.MolToSmiles(mol, isomericSmiles=False))

    # Canonicalize and compare each component of gen_pdts
    for smi in gen_parts:
        mol = Chem.MolFromSmiles(smi)
        if mol is not None:
            [a.SetAtomMapNum(0) for a in mol.GetAtoms()]
            gen_canonical = Chem.MolToSmiles(mol, isomericSmiles=False)
            if gen_canonical in true_canonicals:
                return True

    return False

def find_all_reac_reag_comb(template_smarts_str, str_of_smiles, intra_only=False):
    #print('these are getting used:', template_smarts_str, str_of_smiles)
    list_of_smiles = str_of_smiles.split('.')
    #print('all smiles:', list_of_smiles)
    reactants_smarts = template_smarts_str.split('>>')[0][1:-1]
    reactants_smarts = reactants_smarts.split('.')
    #print('reactants_smarts:', reactants_smarts)
    # Dictionary to hold matches for each SMARTS pattern
    smarts_to_matches = {smarts: [] for smarts in reactants_smarts}

    if intra_only:
        # Check if a single molecule contains all SMARTS
        results = []
        for smiles in list_of_smiles:
            mol = Chem.MolFromSmiles(smiles)
            if all(mol.HasSubstructMatch(Chem.MolFromSmarts(smarts)) for smarts in reactants_smarts):
                # Intramolecular case: all SMARTS must match within the same molecule
                reagents = list_of_smiles.copy()
                reagents.remove(smiles)  # Remaining molecules are considered reagents
                results.append(([smiles], reagents))
        return results

    for smiles in list_of_smiles:
        mol = Chem.MolFromSmiles(smiles)
        for smarts in reactants_smarts:
            #print('smarts here:\n', smarts)
            smarts_mol = Chem.MolFromSmarts(smarts)
            if mol.HasSubstructMatch(smarts_mol):
                smarts_to_matches[smarts].append(smiles)

    # Generate all permutations of matches for each SMARTS pattern
    reactant_combinations = []
    for reactant_permutation in permutations(list_of_smiles, len(reactants_smarts)):
        match = True
        for reactant, smarts in zip(reactant_permutation, reactants_smarts):
            mol = Chem.MolFromSmiles(reactant)
            smarts_mol = Chem.MolFromSmarts(smarts)
            if not mol.HasSubstructMatch(smarts_mol):
                match = False
                break
        if match:
            reactant_combinations.append(list(reactant_permutation))


    # Remove duplicate combinations
    unique_combinations = []
    for combo in reactant_combinations:
        if combo not in unique_combinations:
            unique_combinations.append(combo)

    # Generate reagents for each reactant combination, maintaining counts
    results = []
    original_counts = Counter(list_of_smiles)  # Keep track of the original counts
    for reactants in unique_combinations:
        reactant_counts = Counter(reactants)
        reagents_counts = original_counts - reactant_counts  # Subtract reactant counts
        reagents = list(reagents_counts.elements())  # Reconstruct the reagents list
        results.append((reactants, reagents))


    return results


def transform_func3(template_smarts_str_, str_of_smiles):
    template_smarts_str = template_smarts_str_[0]
    intra_only = template_smarts_str_[1]

    reactant_reagents = find_all_reac_reag_comb(template_smarts_str, str_of_smiles, intra_only)
    #print('reactant_reagents:-->\n', reactant_reagents)
    # if not reactant_reagents:
    #     print(f"No valid reactant-reagent combinations found for: {template_smarts_str} with {str_of_smiles}")
    #     return []  # Return an empty list or handle appropriately


    all_mapped_rxns = []  # Store results for all reactant-reagent combinations

    for reactants_smiles_list, reagents_smiles_list in reactant_reagents:
        reagent_smiles = '.'.join(reagents_smiles_list)
        reactants_mols_list = [Chem.MolFromSmiles(smiles) for smiles in reactants_smiles_list]
        template = AllChem.ReactionFromSmarts(template_smarts_str)

        # -------------------combine the reactants into a single mol object--------------
        rmol = None
        for mol in reactants_mols_list:
            if rmol is None:
                rmol = mol  # If rmol is None, set it to the first molecule
            else:
                rmol = Chem.CombineMols(rmol, mol)

        # ----------------------mapping the reactants and products-------------------------
        atom_mapped_rxns = []
        for r in make_rxns(template, [rmol], intra_only):
            #print('template, intra_only:\n', template, intra_only, '\n')
            smi = AllChem.ReactionToSmiles(r)
            #print('atom_mapped_rxns before cleaning:', smi, '\n')
            smi = re.sub(r'^\((.*)\)>', r'\1>', smi)  # Clean up the reaction SMARTS
            atom_mapped_rxns.append(smi)
            #print('atom_mapped_rxns after cleaning:', atom_mapped_rxns, '\n')
        # --------------------------adding the reagents------------------------------------
        for i in atom_mapped_rxns:
            mapped_rxn = i.replace(">>", f">{reagent_smiles}>")
            all_mapped_rxns.append(mapped_rxn)

    # Remove duplicates from all mapped reactions
    all_mapped_rxns = list(set(all_mapped_rxns))
    return all_mapped_rxns


def dfs_with_processing1(start, transform_func, validate_func, target_output, param_list):
    visited = set()  # To track visited states and avoid cycles

    def process_output(output):
        """Custom processing logic on the outputs if required."""
        try:
            output_reagent = output.split('>')[1]
            output_pdt = output.split('>')[2]
            if output_reagent:
                processed = '.'.join([output_pdt, output_reagent])
            else:
                processed = output_pdt
            return processed

        except IndexError:
            return None  # If processing fails, return None to signal an invalid path

    def is_valid_molecule(smiles):
        """Check if a SMILES string corresponds to a valid molecule."""
        try:
            mol = Chem.MolFromSmiles(smiles)
            return mol is not None
        except:
            return False

    def dfs(current_input, path, param_index):
        # Stop if the current input matches the target output
        if validate_func(current_input, target_output):
            print('Got it, matched!')
            return path

        if (current_input, param_index) in visited:
            return None
        visited.add((current_input, param_index))

        # Check if we've exhausted all transformation rules
        if param_index >= len(param_list):
            return None

        # Apply the current transformation rule
        current_param = param_list[param_index]
        raw_outputs = transform_func(current_param, current_input)  # Raw outputs from transform_func
        #print('raw_outputs:\n', raw_outputs)
        processed_outputs = []

        for raw in raw_outputs:
            # Process the raw output
            processed = process_output(raw)
            if processed is None:
                continue  # Skip invalid processing results

            # Validate individual SMILES in the processed output
            components = processed.split('.')
            if all(is_valid_molecule(comp) for comp in components):
                processed_outputs.append((raw, processed))
        #print('processed_outputs=======>\n', processed_outputs)
        for raw_output, processed_output in processed_outputs:
            # Recursive DFS with the next SMARTS pattern in param_list
            result = dfs(processed_output, path + [(current_param, raw_output, processed_output)], param_index + 1)
            if result is not None:
                return result

        return None  # No valid path found

    # Start with the first SMARTS pattern
    return dfs(start, [], 0)


In [14]:
start = '[O:1]=[CH:2][c:3]1[cH:4][n:5][cH:6][cH:7][cH:8]1.[OH3+:9]'
target_output = '[OH+:1]=[CH:2][c:3]1[cH:4][n:5][cH:6][cH:7][cH:8]1.[OH2:9]'
results = dfs_with_processing1(start, transform_func3, match_pdts, target_output, templates_[:1])
[results[i][1] for i in range(len(results))]

Got it, matched!


['[O:0]=[CH:1][c:2]1[cH:3][n:4][cH:5][cH:6][cH:7]1.[OH3+:8]>>[OH+:0]=[CH:1][c:2]1[cH:3][n:4][cH:5][cH:6][cH:7]1.[OH2:8]']

In [15]:
#Task1- prepare a atom-mapped rxn mechanism
#Task2- automatically get the templates
#Task3- apply the templates to a set of reactants (given the products)
#Task4- to do the task3 I need a set of samples

In [16]:
examples = ['O=Cc1ccc2nc(C(=O)Nc3ccccc3)cn2c1.OCCO.[OH3+]>>O=C(Nc1ccccc1)c2cn3cc(C4OCCO4)ccc3n2', #aldehyde
            'O=Cc1cc(Br)cs1.OCCO.[OH3+]>>Brc1csc(C2OCCO2)c1', #aldehyde
            'CC(=O)c1ccc(Nc2ncc(Cl)c(NC3CC3)n2)cc1.OCCO.[OH3+]>>CC1(OCCO1)c2ccc(Nc3ncc(Cl)c(NC4CC4)n3)cc2'] #ketone

In [17]:
start = 'O=Cc1ccc2nc(C(=O)Nc3ccccc3)cn2c1.OCCO.[OH3+]'
target_output = 'O=C(Nc1ccccc1)c2cn3cc(C4OCCO4)ccc3n2'
results = dfs_with_processing1(start, transform_func3, match_pdts, target_output, templates_)
[results[i][1] for i in range(len(results))]

Got it, matched!


['[O:0]=[CH:1][c:2]1[cH:3][cH:4][c:5]2[n:6][c:7]([C:8](=[O:9])[NH:10][c:11]3[cH:12][cH:13][cH:14][cH:15][cH:16]3)[cH:17][n:18]2[cH:19]1.[OH3+:20]>OCCO>[OH+:0]=[CH:1][c:2]1[cH:3][cH:4][c:5]2[n:6][c:7]([C:8](=[O:9])[NH:10][c:11]3[cH:12][cH:13][cH:14][cH:15][cH:16]3)[cH:17][n:18]2[cH:19]1.[OH2:20]',
 '[OH+:4]=[CH:5][c:6]1[cH:7][cH:8][c:9]2[n:10][c:11]([C:12](=[O:13])[NH:14][c:15]3[cH:16][cH:17][cH:18][cH:19][cH:20]3)[cH:21][n:22]2[cH:23]1.[OH:0][CH2:1][CH2:2][OH:3]>[OH2:20]>[OH+:0]([CH2:1][CH2:2][OH:3])[CH:5]([OH:4])[c:6]1[cH:7][cH:8][c:9]2[n:10][c:11]([C:12](=[O:13])[NH:14][c:15]3[cH:16][cH:17][cH:18][cH:19][cH:20]3)[cH:21][n:22]2[cH:23]1',
 '[OH+:0]([CH2:1][CH2:2][OH:3])[CH:4]([OH:5])[c:6]1[cH:7][cH:8][c:9]2[n:10][c:11]([C:12](=[O:13])[NH:14][c:15]3[cH:16][cH:17][cH:18][cH:19][cH:20]3)[cH:21][n:22]2[cH:23]1>[OH2:20]>[O:0]([CH2:1][CH2:2][OH:3])[CH:4]([OH2+:5])[c:6]1[cH:7][cH:8][c:9]2[n:10][c:11]([C:12](=[O:13])[NH:14][c:15]3[cH:16][cH:17][cH:18][cH:19][cH:20]3)[cH:21][n:22]2[cH:23]1',
 '

In [18]:
start = 'O=Cc1cc(Br)cs1.OCCO.[OH3+]'
target_output = 'Brc1csc(C2OCCO2)c1'
results = dfs_with_processing1(start, transform_func3, match_pdts, target_output, templates_)
[results[i][1] for i in range(len(results))]

Got it, matched!


['[O:0]=[CH:1][c:2]1[cH:3][c:4]([Br:5])[cH:6][s:7]1.[OH3+:8]>OCCO>[OH+:0]=[CH:1][c:2]1[cH:3][c:4]([Br:5])[cH:6][s:7]1',
 '[OH+:4]=[CH:5][c:6]1[cH:7][c:8]([Br:9])[cH:10][s:11]1.[OH:0][CH2:1][CH2:2][OH:3]>>[OH+:0]([CH2:1][CH2:2][OH:3])[CH:5]([OH:4])[c:6]1[cH:7][c:8]([Br:9])[cH:10][s:11]1',
 '[OH+:0]([CH2:1][CH2:2][OH:3])[CH:4]([OH:5])[c:6]1[cH:7][c:8]([Br:9])[cH:10][s:11]1>>[O:0]([CH2:1][CH2:2][OH:3])[CH:4]([OH2+:5])[c:6]1[cH:7][c:8]([Br:9])[cH:10][s:11]1',
 '[O:0]([CH2:1][CH2:2][OH:3])[CH:4]([OH2+:5])[c:6]1[cH:7][c:8]([Br:9])[cH:10][s:11]1>>[O+:0]([CH2:1][CH2:2][OH:3])=[CH:4][c:6]1[cH:7][c:8]([Br:9])[cH:10][s:11]1.[OH2:5]',
 '[O+:0]([CH2:1][CH2:2][OH:3])=[CH:4][c:5]1[cH:6][c:7]([Br:8])[cH:9][s:10]1>[OH2:5]>[O:0]1[CH2:1][CH2:2][OH+:3][CH:4]1[c:5]1[cH:6][c:7]([Br:8])[cH:9][s:10]1',
 '[O:1]1[CH2:2][CH2:3][OH+:4][CH:5]1[c:6]1[cH:7][c:8]([Br:9])[cH:10][s:11]1.[OH2:0]>>[O:1]1[CH2:2][CH2:3][O:4][CH:5]1[c:6]1[cH:7][c:8]([Br:9])[cH:10][s:11]1.[OH3+:0]']

In [19]:
start = 'CC(=O)c1ccc(Nc2ncc(Cl)c(NC3CC3)n2)cc1.OCCO.[OH3+]'
target_output = 'CC1(OCCO1)c2ccc(Nc3ncc(Cl)c(NC4CC4)n3)cc2'
results = dfs_with_processing1(start, transform_func3, match_pdts, target_output, templates_)
[results[i][1] for i in range(len(results))]

TypeError: object of type 'NoneType' has no len()

In [20]:
#so automatically extract the templates for the ketone substrates and check, if it works
#but the problem is I need to get atom-mapped_rxn, this I can't bypass, except If I use ML model or some programs, e.g. marvins

In [21]:
atom_mapped_rxns_list = ['[CH3:1][C:2](=[O:21])[c:3]1[cH:4][cH:5][c:6]([NH:7][c:8]2[n:18][cH:17][c:15]([Cl:16])[c:10]([NH:11][CH:12]3[CH2:13][CH2:14]3)[n:9]2)[cH:19][cH:20]1.[OH3+:22]>>[CH3:1][C:2](=[OH+:21])[c:3]1[cH:4][cH:5][c:6]([NH:7][c:8]2[n:18][cH:17][c:15]([Cl:16])[c:10]([NH:11][CH:12]3[CH2:13][CH2:14]3)[n:9]2)[cH:19][cH:20]1.[OH2:22]',
'[CH3:1][C:2](=[OH+:21])[c:3]1[cH:4][cH:5][c:6]([NH:7][c:8]2[n:18][cH:17][c:15]([Cl:16])[c:10]([NH:11][CH:12]3[CH2:13][CH2:14]3)[n:9]2)[cH:19][cH:20]1.[OH:22][CH2:23][CH2:24][OH:25]>>[CH3:1][C:2]([OH:21])([OH+:25][CH2:24][CH2:23][OH:22])[c:3]1[cH:4][cH:5][c:6]([NH:7][c:8]2[n:18][cH:17][c:15]([Cl:16])[c:10]([NH:11][CH:12]3[CH2:13][CH2:14]3)[n:9]2)[cH:19][cH:20]1',
'[CH3:1][C:2]([OH:7])([OH+:3][CH2:4][CH2:5][OH:6])[c:8]1[cH:9][cH:10][c:11]([NH:12][c:13]2[n:14][cH:15][c:16]([Cl:17])[c:18]([NH:19][CH:20]3[CH2:21][CH2:22]3)[n:23]2)[cH:24][cH:25]1>>[CH3:1][C:2]([OH2+:7])([O:3][CH2:4][CH2:5][OH:6])[c:8]1[cH:9][cH:10][c:11]([NH:12][c:13]2[n:14][cH:15][c:16]([Cl:17])[c:18]([NH:19][CH:20]3[CH2:21][CH2:22]3)[n:23]2)[cH:24][cH:25]1',
'[CH3:1][C:2]([OH2+:7])([O:3][CH2:4][CH2:5][OH:6])[c:8]1[cH:9][cH:10][c:11]([NH:12][c:13]2[n:14][cH:15][c:16]([Cl:17])[c:18]([NH:19][CH:20]3[CH2:21][CH2:22]3)[n:23]2)[cH:24][cH:25]1>>[CH3:1]\[C:2](=[O+:3]/[CH2:4][CH2:5][OH:6])[c:8]1[cH:9][cH:10][c:11]([NH:12][c:13]2[n:14][cH:15][c:16]([Cl:17])[c:18]([NH:19][CH:20]3[CH2:21][CH2:22]3)[n:23]2)[cH:24][cH:25]1.[OH2:7]',
'[CH3:1]\[C:2](=[O+:21]/[CH2:22][CH2:23][OH:24])[c:3]1[cH:4][cH:5][c:6]([NH:7][c:8]2[n:18][cH:17][c:15]([Cl:16])[c:10]([NH:11][CH:12]3[CH2:13][CH2:14]3)[n:9]2)[cH:19][cH:20]1>>[CH3:1][C:2]1([O:21][CH2:22][CH2:23][OH+:24]1)[c:3]1[cH:4][cH:5][c:6]([NH:7][c:8]2[n:18][cH:17][c:15]([Cl:16])[c:10]([NH:11][CH:12]3[CH2:13][CH2:14]3)[n:9]2)[cH:19][cH:20]1',
'[CH3:1][C:2]1([O:3][CH2:4][CH2:5][OH+:6]1)[c:7]1[cH:8][cH:9][c:10]([NH:11][c:12]2[n:13][cH:14][c:15]([Cl:16])[c:17]([NH:18][CH:19]3[CH2:20][CH2:21]3)[n:22]2)[cH:23][cH:24]1.[OH2:25]>>[CH3:1][C:2]1([O:6][CH2:5][CH2:4][O:3]1)[c:7]1[cH:8][cH:9][c:10]([NH:11][c:12]2[n:13][cH:14][c:15]([Cl:16])[c:17]([NH:18][CH:19]3[CH2:20][CH2:21]3)[n:22]2)[cH:23][cH:24]1.[OH3+:25]']

In [22]:
templates = []
for rxn in atom_mapped_rxns_list:
  output = get_reaction_template(extractor, rxn, _id = 0)
  templates.append((output[1]['reaction_smarts'], output[1]['intra_only']))
templates

[('[O;H0;D1;+0:21].[OH3+;D0:22]>>[OH+;D1:21].[OH2;D0;+0:22]', False),
 ('[OH;D1;+0:25].[C;H0;D3;+0:2]=[OH+;D1:21]>>[OH+;D2:25]-[C;H0;D4;+0:2]-[OH;D1;+0:21]',
  False),
 ('[OH;D1;+0:7].[OH+;D2:3]>>[OH2+;D1:7].[O;H0;D2;+0:3]', True),
 ('[O;H0;D2;+0:3]-[C;H0;D4;+0:2]-[OH2+;D1:7]>>[OH2;D0;+0:7].[C;H0;D3;+0:2]=[O+;H0;D2:3]',
  True),
 ('[OH;D1;+0:24].[C;H0;D3;+0:2]=[O+;H0;D2:21]>>[O;H0;D2;+0:21]-[C;H0;D4;+0:2]-[OH+;D2:24]',
  True),
 ('[OH2;D0;+0:25].[OH+;D2:6]>>[OH3+;D0:25].[O;H0;D2;+0:6]', False)]

In [23]:
templates_ = []
for i in templates:
  r, p = i[0].split('>>')
  templates_.append(((f"({r})>>{p}"), i[1]))
templates_

[('([O;H0;D1;+0:21].[OH3+;D0:22])>>[OH+;D1:21].[OH2;D0;+0:22]', False),
 ('([OH;D1;+0:25].[C;H0;D3;+0:2]=[OH+;D1:21])>>[OH+;D2:25]-[C;H0;D4;+0:2]-[OH;D1;+0:21]',
  False),
 ('([OH;D1;+0:7].[OH+;D2:3])>>[OH2+;D1:7].[O;H0;D2;+0:3]', True),
 ('([O;H0;D2;+0:3]-[C;H0;D4;+0:2]-[OH2+;D1:7])>>[OH2;D0;+0:7].[C;H0;D3;+0:2]=[O+;H0;D2:3]',
  True),
 ('([OH;D1;+0:24].[C;H0;D3;+0:2]=[O+;H0;D2:21])>>[O;H0;D2;+0:21]-[C;H0;D4;+0:2]-[OH+;D2:24]',
  True),
 ('([OH2;D0;+0:25].[OH+;D2:6])>>[OH3+;D0:25].[O;H0;D2;+0:6]', False)]

In [24]:
transform_func3(template_smarts_str_=templates_[1], str_of_smiles='CC(c1ccc(Nc2nc(NC3CC3)c(Cl)cn2)cc1)=[OH+].OCCO')

['[CH3:4][C:5]([c:6]1[cH:7][cH:8][c:9]([NH:10][c:11]2[n:12][c:13]([NH:14][CH:15]3[CH2:16][CH2:17]3)[c:18]([Cl:19])[cH:20][n:21]2)[cH:22][cH:23]1)=[OH+:24].[OH:0][CH2:1][CH2:2][OH:3]>>[OH:0][CH2:1][CH2:2][OH+:3][C:5]([CH3:4])([c:6]1[cH:7][cH:8][c:9]([NH:10][c:11]2[n:12][c:13]([NH:14][CH:15]3[CH2:16][CH2:17]3)[c:18]([Cl:19])[cH:20][n:21]2)[cH:22][cH:23]1)[OH:24]',
 '[CH3:4][C:5]([c:6]1[cH:7][cH:8][c:9]([NH:10][c:11]2[n:12][c:13]([NH:14][CH:15]3[CH2:16][CH2:17]3)[c:18]([Cl:19])[cH:20][n:21]2)[cH:22][cH:23]1)=[OH+:24].[OH:0][CH2:1][CH2:2][OH:3]>>[OH+:0]([CH2:1][CH2:2][OH:3])[C:5]([CH3:4])([c:6]1[cH:7][cH:8][c:9]([NH:10][c:11]2[n:12][c:13]([NH:14][CH:15]3[CH2:16][CH2:17]3)[c:18]([Cl:19])[cH:20][n:21]2)[cH:22][cH:23]1)[OH:24]']

In [25]:
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import numpy as np

def kmeans_clustering_with_visualization(fingerprints, max_clusters=10):
    """
    Perform K-means clustering on molecular fingerprints and visualize clusters.

    Args:
        fingerprints (np.ndarray): 2D array of molecular fingerprints.
        max_clusters (int): Maximum number of clusters to test.

    Returns:
        dict: Contains cluster labels, number of clusters, and PCA-transformed fingerprints.
    """
    best_k = None
    best_score = -1
    best_model = None
    scores = []

    # Find the optimal number of clusters
    for n_clusters in range(2, max_clusters + 1):
        kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init='auto')
        labels = kmeans.fit_predict(fingerprints)

        # Calculate silhouette score
        score = silhouette_score(fingerprints, labels)
        scores.append((n_clusters, score))

        if score > best_score:
            best_k = n_clusters
            best_score = score
            best_model = kmeans

    # Perform PCA for visualization
    pca = PCA(n_components=2, random_state=42)
    transformed_fingerprints = pca.fit_transform(fingerprints)

    # # Plot the silhouette scores
    # clusters, sil_scores = zip(*scores)
    # plt.figure(figsize=(12, 5))
    # plt.subplot(1, 2, 1)
    # plt.plot(clusters, sil_scores, marker='o', label='Silhouette Score')
    # plt.xlabel('Number of Clusters')
    # plt.ylabel('Silhouette Score')
    # plt.title('Silhouette Scores vs Number of Clusters')
    # plt.legend()

    # # Plot the clusters
    # plt.subplot(1, 2, 2)
    # labels = best_model.labels_
    # plt.scatter(
    #     transformed_fingerprints[:, 0],
    #     transformed_fingerprints[:, 1],
    #     c=labels,
    #     cmap='viridis',
    #     marker='o',
    #     edgecolor='k'
    # )
    # plt.xlabel('PCA Component 1')
    # plt.ylabel('PCA Component 2')
    # plt.title(f'Clustering with K={best_k}')
    # plt.colorbar(label='Cluster ID')

    # plt.tight_layout()
    # plt.show()

    # Return clustering results
    return {
        "labels": labels,
        "n_clusters": best_k,
        "transformed_fingerprints": transformed_fingerprints,
        "model": best_model,
        "scores": scores
    }



In [26]:
from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np

def generate_difference_fingerprint(rxn_smiles):
    """
    Generates a difference fingerprint for a given reaction SMILES.

    Args:
        rxn_smiles (str): Reaction SMILES in the form "reactant>>product".

    Returns:
        np.ndarray: Difference fingerprint as a binary array.
    """
    def generate_fingerprint(smiles, fp_size=2048):
        """Generate a molecular fingerprint (binary) from SMILES."""
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            raise ValueError(f"Invalid SMILES: {smiles}")
        fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=fp_size)
        # Convert the ExplicitBitVect to a numpy array
        arr = np.array(list(fp))
        return arr

    # Step 1: Split reaction SMILES into reactants and products
    try:
        reactants_smiles, products_smiles = rxn_smiles.split(">>")
    except ValueError:
        raise ValueError(f"Invalid reaction SMILES format: {rxn_smiles}")

    # Step 2: Generate fingerprints for reactants and products
    reactant_fps = np.zeros((2048,), dtype=int)
    for reactant in reactants_smiles.split('.'):
        reactant_fps |= generate_fingerprint(reactant)  # Combine reactants with OR

    product_fps = np.zeros((2048,), dtype=int)
    for product in products_smiles.split('.'):
        product_fps |= generate_fingerprint(product)  # Combine products with OR

    # Step 3: Create the difference fingerprint (XOR)
    difference_fp = reactant_fps ^ product_fps

    return difference_fp

# Example Usage
reaction_smiles = "CCO>>CC=O"  # Simple ethanol oxidation reaction
diff_fp = generate_difference_fingerprint(reaction_smiles)
print("Difference Fingerprint:", diff_fp)


Difference Fingerprint: [0 0 0 ... 0 0 0]


In [27]:
# df = pd.read_csv('/content/nucleophilic_attack_to_(thio)carbonyl_or_sulfonyl_gen_mech_new.csv')
# df.head(3)

In [28]:
# diff_fps = [generate_difference_fingerprint(i) for i in list(df['updated_reaction'])]
# diff_fps = np.array(diff_fps)
# diff_fps.shape

In [29]:
# output = kmeans_clustering_with_visualization(diff_fps)
# set(output['labels'])

In [30]:
#df['label'] = output['labels']

In [31]:
#df[df['label']==0]['updated_reaction'].iloc[3]

In [32]:
#df[df['label']==5]['updated_reaction'].iloc[6]

In [33]:
#df.value_counts('label')

In [34]:
#df.shape