In [13]:
import time
import pulp
import json
import pandas as pd
import re
from rdkit import Chem
from rdkit.Chem import AllChem

# Suppress RDKit warnings
from rdkit import RDLogger
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

In [22]:
def count_substructures(radius, molecule):
    """Helper function for get the information of molecular signature of a
    metabolite. The relaxed signature requires the number of each substructure
    to construct a matrix for each molecule.
    Parameters
    ----------
    radius : int
        the radius is bond-distance that defines how many neighbor atoms should
        be considered in a reaction center.
    molecule : Molecule
        a molecule object create by RDkit (e.g. Chem.MolFromInchi(inchi_code)
        or Chem.MolToSmiles(smiles_code))
    Returns
    -------
    dict
        dictionary of molecular signature for a molecule,
        {smiles: molecular_signature}
    """
    m = molecule
    #m = Chem.AddHs(m)
    smi_count = {}
    atomList = [atom for atom in m.GetAtoms()]

    # Suppress RDKit warnings
    lg = RDLogger.logger()
    lg.setLevel(RDLogger.CRITICAL)

    for i in range(len(atomList)):
        atoms = set()
        env = None
        for r in range(radius, -1, -1):  # Decrease radius until 0
            env = Chem.FindAtomEnvironmentOfRadiusN(m, r, i)
            if env:
                break
        
        if env:
            for bidx in env:
                atoms.add(m.GetBondWithIdx(bidx).GetBeginAtomIdx())
                atoms.add(m.GetBondWithIdx(bidx).GetEndAtomIdx())
        else:
            atoms.add(i)  # Add the current atom index if no neighbors found at any radius

        # Convert atoms to smiles
        substructure = Chem.MolFragmentToSmiles(m, atomsToUse=list(atoms),
                                                bondsToUse=env, canonical=True, isomericSmiles=True)
        
        # Convert SMILES back to molecule
        sub_mol = Chem.MolFromSmiles(substructure)
        
        # Remove hydrogens if any
        if sub_mol:
            sub_mol = Chem.RemoveHs(sub_mol)
            # Convert back to SMILES
            substructure = Chem.MolToSmiles(sub_mol)
        
        if substructure in smi_count:
            smi_count[substructure] += 1
        else:
            smi_count[substructure] = 1
    return smi_count

def add_dicts(dict1, dict2):
    result_dict = {}
    # Iterate over the keys of the first dictionary
    for key in dict1:
        # Add the value from the first dictionary
        result_dict[key] = dict1[key]
        # If the key exists in the second dictionary, add its value as well
        if key in dict2:
            result_dict[key] += dict2[key]
    # Add keys from the second dictionary that are not in the first dictionary
    for key in dict2:
        if key not in dict1:
            result_dict[key] = dict2[key]
    for key, value in result_dict.copy().items():
        if value == 0:
            del result_dict[key]
    return result_dict



def subtract_dicts(dict1, dict2):
    result_dict = {}
    # Iterate over the keys of the first dictionary
    for key in dict1:
        # Subtract the value from the second dictionary if the key exists
        if key in dict2:
            result_dict[key] = dict1[key] - dict2[key]
        else:
            # If the key doesn't exist in the second dictionary, use the value from the first dictionary
            result_dict[key] = dict1[key]
    # Add keys from the second dictionary that are not in the first dictionary
    for key in dict2:
        if key not in dict1:
            # If the key doesn't exist in the first dictionary, use the negative of the value from the second dictionary
            result_dict[key] = -dict2[key]
    for key, value in result_dict.copy().items():
        if value == 0:
            del result_dict[key]
    return result_dict

def reaction_string_to_moiety_change_dict(reaction_string,radius):
    intial_state = reaction_string.split(">>")[0]
    final_state = reaction_string.split(">>")[1]
    
    intial_state_smiles = intial_state.split(".")
    final_state_smiles = final_state.split(".")
    
    if radius != 'MAX':
        state_i = {}
        for smile in intial_state_smiles:
            state_i = add_dicts(state_i,count_substructures(radius,Chem.MolFromSmiles(smile)))
            
        state_f = {}
        for smile in final_state_smiles:
            state_f = add_dicts(state_f,count_substructures(radius,Chem.MolFromSmiles(smile)))
        
        return subtract_dicts(state_f,state_i)
    else:
        state_i = {}
        for smile in intial_state_smiles:
            state_i = add_dicts(state_i,{smile : 1})

        state_f = {}
        for smile in final_state_smiles:
            state_f = add_dicts(state_f,{smile : 1})
            
        return subtract_dicts(state_f,state_i)
        
def moiety_dict_to_reaction_smiles(dict):
    final = []
    intial = []
    for key, value in dict.copy().items():
        if value == 0:
            del result_dict[key]
        if value > 0:
            for k in range(abs(value)):
                final.append(key)
        if value < 0:
            for k in range(abs(value)):
                intial.append(key)
            
    Intial_state = '.'.join(intial)
    Final_state = '.'.join(final)
    
    return Intial_state + '>>' + Final_state

def fix_rxn(reaction_num, leftover_string):
    return moiety_dict_to_reaction_smiles(subtract_dicts(reaction_string_to_moiety_change_dict(reaction_string[reaction_num],'MAX'),reaction_string_to_moiety_change_dict(leftover_string,'MAX')))

def reaction_atom_and_e_balance(reaction_smiles):
    
    intial_state = reaction_smiles.split(">>")[0]
    final_state = reaction_smiles.split(">>")[1]
                
    intial_state_smiles = intial_state.split(".")
    final_state_smiles = final_state.split(".")
        
    smiles = []
    mol_parts = []
    intial_state = []
    final_state = []
    for smile in intial_state_smiles:
        mol = Chem.MolFromSmiles(smile)
        mol = Chem.AddHs(mol)
        charge = Chem.GetFormalCharge(mol)
        Balance = {}
        for atom in mol.GetAtoms():
            Balance = add_dicts(Balance,{atom.GetSymbol():-1})
            for k in range(atom.GetAtomicNum()):
                Balance = add_dicts(Balance,{'e':-1})
        if charge > 0:
            for k in range(abs(charge)):
                 Balance = add_dicts(Balance,{'e':1})
        elif charge < 0:
            for k in range(abs(charge)):
                 Balance = add_dicts(Balance,{'e':-1})
        intial_state.append(Balance)
        mol_parts.append(Balance)
        smiles.append(smile)

                    
    state_f = []
    for smile in final_state_smiles:
        mol = Chem.MolFromSmiles(smile)
        mol = Chem.AddHs(mol)
        charge = Chem.GetFormalCharge(mol)
        Balance = {}
        for atom in mol.GetAtoms():
            Balance = add_dicts(Balance,{atom.GetSymbol():1})
            for k in range(atom.GetAtomicNum()):
                Balance = add_dicts(Balance,{'e':1})
        if charge > 0:
            for k in range(abs(charge)):
                Balance = add_dicts(Balance,{'e':-1})
        elif charge < 0:
            for k in range(abs(charge)):
                Balance = add_dicts(Balance,{'e':1})
        final_state.append(Balance)
        mol_parts.append(Balance)
        smiles.append(smile)
        
    intial = []    
    for smile in intial_state:
        intial = add_dicts(intial,smile)

    final = []    
    for smile in final_state:
        final = add_dicts(final,smile)
        
    Balance = add_dicts(final,intial)
    return Balance , mol_parts , smiles


# extra arrow environment rules form the added protonation rules 
  
def similar_mechs(solution):
     
    solution_arrows = set()
    for rule in solution:
        steps = rule.split('&')
        for step in steps:
            if 'pro' in step:
                for a in Protonation_arrow_rules[step]:
                        solution_arrows.add(a)
                        solution_arrows.add(a)
            elif '(' in step:
                j,i,s = step.split('(')[1].split(')')[0].split('_')
            else:
                j,i,s = step.split('_')
            try:    
                for arr in MCSA_arrow_rules[int(j)-1][int(i)-1][int(s)-1]:
                    for a in arr:
                        solution_arrows.add(a)
                        solution_arrows.add(a)
            except:
                continue
    
    solution_scores = {}
    for m in MCSA_mechanism_arrows:
        if MCSA_mechanism_arrows[m] != set():
            added_lists = list(MCSA_mechanism_arrows[m]) + list(solution_arrows)
            solution_scores[m] = (len(added_lists)-len(set(added_lists)))/len(set(added_lists))
    
    return dict(sorted(solution_scores.items(), key=lambda item: item[1], reverse=True))
    

In [33]:
def count_substructures(radius, molecule):
    """
    Counts occurrences of chemical substructures (moieties) within a given molecule.
    Each atom is treated as the center of a substructure, defined by a given bond radius.

    Parameters
    ----------
    radius : int
        The bond-distance radius defining the size of each substructure.
    molecule : rdkit.Chem.Mol
        An RDKit molecule object to be analyzed.

    Returns
    -------
    dict
        A dictionary mapping the canonical SMILES of each substructure to its count.
        Example: {'CCO': 2, 'CC': 3}
    """
    # Assign the input molecule to a shorter alias
    m = molecule
    # Initialize the dictionary to store counts of each substructure's SMILES
    smi_count = {}
    # Get a list of all atoms in the molecule
    atomList = [atom for atom in m.GetAtoms()]

    # Get the RDKit logger object to suppress verbose warning messages during execution
    lg = RDLogger.logger()
    lg.setLevel(RDLogger.CRITICAL)

    # Iterate over each atom in the molecule to treat it as a potential substructure center
    for i in range(len(atomList)):
        # Use a set to store the indices of atoms in the current substructure to avoid duplicates
        atoms = set()
        # Initialize the environment (the set of bond indices) as None
        env = None

        # Iteratively find the largest possible chemical environment up to the given radius.
        # This loop starts at the max radius and decreases, finding the first valid environment.
        for r in range(radius, -1, -1):
            env = Chem.FindAtomEnvironmentOfRadiusN(m, r, i)
            if env:
                break # Stop once the largest possible environment is found

        # If an environment (a set of bonds) was found...
        if env:
            # ...collect all atom indices connected by those bonds
            for bidx in env:
                atoms.add(m.GetBondWithIdx(bidx).GetBeginAtomIdx())
                atoms.add(m.GetBondWithIdx(bidx).GetEndAtomIdx())
        else:
            # If no environment was found (e.g., an isolated atom), use the central atom's index
            atoms.add(i)

        # Generate a canonical SMILES string for the identified substructure
        substructure = Chem.MolFragmentToSmiles(m, atomsToUse=list(atoms),
                                                bondsToUse=env, canonical=True, isomericSmiles=True)
        
        # To standardize the SMILES, convert it back to a molecule object...
        sub_mol = Chem.MolFromSmiles(substructure)
        
        # ...and remove any explicit hydrogen atoms
        if sub_mol:
            sub_mol = Chem.RemoveHs(sub_mol)
            # Convert the standardized substructure back to a SMILES string
            substructure = Chem.MolToSmiles(sub_mol)
        
        # Increment the count for this substructure in the dictionary
        if substructure in smi_count:
            smi_count[substructure] += 1
        else:
            smi_count[substructure] = 1
            
    # Return the final dictionary of substructure counts
    return smi_count

def add_dicts(dict1, dict2):
    """
    Combines two dictionaries by adding the values of common keys.

    Returns
    -------
    dict
        A new dictionary containing the summed values. Keys with a final value of 0 are removed.
    """
    # Start with a copy of the first dictionary
    result_dict = dict1.copy()
    
    # Iterate over the second dictionary
    for key, value in dict2.items():
        # Add the value from the second dictionary to the result,
        # using .get(key, 0) to handle keys that are not in the first dictionary.
        result_dict[key] = result_dict.get(key, 0) + value

    # Create a new dictionary, excluding any keys where the summed value is 0
    return {key: value for key, value in result_dict.items() if value != 0}

def subtract_dicts(dict1, dict2):
    """
    Subtracts the values of one dictionary from another based on common keys.

    Returns
    -------
    dict
        A new dictionary containing the results of the subtraction. Keys with a final value of 0 are removed.
    """
    # Start with a copy of the first dictionary (the minuend)
    result_dict = dict1.copy()

    # Iterate over the second dictionary (the subtrahend)
    for key, value in dict2.items():
        # Subtract the value from the second dictionary from the result,
        # using .get(key, 0) to handle keys not present in the first.
        result_dict[key] = result_dict.get(key, 0) - value

    # Create a new dictionary, excluding any keys where the final value is 0
    return {key: value for key, value in result_dict.items() if value != 0}


def reaction_string_to_moiety_change_dict(reaction_string, radius):
    """
    Parses a reaction SMILES string and calculates the net change in moieties.

    Parameters
    ----------
    reaction_string : str
        A reaction SMILES string (e.g., "CCO.O>>CC(=O)O").
    radius : int or str
        The radius for substructure definition, or 'MAX' to treat entire molecules as moieties.

    Returns
    -------
    dict
        A dictionary representing the net change of moieties {moiety_smiles: count}.
        Positive counts are products, negative counts are reactants.
    """
    # Split the reaction string into reactant and product sides
    initial_state_str, final_state_str = reaction_string.split(">>")
    
    # Split each side into individual molecule SMILES
    initial_state_smiles = initial_state_str.split(".")
    final_state_smiles = final_state_str.split(".")
    
    # Check if we are using substructure moieties or whole-molecule moieties
    if radius != 'MAX':
        # --- Substructure (moiety) mode ---
        state_i = {} # Dictionary to hold reactant moiety counts
        for smile in initial_state_smiles:
            # Count substructures for each reactant molecule and aggregate them
            state_i = add_dicts(state_i, count_substructures(radius, Chem.MolFromSmiles(smile)))
            
        state_f = {} # Dictionary to hold product moiety counts
        for smile in final_state_smiles:
            # Count substructures for each product molecule and aggregate them
            state_f = add_dicts(state_f, count_substructures(radius, Chem.MolFromSmiles(smile)))
        
        # Calculate the net change by subtracting reactant counts from product counts
        return subtract_dicts(state_f, state_i)
    else:
        # --- Whole-molecule mode ---
        state_i = {} # Dictionary to hold reactant molecule counts
        for smile in initial_state_smiles:
            state_i = add_dicts(state_i, {smile: 1})

        state_f = {} # Dictionary to hold product molecule counts
        for smile in final_state_smiles:
            state_f = add_dicts(state_f, {smile: 1})
            
        # Calculate the net change by subtracting reactant counts from product counts
        return subtract_dicts(state_f, state_i)

def moiety_dict_to_reaction_smiles(d):
    """
    Constructs a reaction SMILES string from a net change dictionary.

    Parameters
    ----------
    d : dict
        A dictionary of {moiety_smiles: count}.

    Returns
    -------
    str
        A reaction SMILES string in the format "reactants>>products".
    """
    final = []    # List to hold product SMILES
    initial = []  # List to hold reactant SMILES
    
    # Iterate through the net change dictionary
    for key, value in d.items():
        if value > 0:
            # Positive values correspond to products; add them to the 'final' list
            final.extend([key] * abs(value))
        elif value < 0:
            # Negative values correspond to reactants; add them to the 'initial' list
            initial.extend([key] * abs(value))
            
    # Join the lists of SMILES with '.' to form the reactant and product strings
    Initial_state = '.'.join(initial)
    Final_state = '.'.join(final)
    
    # Combine them into a full reaction SMILES string
    return Initial_state + '>>' + Final_state

def fix_rxn(reaction_num, leftover_string):
    """
    Calculates the difference between a known reaction and a "leftover" moiety string.
    NOTE: This function depends on a global variable `reaction_string`.

    Returns
    -------
    str
        A reaction SMILES string representing the difference.
    """
    # Get the net change for the known reaction
    known_reaction_change = reaction_string_to_moiety_change_dict(reaction_string[reaction_num], 'MAX')
    # Get the net change for the leftover string
    leftover_change = reaction_string_to_moiety_change_dict(leftover_string, 'MAX')
    # Subtract the leftover change from the known reaction change
    net_difference = subtract_dicts(known_reaction_change, leftover_change)
    # Convert the resulting net change back into a reaction SMILES
    return moiety_dict_to_reaction_smiles(net_difference)

def reaction_atom_and_e_balance(reaction_smiles):
    """
    Checks the atom and electron balance for a given reaction SMILES.

    Returns
    -------
    tuple
        A tuple containing:
        (dict: The final atom/electron balance, 
         list: A list of balance dicts for each molecule, 
         list: A list of all molecule SMILES).
    """
    # Split reaction into reactant and product sides
    initial_state_str, final_state_str = reaction_smiles.split(">>")
    initial_state_smiles = initial_state_str.split(".")
    final_state_smiles = final_state_str.split(".")
    
    # Initialize lists to store results for each molecule
    mol_parts, smiles, initial_state_balances, final_state_balances = [], [], [], []

    # --- Process Reactants (negative counts) ---
    for smile in initial_state_smiles:
        mol = Chem.MolFromSmiles(smile)
        mol = Chem.AddHs(mol) # Add explicit hydrogens for accurate counting
        charge = Chem.GetFormalCharge(mol)
        balance = {}
        for atom in mol.GetAtoms():
            # Decrement count for each atom symbol
            balance = add_dicts(balance, {atom.GetSymbol(): -1})
            # Decrement electron count by atomic number
            balance = add_dicts(balance, {'e': -atom.GetAtomicNum()})
        # Adjust electron count based on formal charge
        balance = add_dicts(balance, {'e': charge})
        
        initial_state_balances.append(balance)
        mol_parts.append(balance)
        smiles.append(smile)

    # --- Process Products (positive counts) ---
    for smile in final_state_smiles:
        mol = Chem.MolFromSmiles(smile)
        mol = Chem.AddHs(mol) # Add explicit hydrogens
        charge = Chem.GetFormalCharge(mol)
        balance = {}
        for atom in mol.GetAtoms():
            # Increment count for each atom symbol
            balance = add_dicts(balance, {atom.GetSymbol(): 1})
            # Increment electron count by atomic number
            balance = add_dicts(balance, {'e': atom.GetAtomicNum()})
        # Adjust electron count based on formal charge
        balance = add_dicts(balance, {'e': -charge})
        
        final_state_balances.append(balance)
        mol_parts.append(balance)
        smiles.append(smile)
        
    # Aggregate the balances from all molecules to get the overall reaction balance
    total_balance = {}
    for balance_dict in mol_parts:
        total_balance = add_dicts(total_balance, balance_dict)
        
    return total_balance, mol_parts, smiles

# A hardcoded dictionary mapping common protonation/deprotonation rule names
# to a set of generic "arrow environment" reaction SMILES. This is used
# in the similarity scoring to represent these common steps.
Protonation_arrow_rules = {
     'Phenol_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},'Phenol_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},
     'Asp/Glu_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},'Asp/Glu_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},
     'HisN3_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},'HisN3_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},
     'Lys_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},'Lys_depro': {'[OH]>>[O]', '[O]>>[H].[O]' '[NH+]>>[N+]', '[N]>>[H].[N]'},
     'Arg_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},'Arg_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},
     'HisN1_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},'HisN1_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},
     'THFA_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},'THFA_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},
     'Tys_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},
     'Formate_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},'Formate_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},
     'Ser_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},
     'Cys_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[S-]>>[H].[S-]', '[SH]>>[S]'},
     'alcohol_secondary_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},
     '2-Acetyllactate_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},
     'Amine_secondary_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},'Amine_secondary_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},
     'Water_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},'Water_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},
     'Oxonium_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O]>>[H].[O]', '[OH+]>>[O+]'},'Oxonium_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O]>>[H].[O]', '[OH+]>>[O+]'},
     'Ammonium_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[N]>>[H].[N]', '[NH+]>>[N+]'},'Ammonium_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[N]>>[H].[N]', '[NH+]>>[N+]'},
     'Carboxylic_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},'Carboxylic_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},
     'Thiosulfate_sulfur_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[S-]>>[H].[S-]', '[SH]>>[S]'},'Thiosulfate_sulfur_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[S-]>>[H].[S-]', '[SH]>>[S]'},
     'Phosphate_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},'Phosphate_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},
     'Phosphate_-1_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},'Phosphate_-1_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},
     'Phosphate_-2_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},'Phosphate_-2_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},
     'Phosphate_-3_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},'Phosphate_-3_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},
     'amine_primary_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},'amine_primary_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},
     'PLP_N_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},'PLP_N_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},
     'Adenine_methylN_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH]>>[N]', '[N]>>[H].[N]', '[N]=[C]>>[N][C]'},'Adenine_methylN_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH]>>[N]', '[N]>>[H].[N]', '[N]=[C]>>[N][C]'},
     'Bicarbonate_-1_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},'Bicarbonate_-1_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},
     'Methylamine_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},'Methylamine_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},
     '3-Phosphonopyruvate_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},'3-Phosphonopyruvate_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},
     'Nitrite_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},'Nitrite_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},
     'AmineC1_sugar_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},'AmineC1_sugar_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},
     'HCl_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[Cl-]>>[H].[Cl-]', '[ClH]>>[Cl]'},'HCl_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[Cl-]>>[H].[Cl-]', '[ClH]>>[Cl]'},
     'Methylamide_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},'Methylamide_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},
     'Adenine_primary_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},'Adenine_primary_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},
     'Arsenate_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},'Arsenate_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},
     'Cyanide_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[C-]>>[H].[C-]', '[CH]>>[C]'},'Cyanide_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[C-]>>[H].[C-]', '[CH]>>[C]'},
     'Sulfate_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},'Sulfate_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},
     'Sulfate_-1_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},'Sulfate_-1_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},
     'Uracil_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[N-]>>[H].[N-]', '[NH]>>[N]'},'Uracil_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[N-]>>[H].[N-]', '[NH]>>[N]'},
     'Sulfite_-1_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},'Sulfite_-1_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},
     'FAD_N3_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[N-]>>[H].[N-]', '[NH]>>[N]'},'FAD_N3_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[N-]>>[H].[N-]', '[NH]>>[N]'},
     'EsterC1C5_sugar_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O]>>[H].[O]', '[OH+]>>[O+]'},'EsterC1C5_sugar_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O]>>[H].[O]', '[OH+]>>[O+]'},
     'Formyl-amine_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},'Formyl-amine_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},
     'Diethylamine_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},'Diethylamine_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},
     'Creatine_N_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},'Creatine_N_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},
     'FAD_N1_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[N-]>>[H].[N-]', '[NH]>>[N]'},'FAD_N1_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[N-]>>[H].[N-]', '[NH]>>[N]'},
     'Benzoate_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},'Benzoate_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[O-]>>[H].[O-]', '[OH]>>[O]'},
     'Trimethylamine_depro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'},'Trimethylamine_pro': {'[OH]>>[O]', '[O]>>[H].[O]', '[NH+]>>[N+]', '[N]>>[H].[N]'}}
    
  

def similar_mechs(solution):
    """
    Calculates the similarity of a candidate mechanism (`solution`) to a global
    database of known M-CSA mechanisms.
    NOTE: Depends on global variables `MCSA_arrow_rules` and `MCSA_mechanism_arrows`.

    Returns
    -------
    dict
        A dictionary of {mcsa_id: similarity_score}, sorted by score descending.
    """
    # Use a set to store the unique arrow environments for the candidate solution
    solution_arrows = set()
    for rule in solution:
        steps = rule.split('&')
        for step in steps:
            # Handle special protonation rules by looking them up in the dictionary
            if 'pro' in step or 'depro' in step:
                for a in Protonation_arrow_rules.get(step, []):
                    solution_arrows.add(a)
            # Handle standard rule notation (e.g., 'mcsaID_step_substep')
            else:
                try:
                    # Parse the rule ID components
                    if '(' in step: # Handle alternative formatting
                        j, i, s = step.split('(')[1].split(')')[0].split('_')
                    else:
                        j, i, s = step.split('_')
                    # Look up the arrow environments for this rule in the global M-CSA data
                    for arr in MCSA_arrow_rules[int(j)-1][int(i)-1][int(s)-1]:
                        solution_arrows.update(arr)
                except (ValueError, IndexError):
                    # If parsing or lookup fails, just skip this step
                    continue
    
    solution_scores = {}
    # Iterate over every known mechanism in the global M-CSA arrow database
    for m in MCSA_mechanism_arrows:
        if MCSA_mechanism_arrows[m]: # Ensure the known mechanism is not empty
            # The similarity score is the Jaccard index: |A ∩ B| / |A ∪ B|
            # This is calculated here using a mathematical identity:
            # |A ∩ B| = |A| + |B| - |A ∪ B|
            # Here, added_lists represents |A| + |B|
            # And set(added_lists) represents |A ∪ B|
            known_arrows_set = MCSA_mechanism_arrows[m]
            intersection_size = len(solution_arrows.intersection(known_arrows_set))
            union_size = len(solution_arrows.union(known_arrows_set))
            
            # Calculate the score, avoiding division by zero
            score = intersection_size / union_size if union_size > 0 else 0.0
            solution_scores[m] = score
    
    # Return the scores sorted from highest to lowest similarity
    return dict(sorted(solution_scores.items(), key=lambda item: item[1], reverse=True))

In [34]:
Unique_Rules = pd.read_csv("Unique_Rules.csv", index_col=0)

labeled_moieties = []
for moiety in Unique_Rules.index.tolist():
    if ':' in moiety:
        atoms = re.findall(r'(\[.*?\]|\(.*?\)|[^-=#])', moiety)
        labeled_atom_num = moiety.count(':')
        # this finds moieties that have labeled atoms but are not taken from labeled atoms 
        # only works for radius = 1 moieties 
        if len(atoms) != labeled_atom_num:
            if ':' in atoms[1]:
                labeled_moieties.append(moiety)
        else:
            labeled_moieties.append(moiety)


old_names = Unique_Rules.columns.tolist()
new_names = list(range(len(old_names)))
column_mapping = dict(zip(old_names, new_names))
Unique_Rules_renamed = Unique_Rules.rename(columns=column_mapping)

###### Sets ######
moiety_index = Unique_Rules_renamed.index.tolist()  # moiety set
rules_index = Unique_Rules_renamed.columns.values.tolist() # rule set
print("Number of unique rules used in this search:", len(rules_index))

###### parameters ######
    # T(m,r) contains atom stoichiometry for each rule
T = Unique_Rules_renamed.to_dict(orient="index")


with open('M-CSA_arrow_rules_r0.json', 'r') as f:
    MCSA_arrow_rules = json.load(f)


MCSA_mechanism_arrows = {}
for j in range(len(MCSA_arrow_rules)):
    for i in range(len(MCSA_arrow_rules[j])):
        MCSA_mechanism_arrows[str(j+1)+'_'+str(i+1)] = set()
        for s in range(len(MCSA_arrow_rules[j][i])):
            for dir in range(len(MCSA_arrow_rules[j][i][s])):
                for a in range(len(MCSA_arrow_rules[j][i][s][dir])):
                    arr = MCSA_arrow_rules[j][i][s][dir][a]
                    MCSA_mechanism_arrows[str(j+1)+'_'+str(i+1)].add(arr)    

Number of unique rules used in this search: 4143


In [35]:
def MechFind(desired_reaction,max_steps,iterations,time_limit,max_repeatable_rules):
    start_time = time.time()
    
    # overall reaction input
    T_o = reaction_string_to_moiety_change_dict(desired_reaction,1)
    for index in moiety_index:
        if index not in T_o:
            T_o[index] = 0.0
    
    # reactant input
    C_R = count_substructures(1,Chem.MolFromSmiles(desired_reaction.split('>>')[0]))
    
    # makes sure if a reactant moiety is labeled that moiety is still included in constraint 3
    ignored_moieties = []
    for moiety in labeled_moieties:
        if moiety == '[OH2:1]': # water is life, it is always available
            ignored_moieties.append(moiety)
        else:
            if moiety in C_R.keys():
                continue
            else:
                ignored_moieties.append(moiety)
                
    # adds the rest of the moieties in moiety_index to C_R as zeros        
    for index in moiety_index:
        if index not in C_R:
            C_R[index] = 0.0
            
    
    ###### variables ######
    if max_repeatable_rules == 1:
        y = pulp.LpVariable.dicts("y", rules_index, lowBound=0, upBound= max_repeatable_rules, cat="Binary")
    else:
        y = pulp.LpVariable.dicts("y", rules_index, lowBound=0, upBound= max_repeatable_rules, cat="Integer")
    
    # create minRules MILP problem
    minRules = pulp.LpProblem("minRules", pulp.LpMinimize)
    
    ####### objective function ####
    minRules += pulp.lpSum([y[r] for r in rules_index])
    
    
    ####### constraints ####
    
    # constraint 1: moiety change balance
    for m in moiety_index:
        minRules += (pulp.lpSum([T[m][r] * y[r] for r in rules_index if T[m][r] != 0]) == [T_o[m]]
                    , "moiety_balance_" + str(moiety_index.index(m)))
        
    # constraint 2: customized constraints
    # the number of steps of the pathway
    minRules += pulp.lpSum([y[r] for r in rules_index]) <= max_steps
    
    solutions = []
    sol_num = 0
    while len(solutions) < iterations:
            
        #minRules.writeLP('./minRules_test.lp')
        minRules.solve(pulp.PULP_CBC_CMD(msg=0))
        
        elapsed_time = time.time() - start_time
        if time_limit != 'none':
            if elapsed_time > time_limit:
                solutions.append('took too long')
                break
        if pulp.LpStatus[minRules.status] != 'Optimal':
            solutions.append('infeasible')
            break
        
        # constraint 3: integer cuts
        integer_cut_rules = []
        solution = []
        for r in rules_index:
            for n in range(int(y[r].varValue)):
                integer_cut_rules.append(r)
                solution.append(r)
 
        sol_num += 1        
        length = len(integer_cut_rules) - 1
        minRules += (pulp.lpSum([y[r] for r in integer_cut_rules]) <= length,
                    "integer_cut_" + str(sol_num),)
    
        ## Set K
        K_steps = list(range(1, len(solution) + 1))
    
        z = pulp.LpVariable.dicts("z", (K_steps,rules_index), lowBound=0, upBound=1, cat="Binary")
    
        # create OrderRules MILP problem
        OrderRules = pulp.LpProblem("OrderRules", pulp.LpMinimize)
        
        ####### objective function ####
        OrderRules += 0
        
        # constraint 4: the cumlitive sum of all rules and reactant to be greater than or equal to zero
        for m in moiety_index:
            if m not in ignored_moieties:
                for current_k in K_steps:
                    OrderRules += (C_R[m] + pulp.lpSum([[T[m][r] * z[k][r] for r in set(solution) if T[m][r] != 0] for k in range(1,current_k+1) ]) >= 0
                                ,"cumulative_sum_" + str(moiety_index.index(m)) + '_' + str(current_k))
        
    
        # constraint 5: fixes one rule to be used per step
        for k in K_steps:
            OrderRules += pulp.lpSum([z[k][r] for r in set(solution)]) == 1 
    
        # constraint 6: fixes z_kr == y_r for all rules in solution
        for r in set(solution):
            OrderRules += pulp.lpSum([z[k][r] for k in K_steps] ) == int(y[r].varValue)
        
        #OrderRules.writeLP('./OrderRules_test.lp')
        OrderRules.solve(pulp.PULP_CBC_CMD(msg=0))
        
       
        if pulp.LpStatus[OrderRules.status] != 'Optimal':
            continue
            
        solution = [] 
        for k in K_steps:
            for r in rules_index:
                if z[k][r].varValue == 1:
                    solution.append(Unique_Rules.columns.tolist()[r])
        solutions.append(solution)

    scores = {}
    for n in range(len(solutions)):
        solution = solutions[n]
        if solution == 'took too long' or solution == 'infeasible':
            scores[n] = 0.0
        else:
            mechs = similar_mechs(solution)
            scores[n] = mechs[list(mechs.keys())[0]]
    
    sorted_scores = dict(sorted(scores.items(), key=lambda item: item[1], reverse=True))
    
    sorted_solutions = []
    for idx in sorted_scores:
        sorted_solutions.append(solutions[idx])
    
    return sorted_solutions



In [37]:
def Mechanism_Matrix(mech_rules_list,desired_reaction):
    reactants = desired_reaction.split('>>')[0]
    products = desired_reaction.split('>>')[1]
    
    reactant_moieties = {}
    for reactant in reactants.split('.'):
        reactant_moieties = add_dicts(reactant_moieties,count_substructures(1,Chem.MolFromSmiles(reactant)))
    reactant_df = pd.DataFrame.from_dict(reactant_moieties, orient='index')
    reactant_df.columns = ['Reactants']
    
    product_moieties = {}
    for product in products.split('.'):
        product_moieties = add_dicts(product_moieties,count_substructures(1,Chem.MolFromSmiles(product)))
        
    product_df = pd.DataFrame.from_dict(product_moieties, orient='index')
    product_df.columns = ['Products']
    
    true_rxn = reaction_string_to_moiety_change_dict(desired_reaction,1)
    Mechanism  = pd.DataFrame.from_dict(true_rxn, orient='index')
    Mechanism.columns = ['RXN']
    
    Mechanism = pd.concat([Mechanism, reactant_df], axis=1)
    Mechanism = Mechanism.fillna(0).astype(int)
    
    for rule in mech_rules_list:
        df = pd.DataFrame.from_dict(Unique_Rules[Unique_Rules[rule] != 0][rule].to_dict(), orient='index')
        df.columns = [rule]
        Mechanism = pd.concat([Mechanism, df], axis=1)
        Mechanism = Mechanism.fillna(0).astype(int)
    
    Mechanism = pd.concat([Mechanism, product_df], axis=1)
    Mechanism = Mechanism.fillna(0).astype(int)
    
    pd.options.display.max_rows = len(Mechanism)  
    pd.options.display.max_columns = len(Mechanism.columns)

    return Mechanism

In [39]:
max_steps = 20
iterations = 10
time_limit = 120
max_repeatable_rules = 5
desired_reaction = 'CC(=O)SCC.[OH2:1].[NH3+]CC>>CC(=O)NCC.SCC.[OH3+:1]'
solutions = MechFind(desired_reaction,max_steps,iterations,time_limit,max_repeatable_rules)
solutions

[['22_1_4&(80_1_3)&96_1_4&102_1_6&(106_1_8)&216_1_3&(229_1_5)&(236_1_3)&317_1_5&329_1_2&336_1_3&526_1_2&527_1_2&576_1_2&673_1_2&785_1_2&843_1_2&984_1_3&HisN1_depro',
  '22_1_1',
  '22_1_2',
  '22_1_3'],
 ['344_1_1', '344_1_2', 'Methylamide_depro'],
 ['amine_primary_depro', '224_1_2&524_1_2&525_2_2', '524_1_3&525_2_3'],
 ['22_1_1',
  '(163_2_2)&216_1_2&(239_1_9)&(239_2_9)&571_1_2&765_1_2&(787_1_2)&788_1_2&(791_1_3)&846_1_1',
  '(38_1_10)',
  '524_1_3&525_2_3'],
 ['22_1_4&(80_1_3)&96_1_4&102_1_6&(106_1_8)&216_1_3&(229_1_5)&(236_1_3)&317_1_5&329_1_2&336_1_3&526_1_2&527_1_2&576_1_2&673_1_2&785_1_2&843_1_2&984_1_3&HisN1_depro',
  '22_1_1',
  '524_1_3&525_2_3'],
 ['(4_1_1)&13_1_4&19_1_1&36_1_3&(49_1_7)&68_1_4&99_1_7&100_1_5&107_1_7&116_1_3&141_1_3&141_1_6&141_1_8&141_1_11&159_1_3&166_1_4&171_1_3&209_1_2&223_1_8&(244_1_3)&332_1_2&339_1_2&345_1_5&365_1_2&507_1_3&525_1_6&525_2_4&546_1_4&546_2_6&570_1_2&588_1_2&610_1_2&615_1_2&627_1_5&632_1_2&647_1_2&654_1_2&663_1_2&665_1_3&696_1_2&702_1_6&801_1

In [40]:
Mechanism_Matrix(solutions[0],desired_reaction)

Unnamed: 0,RXN,Reactants,22_1_4&(80_1_3)&96_1_4&102_1_6&(106_1_8)&216_1_3&(229_1_5)&(236_1_3)&317_1_5&329_1_2&336_1_3&526_1_2&527_1_2&576_1_2&673_1_2&785_1_2&843_1_2&984_1_3&HisN1_depro,22_1_1,22_1_2,22_1_3,Products
CC(N)=O,1,0,0,0,0,1,1
CNC,1,0,0,0,0,1,1
CCN,1,0,0,0,0,1,1
CS,1,0,0,0,1,0,1
[OH3+:1],1,0,1,0,0,0,1
CC(=O)S,-1,1,0,-1,0,0,0
CSC,-1,1,0,0,-1,0,0
[OH2:1],-1,1,-1,0,0,0,0
C[NH3+],-1,1,0,-1,0,0,0
CC[NH3+],-1,1,0,-1,0,0,0
