In [1]:
from litellm import completion
from typing import Optional
import litellm
from dotenv import load_dotenv
import rootutils

rootutils.setup_root(".", indicator=".project-root", pythonpath=True)

PosixPath('/Users/shreyasv/Desktop/research/deepchem/retrosynthesis/prod')

In [2]:
import os

os.environ["ENABLE_LOGGING"] = "False"

# Stability

In [3]:
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.rdMolDescriptors import CalcNumAliphaticCarbocycles, CalcNumAliphaticHeterocycles, CalcNumBridgeheadAtoms
from rdkit.Chem.rdMolDescriptors import CalcNumAliphaticRings, CalcNumAromaticCarbocycles, CalcNumAromaticHeterocycles, CalcNumAromaticRings

In [4]:
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors, Lipinski
from rdkit.Chem.rdMolDescriptors import CalcNumAliphaticCarbocycles, CalcNumAliphaticHeterocycles
from rdkit.Chem.rdMolDescriptors import CalcNumAliphaticRings, CalcNumAromaticCarbocycles
from rdkit.Chem.rdMolDescriptors import CalcNumAromaticHeterocycles, CalcNumAromaticRings
from rdkit.Chem.rdMolDescriptors import CalcNumBridgeheadAtoms

In [5]:
smiles111 = "C1(C2)CC=1C=2"
smiles2 = "c1(C2)ccccc1C=2"

In [27]:
mol = Chem.MolFromSmiles(smiles2)

ring_info = mol.GetRingInfo()
print("ring info: ", ring_info)
atom_rings = ring_info.AtomRings()
print("atom_rings: ", atom_rings)
bond_rings = ring_info.BondRings()
print("bond_rings: ", bond_rings)
num_rings = len(atom_rings)
print("num_rings: ", num_rings)
num_atoms = mol.GetNumAtoms()
print("num_atoms: ", num_atoms)
num_bonds = mol.GetNumBonds()
print("num_bonds: ", num_bonds)
num_heavy_atoms = mol.GetNumHeavyAtoms()
print("num_heavy_atoms: ", num_heavy_atoms)
num_heavy_bonds = mol.GetNumBonds()
print("num_heavy_bonds: ", num_heavy_bonds)
num_aromatic_atoms = len(mol.GetAromaticAtoms())
print("num_aromatic_atoms: ", num_aromatic_atoms)
num_aliphatic_atoms = num_heavy_atoms - num_aromatic_atoms
print("num_aliphatic_atoms: ", num_aliphatic_atoms)
num_aliphatic_carbocyles = CalcNumAliphaticCarbocycles(mol)
print("num_aliphatic_carbocyles: ", num_aliphatic_carbocyles)
num_aliphatic_heterocycles = CalcNumAliphaticHeterocycles(mol)
print("num_aliphatic_heterocycles: ", num_aliphatic_heterocycles)
num_aliphatic_rings = CalcNumAliphaticRings(mol)
print("num_aliphatic_rings: ", num_aliphatic_rings)
num_aromatic_carbocyles = CalcNumAromaticCarbocycles(mol)
print("num_aromatic_carbocyles: ", num_aromatic_carbocyles)
num_aromatic_heterocycles = CalcNumAromaticHeterocycles(mol)
print("num_aromatic_heterocycles: ", num_aromatic_heterocycles)
num_aromatic_rings = CalcNumAromaticRings(mol)
print("num_aromatic_rings: ", num_aromatic_rings)
num_bridgehead_atoms = CalcNumBridgeheadAtoms(mol)
print("num_bridgehead_atoms: ", num_bridgehead_atoms)
# num_aliphatic_bonds = num_heavy_bonds - num_aromatic_bonds
# num_heteroatoms = len(mol.GetAromaticAtoms()) + len(mol.GetHeteroatoms())
# num_heterobonds = len(mol.GetAromaticBonds()) + len(mol.GetHeterobonds())

ring info:  <rdkit.Chem.rdchem.RingInfo object at 0x12f1daa40>
atom_rings:  ((1, 7, 6, 0), (2, 3, 4, 5, 6, 0))
bond_rings:  ((8, 6, 7, 0), (2, 3, 4, 5, 7, 1))
num_rings:  2
num_atoms:  8
num_bonds:  9
num_heavy_atoms:  8
num_heavy_bonds:  9
num_aromatic_atoms:  6
num_aliphatic_atoms:  2
num_aliphatic_carbocyles:  1
num_aliphatic_heterocycles:  0
num_aliphatic_rings:  1
num_aromatic_carbocyles:  1
num_aromatic_heterocycles:  0
num_aromatic_rings:  1
num_bridgehead_atoms:  0


In [None]:
def check_molecule_stability(smiles):
    """
    Performs heuristic checks on a molecule given its SMILES string to estimate stability.
    
    Args:
        smiles (str): SMILES representation of the molecule
        
    Returns:
        dict: Dictionary containing stability assessment and various metrics
    """
    # Initialize results dictionary
    results = {
        "valid_structure": False,
        "stability_score": 0,
        "issues": [],
        "metrics": {},
        "ring_data": {},
        "atom_data": {}
    }

    # Parse SMILES and check if it's a valid structure
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        results["issues"].append("Invalid SMILES string or cannot be parsed")
        return results

    results["valid_structure"] = True

    # Calculate basic properties
    mw = Descriptors.MolWt(mol)
    logp = Descriptors.MolLogP(mol)
    hbd = Descriptors.NumHDonors(mol)
    hba = Descriptors.NumHAcceptors(mol)
    rotatable_bonds = Descriptors.NumRotatableBonds(mol)

    results["metrics"] = {
        "molecular_weight": mw,
        "logP": logp,
        "h_bond_donors": hbd,
        "h_bond_acceptors": hba,
        "rotatable_bonds": rotatable_bonds
    }

    # Get detailed ring information
    ring_info = mol.GetRingInfo()
    atom_rings = ring_info.AtomRings()
    bond_rings = ring_info.BondRings()

    # Get atom and bond counts
    num_atoms = mol.GetNumAtoms()
    num_bonds = mol.GetNumBonds()
    num_heavy_atoms = mol.GetNumHeavyAtoms()
    num_aromatic_atoms = len(mol.GetAromaticAtoms())
    num_aliphatic_atoms = num_heavy_atoms - num_aromatic_atoms

    # Calculate ring-related descriptors
    num_aliphatic_carbocycles = CalcNumAliphaticCarbocycles(mol)
    num_aliphatic_heterocycles = CalcNumAliphaticHeterocycles(mol)
    num_aliphatic_rings = CalcNumAliphaticRings(mol)
    num_aromatic_carbocycles = CalcNumAromaticCarbocycles(mol)
    num_aromatic_heterocycles = CalcNumAromaticHeterocycles(mol)
    num_aromatic_rings = CalcNumAromaticRings(mol)
    num_bridgehead_atoms = CalcNumBridgeheadAtoms(mol)

    # Store ring data
    results["ring_data"] = {
        "num_rings": len(atom_rings),
        "atom_rings": [list(ring) for ring in atom_rings],
        "bond_rings": [list(ring) for ring in bond_rings],
        "num_aliphatic_carbocycles": num_aliphatic_carbocycles,
        "num_aliphatic_heterocycles": num_aliphatic_heterocycles,
        "num_aliphatic_rings": num_aliphatic_rings,
        "num_aromatic_carbocycles": num_aromatic_carbocycles,
        "num_aromatic_heterocycles": num_aromatic_heterocycles,
        "num_aromatic_rings": num_aromatic_rings,
        "num_bridgehead_atoms": num_bridgehead_atoms
    }

    # Store atom data
    results["atom_data"] = {
        "num_atoms": num_atoms,
        "num_bonds": num_bonds,
        "num_heavy_atoms": num_heavy_atoms,
        "num_heavy_bonds":
        num_bonds,  # Same as num_bonds since RDKit doesn't count hydrogen bonds separately
        "num_aromatic_atoms": num_aromatic_atoms,
        "num_aliphatic_atoms": num_aliphatic_atoms
    }

    # Check for strained rings (small rings are often unstable)
    for ring in atom_rings:
        if len(ring) < 3:
            results["issues"].append(
                f"Highly strained ring of size {len(ring)}")
        elif len(ring) == 3 and any(
                mol.GetAtomWithIdx(i).GetSymbol() != 'C' for i in ring):
            results["issues"].append(
                "Three-membered heterocycle (potentially unstable)")
        elif len(ring) == 4 and any(
                mol.GetAtomWithIdx(i).GetSymbol() != 'C' for i in ring):
            results["issues"].append(
                "Four-membered heterocycle (potentially unstable)")

    # Assess complex ring systems
    if num_bridgehead_atoms > 0:
        # Complex polycyclic systems can be strained
        if any(len(ring) <= 4
               for ring in atom_rings) and num_bridgehead_atoms >= 2:
            results["issues"].append(
                "Strained polycyclic system with bridgehead atoms")

    # Calculate stability score (0-100, higher is more stable)
    # Start with 100 and subtract for various issues
    stability_score = 100

    # Penalize for each identified issue
    stability_score -= len(results["issues"]) * 15

    # Penalize for extreme values of properties
    if mw > 1000: stability_score -= 10
    if abs(logp) > 10: stability_score -= 10
    if rotatable_bonds > 15: stability_score -= 10

    # Assess stability based on ring structure
    if num_aromatic_rings > 0:
        # Aromatic rings typically enhance stability
        stability_score += min(num_aromatic_rings * 5, 15)  # Cap bonus at 15

    if num_aliphatic_heterocycles > 0 and num_aliphatic_rings <= 3:
        # Small number of aliphatic heterocycles can be unstable
        stability_score -= 5 * num_aliphatic_heterocycles

    # Penalize for complex strained systems
    if num_bridgehead_atoms > 0 and any(len(ring) <= 5 for ring in atom_rings):
        stability_score -= num_bridgehead_atoms * 5

    # Cap the score between 0 and 100
    stability_score = max(0, min(100, stability_score))
    results["stability_score"] = stability_score

    # Overall stability assessment
    if stability_score >= 80:
        results["assessment"] = "Likely stable"
    elif stability_score >= 50:
        results["assessment"] = "Moderately stable"
    else:
        results["assessment"] = "Potentially unstable"

    return results

In [43]:
def check_molecule_stability(smiles):
    """
    Performs heuristic checks on a molecule given its SMILES string to estimate stability.
    
    Args:
        smiles (str): SMILES representation of the molecule
        
    Returns:
        dict: Dictionary containing stability assessment and various metrics
    """
    # Initialize results dictionary
    results = {
        "valid_structure": False,
        "stability_score": 0,
        "issues": [],
        "metrics": {},
        "ring_data": {},
        "atom_data": {}
    }

    # Parse SMILES and check if it's a valid structure
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        results["issues"].append("Invalid SMILES string or cannot be parsed")
        return results

    results["valid_structure"] = True

    # Calculate basic properties
    mw = Descriptors.MolWt(mol)
    logp = Descriptors.MolLogP(mol)
    hbd = Descriptors.NumHDonors(mol)
    hba = Descriptors.NumHAcceptors(mol)
    rotatable_bonds = Descriptors.NumRotatableBonds(mol)

    results["metrics"] = {
        "molecular_weight": mw,
        "logP": logp,
        "h_bond_donors": hbd,
        "h_bond_acceptors": hba,
        "rotatable_bonds": rotatable_bonds
    }

    # Get detailed ring information
    ring_info = mol.GetRingInfo()
    atom_rings = ring_info.AtomRings()
    bond_rings = ring_info.BondRings()

    # Get atom and bond counts
    num_atoms = mol.GetNumAtoms()
    num_bonds = mol.GetNumBonds()
    num_heavy_atoms = mol.GetNumHeavyAtoms()
    num_aromatic_atoms = len(mol.GetAromaticAtoms())
    num_aliphatic_atoms = num_heavy_atoms - num_aromatic_atoms

    # Calculate ring-related descriptors
    num_aliphatic_carbocycles = CalcNumAliphaticCarbocycles(mol)
    num_aliphatic_heterocycles = CalcNumAliphaticHeterocycles(mol)
    num_aliphatic_rings = CalcNumAliphaticRings(mol)
    num_aromatic_carbocycles = CalcNumAromaticCarbocycles(mol)
    num_aromatic_heterocycles = CalcNumAromaticHeterocycles(mol)
    num_aromatic_rings = CalcNumAromaticRings(mol)
    num_bridgehead_atoms = CalcNumBridgeheadAtoms(mol)

    # Store ring data
    results["ring_data"] = {
        "num_rings": len(atom_rings),
        "atom_rings": [list(ring) for ring in atom_rings],
        "bond_rings": [list(ring) for ring in bond_rings],
        "num_aliphatic_carbocycles": num_aliphatic_carbocycles,
        "num_aliphatic_heterocycles": num_aliphatic_heterocycles,
        "num_aliphatic_rings": num_aliphatic_rings,
        "num_aromatic_carbocycles": num_aromatic_carbocycles,
        "num_aromatic_heterocycles": num_aromatic_heterocycles,
        "num_aromatic_rings": num_aromatic_rings,
        "num_bridgehead_atoms": num_bridgehead_atoms
    }

    # Store atom data
    results["atom_data"] = {
        "num_atoms": num_atoms,
        "num_bonds": num_bonds,
        "num_heavy_atoms": num_heavy_atoms,
        "num_heavy_bonds":
        num_bonds,  # Same as num_bonds since RDKit doesn't count hydrogen bonds separately
        "num_aromatic_atoms": num_aromatic_atoms,
        "num_aliphatic_atoms": num_aliphatic_atoms
    }

    # Check for strained rings (small rings are often unstable)
    for ring in atom_rings:
        if len(ring) < 3:
            results["issues"].append(
                f"Highly strained ring of size {len(ring)}")
        elif len(ring) == 3 and any(
                mol.GetAtomWithIdx(i).GetSymbol() != 'C' for i in ring):
            results["issues"].append(
                "Three-membered heterocycle (potentially unstable)")
        elif len(ring) == 4 and any(
                mol.GetAtomWithIdx(i).GetSymbol() != 'C' for i in ring):
            results["issues"].append(
                "Four-membered heterocycle (potentially unstable)")

    # ----------------- DETECT ANTI-AROMATIC COMPOUNDS -----------------
    # Check for anti-aromatic systems (4n π electrons)
    # Common anti-aromatic patterns include cyclobutadiene, cyclooctatetraene, etc.
    patt_cyclobutadiene = Chem.MolFromSmarts(
        "c1ccc1")  # 4-membered fully conjugated ring
    patt_cyclooctatetraene = Chem.MolFromSmarts(
        "C1=CC=CC=CC=C1")  # Non-planar COT
    patt_pentalene = Chem.MolFromSmarts("c1cc2cccc2c1")  # Pentalene pattern

    anti_aromatic_patterns = [
        (patt_cyclobutadiene, "cyclobutadiene-like (anti-aromatic)"),
        (patt_cyclooctatetraene,
         "cyclooctatetraene-like (potential anti-aromatic)"),
        (patt_pentalene, "pentalene-like (potential anti-aromatic)")
    ]

    for patt, name in anti_aromatic_patterns:
        if patt and mol.HasSubstructMatch(patt):
            results["issues"].append(f"Contains {name} motif")

    # Also detect rings with 4n π electrons
    for ring in atom_rings:
        # Only consider rings of sizes that could be anti-aromatic
        if len(ring) in [4, 8, 12, 16]:
            # Count conjugated double bonds in the ring
            double_bond_count = 0
            ring_atoms = [mol.GetAtomWithIdx(i) for i in ring]

            for atom in ring_atoms:
                if atom.GetIsAromatic():
                    double_bond_count += 0.5

                for bond in atom.GetBonds():
                    # Check if bond is in this ring and is a double bond
                    other_atom_idx = bond.GetOtherAtomIdx(atom.GetIdx())
                    if other_atom_idx in ring and bond.GetBondType(
                    ) == Chem.BondType.DOUBLE:
                        double_bond_count += 0.5  # Count each end of double bond once

            # If number of π electrons is 4n (where n is positive integer)
            pi_electrons = int(double_bond_count *
                               2)  # Each double bond contributes 2 π electrons
            if pi_electrons > 0 and pi_electrons % 4 == 0:
                results["issues"].append(
                    f"{len(ring)}-membered ring with {pi_electrons} π electrons (potential anti-aromatic)"
                )

    # ----------------- DETECT FUSED 3-4 MEMBERED RINGS -----------------
    # Look for fused small rings (which create highly strained systems)
    small_rings = []
    for ring in atom_rings:
        if len(ring) <= 4:
            small_rings.append(set(ring))

    # Check for fused small rings (rings that share atoms)
    fused_small_rings_detected = False
    for i in range(len(small_rings)):
        for j in range(i + 1, len(small_rings)):
            shared_atoms = small_rings[i].intersection(small_rings[j])
            if len(shared_atoms) > 0:
                fused_small_rings_detected = True
                ring1_size = len(small_rings[i])
                ring2_size = len(small_rings[j])
                results["issues"].append(
                    f"Fused {ring1_size} and {ring2_size}-membered rings (highly strained system)"
                )

    if fused_small_rings_detected:
        # This is a serious stability concern, add an explicit warning
        results["issues"].append(
            "WARNING: Fused small rings create highly strained and potentially explosive compounds"
        )

    # ----------------- DETECT LARGE HETEROCYCLES -----------------
    # Large heterocycles (>7 members) can be unstable
    for ring in atom_rings:
        if len(ring) >= 7:
            # Check if ring contains heteroatoms
            ring_atoms = [mol.GetAtomWithIdx(i) for i in ring]
            heteroatoms = [
                atom for atom in ring_atoms
                if atom.GetSymbol() not in ['C', 'H']
            ]

            if len(heteroatoms) > 0:
                hetero_symbols = [atom.GetSymbol() for atom in heteroatoms]
                unique_hetero = set(hetero_symbols)
                results["issues"].append(
                    f"Large ({len(ring)}-membered) heterocycle with {', '.join(unique_hetero)} (potentially unstable)"
                )

                # More severe warning for very large heterocycles (>10 members)
                if len(ring) > 10 and len(heteroatoms) >= 3:
                    results["issues"].append(
                        f"Very large heterocycle with multiple heteroatoms (significant stability concern)"
                    )

    # Assess complex ring systems
    if num_bridgehead_atoms > 0:
        # Complex polycyclic systems can be strained
        if any(len(ring) <= 4
               for ring in atom_rings) and num_bridgehead_atoms >= 2:
            results["issues"].append(
                "Strained polycyclic system with bridgehead atoms")

    # Calculate stability score (0-100, higher is more stable)
    # Start with 100 and subtract for various issues
    stability_score = 100

    # Penalize for each identified issue
    stability_score -= len(results["issues"]) * 15

    # Penalize for extreme values of properties
    if mw > 1000: stability_score -= 10
    if abs(logp) > 10: stability_score -= 10
    if rotatable_bonds > 15: stability_score -= 10

    # Assess stability based on ring structure
    if num_aromatic_rings > 0:
        # Aromatic rings typically enhance stability
        stability_score += min(num_aromatic_rings * 5, 15)  # Cap bonus at 15

    if num_aliphatic_heterocycles > 0 and num_aliphatic_rings <= 3:
        # Small number of aliphatic heterocycles can be unstable
        stability_score -= 5 * num_aliphatic_heterocycles

    # Penalize for complex strained systems
    if num_bridgehead_atoms > 0 and any(len(ring) <= 5 for ring in atom_rings):
        stability_score -= num_bridgehead_atoms * 5

    # Additional penalties for new detected issues
    if fused_small_rings_detected:
        stability_score -= 30  # Severe penalty for fused small rings

    # Penalize anti-aromatic structures
    for issue in results["issues"]:
        if "anti-aromatic" in issue:
            stability_score -= 25
        elif "π electrons" in issue:
            stability_score -= 20

    # Penalize large heterocycles
    large_heterocycle_count = sum(
        1 for issue in results["issues"]
        if "heterocycle" in issue and "large" in issue.lower())
    if large_heterocycle_count > 0:
        stability_score -= large_heterocycle_count * 10

    # Cap the score between 0 and 100
    stability_score = max(0, min(100, stability_score))
    results["stability_score"] = stability_score

    # Overall stability assessment
    if stability_score >= 80:
        results["assessment"] = "Likely stable"
    elif stability_score >= 50:
        results["assessment"] = "Moderately stable"
    else:
        results["assessment"] = "Potentially unstable"

    return results

In [6]:
from src.utils.stability_checks import check_molecule_stability

In [8]:
# Example usage
if __name__ == "__main__":
    # Example molecules
    examples = [
        ("CC(=O)O", "Acetic acid - stable compound"),
        ("C1=CC=CC=C1", "Benzene - stable aromatic"),
        ("OO", "Hydrogen peroxide - reactive"),
        ("C1CC1", "Cyclopropane - strained ring"),
        ("CN=NC", "Azo compound - potentially unstable"),
        ("C(C(=O)Cl)C(=O)Cl", "Malonyl dichloride - reactive"),
        ("CCCCCCCCCCCCCCCCCCCCCCCCCCCC",
         "Long alkane - stable but low solubility"),
        ("C12C3C4C1C5C2C3C45", "Cubane - strained structure"),
        (smiles111, "smiles 111 unstable"),
        ("C1=CC=C1", "Cyclobutadiene (anti-aromatic)"),
        ("C1=CC=CC=CC=C1",
         "Cyclooctatetraene (non-planar, potential anti-aromatic)"),
        ("C12CC1CC2", "Bicyclo[1.1.0]butane (fused 3-membered rings)"),
        ("C1CC12CC2", "Bicyclo[2.1.0]pentane (fused 3 and 4-membered rings)"),
        ("C1NCCOCCN1", "7-membered heterocycle with N and O"),
        ("C1NCCOCCSCCN1", "12-membered heterocycle with N, O, and S"),
        ("CC(=O)OC1=C(Cl)C=C(C=[C+]CNC)C=C1", "Mercury acetylide - unstable"),
        ("C1=CC=C2C=CC[C@@H]3[C@@H](N(C)[C+]=C)C2=C13", "some thing")
    ]

    for smiles, description in examples:
        print(f"\n{description}")
        result = check_molecule_stability(smiles)
        print(f"Result: {result}")
        print(f"SMILES: {smiles}")
        print(f"Stability assessment: {result.get('assessment', 'Unknown')}")
        print(f"Stability score: {result.get('stability_score', 0)}/100")

        if result["issues"]:
            print("Issues identified:")
            for issue in result["issues"]:
                print(f"  - {issue}")

        print("Basic metrics:")
        for key, value in result.get("metrics", {}).items():
            print(f"  - {key}: {value}")

        print("Ring data:")
        for key, value in result.get("ring_data", {}).items():
            if key != "atom_rings" and key != "bond_rings":
                print(f"  - {key}: {value}")

        print("Atom data:")
        for key, value in result.get("atom_data", {}).items():
            print(f"  - {key}: {value}")
        print("--------------------")


Acetic acid - stable compound
Result: {'valid_structure': True, 'stability_score': 100, 'issues': [], 'metrics': {'molecular_weight': 60.05200000000001, 'logP': 0.09089999999999993, 'h_bond_donors': 1, 'h_bond_acceptors': 1, 'rotatable_bonds': 0}, 'ring_data': {'num_rings': 0, 'atom_rings': [], 'bond_rings': [], 'num_aliphatic_carbocycles': 0, 'num_aliphatic_heterocycles': 0, 'num_aliphatic_rings': 0, 'num_aromatic_carbocycles': 0, 'num_aromatic_heterocycles': 0, 'num_aromatic_rings': 0, 'num_bridgehead_atoms': 0}, 'atom_data': {'num_atoms': 4, 'num_bonds': 3, 'num_heavy_atoms': 4, 'num_heavy_bonds': 3, 'num_aromatic_atoms': 0, 'num_aliphatic_atoms': 4}, 'assessment': 'Likely stable'}
SMILES: CC(=O)O
Stability assessment: Likely stable
Stability score: 100/100
Basic metrics:
  - molecular_weight: 60.05200000000001
  - logP: 0.09089999999999993
  - h_bond_donors: 1
  - h_bond_acceptors: 1
  - rotatable_bonds: 0
Ring data:
  - num_rings: 0
  - num_aliphatic_carbocycles: 0
  - num_alipha

# Hallucination Checker

In [46]:
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.rdFMCS import FindMCS

In [51]:
def identify_reaction_sites(reactant_smiles, product_smiles):
    """
    Identifies reaction sites by comparing the maximum common substructure
    between reactant and product molecules.
    
    Args:
        reactant_smiles (str): SMILES notation of the reactant molecule
        product_smiles (str): SMILES notation of the product molecule
    
    Returns:
        tuple: Contains:
            - List of atoms indices that changed in the reactant
            - List of atoms indices that changed in the product
            - The maximum common substructure as a SMILES string
    """
    # Convert SMILES to RDKit molecules
    reactant = Chem.MolFromSmiles(reactant_smiles)
    product = Chem.MolFromSmiles(product_smiles)

    if not reactant or not product:
        raise ValueError("Invalid SMILES notation provided")

    # Find Maximum Common Substructure (MCS)
    mcs = FindMCS([reactant, product],
                  matchValences=True,
                  ringMatchesRingOnly=True,
                  completeRingsOnly=True)

    # Convert MCS to molecule
    mcs_mol = Chem.MolFromSmarts(mcs.smartsString)

    # Find matches in both molecules
    reactant_match = reactant.GetSubstructMatch(mcs_mol)
    product_match = product.GetSubstructMatch(mcs_mol)

    # Identify changed atoms in reactant
    reactant_changed = set(range(reactant.GetNumAtoms())) - set(reactant_match)
    product_changed = set(range(product.GetNumAtoms())) - set(product_match)

    # Convert MCS to SMILES for visualization
    mcs_smiles = Chem.MolToSmiles(mcs_mol) if mcs_mol else ""

    return (list(reactant_changed), list(product_changed), mcs_smiles)

In [52]:
def analyze_reaction(reactant_smiles, product_smiles):
    """
    Performs a complete analysis of the reaction sites.
    
    Args:
        reactant_smiles (str): SMILES notation of the reactant
        product_smiles (str): SMILES notation of the product
    
    Returns:
        dict: Detailed analysis of the reaction sites
    """
    reactant = Chem.MolFromSmiles(reactant_smiles)
    product = Chem.MolFromSmiles(product_smiles)

    # Get reaction sites
    reactant_sites, product_sites, mcs = identify_reaction_sites(
        reactant_smiles, product_smiles)

    # Analyze environments around reaction sites
    analysis = {'reactant_sites': [], 'product_sites': [], 'mcs': mcs}

    # Analyze reactant sites
    for site in reactant_sites:
        environment = get_atom_environment(reactant, site)
        atom = reactant.GetAtomWithIdx(site)
        analysis['reactant_sites'].append({
            'atom_index': site,
            'atom_symbol': atom.GetSymbol(),
            'environment': environment
        })

    # Analyze product sites
    for site in product_sites:
        environment = get_atom_environment(product, site)
        atom = product.GetAtomWithIdx(site)
        analysis['product_sites'].append({
            'atom_index': site,
            'atom_symbol': atom.GetSymbol(),
            'environment': environment
        })

    return analysis

In [None]:
def get_atom_environment(mol, atom_idx, radius=1):
    """
    Gets the chemical environment around a specific atom.
    
    Args:
        mol (RDKit.Mol): The molecule to analyze
        atom_idx (int): Index of the atom of interest
        radius (int): Number of bonds to consider in the environment
    
    Returns:
        str: SMILES notation of the atom's environment
    """
    # Create a new molecule to highlight the environment
    env = Chem.FindAtomEnvironmentOfRadiusN(mol, radius, atom_idx)

    # This is the corrected approach - create an atom map to track atoms
    amap = {}
    submol = Chem.PathToSubmol(mol, env, atomMap=amap)

    # If we couldn't get a valid substructure, return empty string
    if submol.GetNumAtoms() == 0:
        return ""

    return Chem.MolToSmiles(submol)

In [75]:
# Example usage
if __name__ == "__main__":
    # Example molecules (ethanol to ethanoic acid)
    reactant_smiles = "c1c(CCNC)cc(C(=O)O)cc1"
    product_smiles = "c1cc(CCN(C)CC)c(C(=O)O)cc1"

    results = analyze_reaction(reactant_smiles, product_smiles)

    print("Reaction Analysis Results:")
    print(f"Maximum Common Substructure: {results['mcs']}")
    print("\nReactant reaction sites:")
    for site in results['reactant_sites']:
        print(f"Atom {site['atom_index']} ({site['atom_symbol']}):")
        print(f"Environment: {site['environment']}")

    print("\nProduct reaction sites:")
    for site in results['product_sites']:
        print(f"Atom {site['atom_index']} ({site['atom_symbol']}):")
        print(f"Environment: {site['environment']}")

Analyzing reaction:
Reactant: c1c(CCNC)cc(C(=O)O)cc1
Product: c1cc(CCN(C)CC)c(C(=O)O)cc1

Structural consistency check:
  consistent: True
  message: No obvious structural inconsistencies detected

Reaction site information:
  MCS SMARTS: [#6&R]1:&@[#6&R](-&!@[#6&!R]-&!@[#6&!R]-&!@[#7&!R]-&!@[#6&!R]):&@[#6&R]:&@[#6&R]:&@[#6&R]:&@[#6&R]:&@1
  Reactant reaction site atoms: [8, 9, 10]
  Product reaction site atoms: [7, 8, 10, 11, 12]
  Reactant attachment points: [7]
  Product attachment points: [9, 5]

Images are available in the result dictionary
Reaction Analysis Results:


TypeError: 'NoneType' object is not subscriptable

In [74]:
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import rdFMCS
from rdkit.Chem import Draw
import numpy as np


def identify_reaction_site(reactant_smiles, product_smiles, visualize=True):
    """
    Identifies the reaction site by comparing the reactant and product molecules.
    
    Parameters:
    -----------
    reactant_smiles : str
        SMILES notation of the reactant molecule
    product_smiles : str
        SMILES notation of the product molecule
    visualize : bool
        Whether to generate visualizations of the molecules and reaction site
        
    Returns:
    --------
    dict
        Dictionary containing information about the reaction site
    """
    # Convert SMILES to RDKit molecules
    try:
        reactant = Chem.MolFromSmiles(reactant_smiles)
        product = Chem.MolFromSmiles(product_smiles)

        if reactant is None:
            return {"error": "Invalid reactant SMILES notation"}
        if product is None:
            return {"error": "Invalid product SMILES notation"}
    except Exception as e:
        return {"error": f"Error parsing SMILES: {str(e)}"}

    # Find the Maximum Common Substructure (MCS)
    mcs = rdFMCS.FindMCS([reactant, product],
                         completeRingsOnly=True,
                         ringMatchesRingOnly=True,
                         matchValences=True)

    mcs_mol = Chem.MolFromSmarts(mcs.smartsString)

    # Get atom indices for the MCS in both molecules
    reactant_match = reactant.GetSubstructMatch(mcs_mol)
    product_match = product.GetSubstructMatch(mcs_mol)

    # Get all atoms in both molecules
    reactant_atoms = set(range(reactant.GetNumAtoms()))
    product_atoms = set(range(product.GetNumAtoms()))

    # Find atoms that are not part of the MCS (reaction site)
    reactant_reaction_site = reactant_atoms - set(reactant_match)
    product_reaction_site = product_atoms - set(product_match)

    # Identify neighboring atoms to the reaction site
    reactant_neighbors = set()
    for atom_idx in reactant_reaction_site:
        atom = reactant.GetAtomWithIdx(atom_idx)
        for neighbor in atom.GetNeighbors():
            neighbor_idx = neighbor.GetIdx()
            if neighbor_idx in reactant_match:
                reactant_neighbors.add(neighbor_idx)

    product_neighbors = set()
    for atom_idx in product_reaction_site:
        atom = product.GetAtomWithIdx(atom_idx)
        for neighbor in atom.GetNeighbors():
            neighbor_idx = neighbor.GetIdx()
            if neighbor_idx in product_match:
                product_neighbors.add(neighbor_idx)

    # Create molecules with highlighted reaction sites
    reactant_highlight = Chem.Mol(reactant)
    for idx in reactant_reaction_site:
        reactant_highlight.GetAtomWithIdx(idx).SetProp("atomNote", "R")
    for idx in reactant_neighbors:
        reactant_highlight.GetAtomWithIdx(idx).SetProp("atomNote", "N")

    product_highlight = Chem.Mol(product)
    for idx in product_reaction_site:
        product_highlight.GetAtomWithIdx(idx).SetProp("atomNote", "P")
    for idx in product_neighbors:
        product_highlight.GetAtomWithIdx(idx).SetProp("atomNote", "N")

    # Prepare result
    result = {
        "reactant_smiles": reactant_smiles,
        "product_smiles": product_smiles,
        "mcs_smarts": mcs.smartsString,
        "reactant_reaction_site": list(reactant_reaction_site),
        "product_reaction_site": list(product_reaction_site),
        "reactant_attachment_points": list(reactant_neighbors),
        "product_attachment_points": list(product_neighbors)
    }

    # Visualize if requested
    if visualize:
        # Generate 2D coordinates for better visualization
        AllChem.Compute2DCoords(reactant)
        AllChem.Compute2DCoords(product)
        AllChem.Compute2DCoords(reactant_highlight)
        AllChem.Compute2DCoords(product_highlight)

        # Create visualizations with highlighted atoms
        result["reactant_image"] = Draw.MolToImage(
            reactant_highlight, highlightAtoms=list(reactant_reaction_site))
        result["product_image"] = Draw.MolToImage(
            product_highlight, highlightAtoms=list(product_reaction_site))

    return result


def check_structural_inconsistencies(reactant_smiles, product_smiles):
    """
    Checks for structural inconsistencies between reactant and product,
    such as jumping side chains or unexpected rearrangements.
    
    Parameters:
    -----------
    reactant_smiles : str
        SMILES notation of the reactant molecule
    product_smiles : str
        SMILES notation of the product molecule
        
    Returns:
    --------
    dict
        Dictionary containing information about potential inconsistencies
    """
    reactant = Chem.MolFromSmiles(reactant_smiles)
    product = Chem.MolFromSmiles(product_smiles)

    if reactant is None or product is None:
        return {"error": "Invalid SMILES notation"}

    # Generate canonical SMILES to ensure consistent representation
    reactant_canonical = Chem.MolToSmiles(reactant, isomericSmiles=True)
    product_canonical = Chem.MolToSmiles(product, isomericSmiles=True)

    # Find MCS with more relaxed settings to detect potential issues
    mcs = rdFMCS.FindMCS([reactant, product],
                         completeRingsOnly=False,
                         ringMatchesRingOnly=False,
                         matchValences=False,
                         timeout=60)

    mcs_mol = Chem.MolFromSmarts(mcs.smartsString)

    # Get atom indices for the MCS in both molecules
    reactant_match = reactant.GetSubstructMatch(mcs_mol)
    product_match = product.GetSubstructMatch(mcs_mol)

    # Check for ring system changes
    reactant_rings = Chem.GetSSSR(reactant)
    product_rings = Chem.GetSSSR(product)

    if len(reactant_rings) != len(product_rings):
        return {
            "consistent":
            False,
            "issue":
            f"Ring count changed from {len(reactant_rings)} to {len(product_rings)}"
        }

    # Check for unusual changes in atom connectivity patterns
    # This would detect side chain "jumping" from ortho to meta positions
    reactant_adj_matrix = Chem.GetAdjacencyMatrix(reactant)
    product_adj_matrix = Chem.GetAdjacencyMatrix(product)

    # Check if the MCS is a significant portion of both molecules
    if len(reactant_match) < reactant.GetNumAtoms() * 0.7 or len(
            product_match) < product.GetNumAtoms() * 0.7:
        return {
            "consistent": False,
            "issue": "Significant structural rearrangement detected"
        }

    # Compare connected fragments to identify potential rearrangements
    # This is a simplistic approach - more sophisticated checks would be needed for complex molecules
    reactant_fragments = Chem.GetMolFrags(reactant)
    product_fragments = Chem.GetMolFrags(product)

    if len(reactant_fragments) != len(product_fragments):
        return {
            "consistent":
            False,
            "issue":
            f"Fragment count changed from {len(reactant_fragments)} to {len(product_fragments)}"
        }

    # If we passed all checks, the structures are likely consistent
    return {
        "consistent": True,
        "message": "No obvious structural inconsistencies detected"
    }


# Example usage
def analyze_reaction(reactant_smiles, product_smiles):
    """
    Comprehensive analysis of the reaction between reactant and product.
    
    Parameters:
    -----------
    reactant_smiles : str
        SMILES notation of the reactant molecule
    product_smiles : str
        SMILES notation of the product molecule
    """
    print(
        f"Analyzing reaction:\nReactant: {reactant_smiles}\nProduct: {product_smiles}\n"
    )

    # Check for structural inconsistencies
    inconsistency_check = check_structural_inconsistencies(
        reactant_smiles, product_smiles)
    print("Structural consistency check:")
    for key, value in inconsistency_check.items():
        print(f"  {key}: {value}")
    print()

    # Identify reaction site
    reaction_site_info = identify_reaction_site(reactant_smiles,
                                                product_smiles)

    if "error" in reaction_site_info:
        print(f"Error: {reaction_site_info['error']}")
        return

    print("Reaction site information:")
    print(f"  MCS SMARTS: {reaction_site_info['mcs_smarts']}")
    print(
        f"  Reactant reaction site atoms: {reaction_site_info['reactant_reaction_site']}"
    )
    print(
        f"  Product reaction site atoms: {reaction_site_info['product_reaction_site']}"
    )
    print(
        f"  Reactant attachment points: {reaction_site_info['reactant_attachment_points']}"
    )
    print(
        f"  Product attachment points: {reaction_site_info['product_attachment_points']}"
    )

    # If visualization was enabled, the images would be available in the result dict
    if "reactant_image" in reaction_site_info:
        print("\nImages are available in the result dictionary")
        # In a Jupyter notebook, you could display them with:
        # display(reaction_site_info["reactant_image"])
        # display(reaction_site_info["product_image"])


# Example: To use this code, uncomment and replace with your SMILES strings
reactant_smiles = "CC(=O)c1ccccc1"  # Acetophenone
product_smiles = "CC(O)(c1ccccc1)C"  # 2-Phenyl-2-propanol
reactant_smiles = "c1c(CCNC)cc(C(=O)O)cc1"
product_smiles = "c1cc(CCN(C)CC)c(C(=O)O)cc1"
analyze_reaction(reactant_smiles, product_smiles)

Analyzing reaction:
Reactant: c1c(CCNC)cc(C(=O)O)cc1
Product: c1cc(CCN(C)CC)c(C(=O)O)cc1

Structural consistency check:
  consistent: True
  message: No obvious structural inconsistencies detected

Reaction site information:
  MCS SMARTS: [#6&R]1:&@[#6&R](-&!@[#6&!R]-&!@[#6&!R]-&!@[#7&!R]-&!@[#6&!R]):&@[#6&R]:&@[#6&R]:&@[#6&R]:&@[#6&R]:&@1
  Reactant reaction site atoms: [8, 9, 10]
  Product reaction site atoms: [7, 8, 10, 11, 12]
  Reactant attachment points: [7]
  Product attachment points: [9, 5]

Images are available in the result dictionary


In [None]:
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem, rdmolops
from collections import Counter, defaultdict
import numpy as np


def compare_molecules(reactant_smiles, product_smiles):
    """
    Compare reactant and product molecules to detect potential hallucinations or inconsistencies.
    
    Parameters:
    -----------
    reactant_smiles : str
        SMILES string of the reactant molecule
    product_smiles : str
        SMILES string of the product molecule
        
    Returns:
    --------
    dict
        Dictionary containing validation results and detected issues
    """
    results = {
        "valid_reactant": False,
        "valid_product": False,
        "atom_count_consistent": False,
        "ring_size_changes": [],
        "substituent_position_changes": [],
        "detected_issues": []
    }

    # Check if SMILES strings are valid
    reactant_mol = Chem.MolFromSmiles(reactant_smiles)
    product_mol = Chem.MolFromSmiles(product_smiles)

    if reactant_mol is None:
        results["detected_issues"].append("Invalid reactant SMILES string")
        return results
    else:
        results["valid_reactant"] = True

    if product_mol is None:
        results["detected_issues"].append("Invalid product SMILES string")
        return results
    else:
        results["valid_product"] = True

    # Get basic molecule properties
    reactant_atoms = Counter(
        [atom.GetSymbol() for atom in reactant_mol.GetAtoms()])
    product_atoms = Counter(
        [atom.GetSymbol() for atom in product_mol.GetAtoms()])

    # Check atom count consistency
    for atom_symbol in set(
            list(reactant_atoms.keys()) + list(product_atoms.keys())):
        if reactant_atoms.get(atom_symbol,
                              0) != product_atoms.get(atom_symbol, 0):
            results["detected_issues"].append(
                f"Atom count mismatch for {atom_symbol}: "
                f"Reactant has {reactant_atoms.get(atom_symbol, 0)}, "
                f"Product has {product_atoms.get(atom_symbol, 0)}")

    if not any("Atom count mismatch" in issue
               for issue in results["detected_issues"]):
        results["atom_count_consistent"] = True

    # Check for ring size changes
    reactant_rings = Chem.GetSSSR(reactant_mol)
    product_rings = Chem.GetSSSR(product_mol)

    reactant_ring_sizes = [len(ring) for ring in reactant_rings]
    product_ring_sizes = [len(ring) for ring in product_rings]

    # Sort ring sizes for easier comparison
    reactant_ring_sizes.sort()
    product_ring_sizes.sort()

    if reactant_ring_sizes != product_ring_sizes:
        results["detected_issues"].append(
            f"Ring size change detected: Reactant rings {reactant_ring_sizes}, "
            f"Product rings {product_ring_sizes}")

        # Report specific ring changes
        for r_size in reactant_ring_sizes:
            if reactant_ring_sizes.count(r_size) > product_ring_sizes.count(
                    r_size):
                results["ring_size_changes"].append(
                    f"{r_size}-membered ring removed")

        for p_size in product_ring_sizes:
            if product_ring_sizes.count(p_size) > reactant_ring_sizes.count(
                    p_size):
                results["ring_size_changes"].append(
                    f"{p_size}-membered ring added")

    # Check for aromatic ring changes
    reactant_aromatic_atoms = set([
        atom.GetIdx() for atom in reactant_mol.GetAtoms()
        if atom.GetIsAromatic()
    ])
    product_aromatic_atoms = set([
        atom.GetIdx() for atom in product_mol.GetAtoms()
        if atom.GetIsAromatic()
    ])

    # Check if the number of aromatic atoms changed significantly
    if abs(len(reactant_aromatic_atoms) - len(product_aromatic_atoms)) > 2:
        results["detected_issues"].append(
            f"Significant change in aromaticity: Reactant has {len(reactant_aromatic_atoms)} "
            f"aromatic atoms, Product has {len(product_aromatic_atoms)}")

    # Advanced check for substituent position changes on rings
    check_ring_substituent_positions(reactant_mol, product_mol, results)

    # Check for unnecessary bond formations
    reactant_bonds = Counter(
        [bond.GetBondType() for bond in reactant_mol.GetBonds()])
    product_bonds = Counter(
        [bond.GetBondType() for bond in product_mol.GetBonds()])

    if sum(reactant_bonds.values()) < sum(product_bonds.values()):
        results["detected_issues"].append(
            f"Possible unnecessary bonds formed: Reactant has {sum(reactant_bonds.values())} bonds, "
            f"Product has {sum(product_bonds.values())} bonds")

    return results


def check_ring_substituent_positions(reactant_mol, product_mol, results):
    """
    Detect changes in the position of substituents on aromatic rings.
    
    Parameters:
    -----------
    reactant_mol : RDKit Mol
        RDKit molecule object of the reactant
    product_mol : RDKit Mol
        RDKit molecule object of the product
    results : dict
        Results dictionary to update with findings
    """
    # Get all ring systems in both molecules
    reactant_ring_info = identify_ring_systems(reactant_mol)
    product_ring_info = identify_ring_systems(product_mol)

    # If ring counts mismatch, this is already caught in the main function
    if len(reactant_ring_info) != len(product_ring_info):
        return

    # For each aromatic ring, identify and compare substituent patterns
    for r_idx, reactant_ring in enumerate(reactant_ring_info):
        if not reactant_ring['is_aromatic']:
            continue

        # Find a matching aromatic ring in the product
        matching_rings = [
            p for p in product_ring_info if p['is_aromatic']
            and p['size'] == reactant_ring['size'] and not p['matched']
        ]

        if not matching_rings:
            continue

        product_ring = matching_rings[0]
        product_ring['matched'] = True  # Mark this ring as matched

        # Identify substituents and their positions for both rings
        reactant_substituents = identify_substituents(reactant_mol,
                                                      reactant_ring)
        product_substituents = identify_substituents(product_mol, product_ring)

        # Create signature of each substituent
        reactant_sig = {}
        product_sig = {}

        for subst in reactant_substituents:
            sig = get_substituent_signature(reactant_mol, subst)
            if sig not in reactant_sig:
                reactant_sig[sig] = []
            reactant_sig[sig].append(subst)

        for subst in product_substituents:
            sig = get_substituent_signature(product_mol, subst)
            if sig not in product_sig:
                product_sig[sig] = []
            product_sig[sig].append(subst)

        # Check for position changes of similar substituents
        for sig in set(reactant_sig.keys()).intersection(
                set(product_sig.keys())):
            r_positions = [pos_map[s['position']] for s in reactant_sig[sig]]
            p_positions = [pos_map[s['position']] for s in product_sig[sig]]

            # Sort positions for easier comparison
            r_positions.sort()
            p_positions.sort()

            if r_positions != p_positions:
                # We found a substituent that has changed position
                subst_name = get_friendly_substituent_name(sig)
                results["detected_issues"].append(
                    f"Substituent position change detected: {subst_name} moved from "
                    f"{', '.join(r_positions)} to {', '.join(p_positions)} position(s)"
                )
                results["substituent_position_changes"].append({
                    "substituent":
                    subst_name,
                    "from_positions":
                    r_positions,
                    "to_positions":
                    p_positions
                })


def identify_ring_systems(mol):
    """
    Identify all ring systems in a molecule and their properties.
    
    Parameters:
    -----------
    mol : RDKit Mol
        RDKit molecule object
        
    Returns:
    --------
    list
        List of dictionaries containing ring information
    """
    rings = []
    ring_info = Chem.GetSSSR(mol)

    for idx, ring in enumerate(ring_info):
        ring_atoms = list(ring)
        is_aromatic = all(
            mol.GetAtomWithIdx(atom_idx).GetIsAromatic()
            for atom_idx in ring_atoms)

        rings.append({
            'id': idx,
            'atoms': ring_atoms,
            'size': len(ring_atoms),
            'is_aromatic': is_aromatic,
            'matched':
            False  # Used later for matching rings between reactant and product
        })

    return rings


def identify_substituents(mol, ring_info):
    """
    Identify all substituents attached to a ring and their positions.
    
    Parameters:
    -----------
    mol : RDKit Mol
        RDKit molecule object
    ring_info : dict
        Dictionary containing ring information
        
    Returns:
    --------
    list
        List of dictionaries containing substituent information
    """
    substituents = []
    ring_atoms = set(ring_info['atoms'])

    # Get connections from ring atoms to non-ring atoms
    for ring_atom_idx in ring_atoms:
        ring_atom = mol.GetAtomWithIdx(ring_atom_idx)

        for neighbor in ring_atom.GetNeighbors():
            neighbor_idx = neighbor.GetIdx()

            # Skip atoms that are part of the ring
            if neighbor_idx in ring_atoms:
                continue

            # Determine the position (ortho, meta, para) relative to other substituents
            position = determine_ring_position(mol, ring_atom_idx, ring_atoms,
                                               ring_info['size'])

            # Find the entire substituent group connected to this point
            subst_atoms = get_connected_atoms(mol, neighbor_idx, ring_atoms)

            substituents.append({
                'attachment_point': ring_atom_idx,
                'first_atom': neighbor_idx,
                'atoms': subst_atoms,
                'position': position
            })

    return substituents


def determine_ring_position(mol, atom_idx, ring_atoms, ring_size):
    """
    Determine the position of a substituent on a ring (ortho, meta, para, etc.).
    
    Parameters:
    -----------
    mol : RDKit Mol
        RDKit molecule object
    atom_idx : int
        Index of the ring atom where the substituent is attached
    ring_atoms : set
        Set of atom indices that form the ring
    ring_size : int
        Size of the ring
        
    Returns:
    --------
    str
        Position description ("ortho", "meta", "para", or numbered position)
    """
    # For 6-membered rings, use ortho/meta/para nomenclature
    if ring_size == 6:
        # Find other substituents on the ring
        other_subst = []
        for ring_atom in ring_atoms:
            if ring_atom == atom_idx:
                continue

            atom = mol.GetAtomWithIdx(ring_atom)
            for neighbor in atom.GetNeighbors():
                if neighbor.GetIdx() not in ring_atoms:
                    other_subst.append(ring_atom)
                    break

        # If no other substituents, just return position number
        if not other_subst:
            return "1"

        # Calculate distance to other substituents
        distances = {}
        for other in other_subst:
            # Use shortest path through the ring
            path = rdmolops.GetShortestPath(mol, atom_idx, other)
            if path:
                path_len = len(
                    path
                ) - 1  # Subtract 1 because path includes both endpoints

                # Convert distance to position name
                if path_len == 1:
                    pos = "ortho"
                elif path_len == 2:
                    pos = "meta"
                elif path_len == 3:
                    pos = "para"
                else:
                    pos = str(path_len)

                distances[other] = pos

        # Return the closest position if multiple are found
        if distances:
            positions = list(distances.values())
            # Prioritize ortho, then meta, then para for consistent naming
            if "ortho" in positions:
                return "ortho"
            elif "meta" in positions:
                return "meta"
            elif "para" in positions:
                return "para"
            else:
                return positions[0]

    # For other ring sizes, use numbered positions (1, 2, 3, etc.)
    return "1"  # Default for now


def get_connected_atoms(mol, start_idx, exclude_atoms):
    """
    Get all atoms connected to a starting atom, excluding a set of atoms.
    
    Parameters:
    -----------
    mol : RDKit Mol
        RDKit molecule object
    start_idx : int
        Index of the starting atom
    exclude_atoms : set
        Set of atom indices to exclude
        
    Returns:
    --------
    list
        List of atom indices that form the connected component
    """
    visited = set([start_idx])
    queue = [start_idx]

    while queue:
        current = queue.pop(0)
        atom = mol.GetAtomWithIdx(current)

        for neighbor in atom.GetNeighbors():
            neighbor_idx = neighbor.GetIdx()

            if neighbor_idx not in visited and neighbor_idx not in exclude_atoms:
                visited.add(neighbor_idx)
                queue.append(neighbor_idx)

    return list(visited)


def get_substituent_signature(mol, substituent):
    """
    Generate a signature for a substituent to identify similar groups.
    
    Parameters:
    -----------
    mol : RDKit Mol
        RDKit molecule object
    substituent : dict
        Dictionary containing substituent information
        
    Returns:
    --------
    str
        Signature string for the substituent
    """
    # Create a fragment of just the substituent
    atoms = substituent['atoms']
    if not atoms:
        return ""

    # Get the SMILES of the fragment
    # This is a simplified approach - a more robust one would create a proper fragment
    atom_symbols = []
    for atom_idx in atoms:
        atom = mol.GetAtomWithIdx(atom_idx)
        atom_symbols.append(atom.GetSymbol())

    # Count elements as a basic signature
    elem_counts = Counter(atom_symbols)
    signature = ".".join(f"{elem}{count}"
                         for elem, count in sorted(elem_counts.items()))

    # For more complex substituents, we could use a more sophisticated approach
    # like a Morgan fingerprint or a proper SMILES fragment

    return signature


def get_friendly_substituent_name(signature):
    """
    Convert a substituent signature to a friendly name when possible.
    
    Parameters:
    -----------
    signature : str
        Signature string for the substituent
        
    Returns:
    --------
    str
        Friendly name for the substituent
    """
    # Map of common substituent signatures to friendly names
    common_substituents = {
        "C1": "Methyl",
        "C2": "Ethyl",
        "C3": "Propyl",
        "N1": "Amino",
        "O1": "Hydroxy",
        "O2": "Carboxyl",
        "O2.C1": "Carboxyl acid",
        "Cl1": "Chloro",
        "Br1": "Bromo",
        "F1": "Fluoro",
        "I1": "Iodo",
        "N1.C1": "Methylamino",
        "C1.O1": "Hydroxy methyl",
        "C1.N1": "Aminomethyl",
        "N1.O1": "Nitro",
        "N1.O2": "Nitro",
        "S1": "Thiol"
    }

    return common_substituents.get(signature, f"Group ({signature})")


# Position mapping for consistent naming
pos_map = {
    "1": "position 1",
    "2": "position 2",
    "3": "position 3",
    "4": "position 4",
    "5": "position 5",
    "6": "position 6",
    "ortho": "ortho",
    "meta": "meta",
    "para": "para"
}

# Test with the provided example
if __name__ == "__main__":
    test_reactant = "c1c(CCNC)cc(CC(=O)O)cc1"
    test_product = "c1cc(CCN(C)CC)c(CC(=O)O)cc1"

    # test_reactant = "C2CC1CCCC1C2"
    # test_product = "C2CCC1CCCCC1C2"

    test_reactant = "CC(=O)O"
    test_product = "c1cc(CCN(C)CC)c(CC(=O)O)cc1"

    print("Testing with provided example:")
    print(f"Reactant: {test_reactant}")
    print(f"Product: {test_product}")

    results = compare_molecules(test_reactant, test_product)

    print(results)

    print("\nValidation Results:")
    print("-------------------")

    if results['substituent_position_changes']:
        print("\nSubstituent Position Changes:")
        for change in results['substituent_position_changes']:
            print(
                f"- {change['substituent']} moved from {', '.join(change['from_positions'])} "
                f"to {', '.join(change['to_positions'])}")

    if results['detected_issues']:
        print("\nDetected Issues:")
        for issue in results['detected_issues']:
            print(f"- {issue}")

    # print("\n\nTo run with your own SMILES strings:")
    # main()

Testing with provided example:
Reactant: CC(=O)O.CCC
Product: c1cc(CCN(C)CC)c(CC(=O)O)cc1
{'valid_reactant': True, 'valid_product': True, 'atom_count_consistent': False, 'ring_size_changes': ['6-membered ring added'], 'substituent_position_changes': [], 'detected_issues': ['Atom count mismatch for N: Reactant has 0, Product has 1', 'Atom count mismatch for C: Reactant has 5, Product has 13', 'Ring size change detected: Reactant rings [], Product rings [6]', 'Significant change in aromaticity: Reactant has 0 aromatic atoms, Product has 6', 'Possible unnecessary bonds formed: Reactant has 5 bonds, Product has 16 bonds']}

Validation Results:
-------------------

Detected Issues:
- Atom count mismatch for N: Reactant has 0, Product has 1
- Atom count mismatch for C: Reactant has 5, Product has 13
- Ring size change detected: Reactant rings [], Product rings [6]
- Significant change in aromaticity: Reactant has 0 aromatic atoms, Product has 6
- Possible unnecessary bonds formed: Reactant h

In [83]:
from rdkit import Chem
from rdkit.Chem import AllChem, rdMolDescriptors
from collections import defaultdict


def validate_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return False, None
    try:
        Chem.SanitizeMol(mol)
        return True, mol
    except:
        return False, None


def get_ring_details(mol):
    ring_info = []
    ri = mol.GetRingInfo()
    for ring in ri.AtomRings():
        if all(mol.GetAtomWithIdx(i).GetIsAromatic() for i in ring):
            substituents = []
            for atom_idx in ring:
                atom = mol.GetAtomWithIdx(atom_idx)
                for neighbor in atom.GetNeighbors():
                    if neighbor.GetIdx() not in ring:
                        substituents.append(atom_idx)
            ring_info.append({
                'size':
                len(ring),
                'substituents':
                sorted(ring.index(s) for s in substituents),
                'atom_indices':
                ring
            })
    return ring_info


def get_min_distance(pos1, pos2, ring_size):
    diff = abs(pos1 - pos2)
    return min(diff, ring_size - diff)


def compare_positions(reactant_rings, product_rings):
    position_changes = []

    for r_ring in reactant_rings:
        for p_ring in product_rings:
            if (r_ring['size'] == p_ring['size'] and len(
                    r_ring['substituents']) == len(p_ring['substituents'])):

                # Get positions in normalized order
                r_subs = sorted(r_ring['substituents'])
                p_subs = sorted(p_ring['substituents'])

                # Calculate pairwise distances
                r_dist = get_min_distance(r_subs[0], r_subs[1], r_ring['size'])
                p_dist = get_min_distance(p_subs[0], p_subs[1], p_ring['size'])

                if r_dist != p_dist:
                    position_changes.append({
                        'reactant_positions': r_subs,
                        'product_positions': p_subs,
                        'reactant_distance': r_dist,
                        'product_distance': p_dist,
                        'ring_size': r_ring['size']
                    })

    return position_changes


def analyze_position_changes(position_changes):
    results = []
    for change in position_changes:
        if change['ring_size'] == 6:
            r_dist = change['reactant_distance']
            p_dist = change['product_distance']
            if (r_dist == 1 and p_dist == 2) or (r_dist == 2 and p_dist == 1):
                results.append("Ortho <-> Meta position change detected")
            elif (r_dist == 1 and p_dist == 3) or (r_dist == 3
                                                   and p_dist == 1):
                results.append("Ortho <-> Para position change detected")
            elif (r_dist == 2 and p_dist == 3) or (r_dist == 3
                                                   and p_dist == 2):
                results.append("Meta <-> Para position change detected")
    return results


def compare_structures(reactant_smiles, product_smiles):
    results = {'position_changes': [], 'other_issues': []}

    valid_r, r_mol = validate_smiles(reactant_smiles)
    valid_p, p_mol = validate_smiles(product_smiles)

    if not valid_r or not valid_p:
        results['error'] = "Invalid SMILES input"
        return results

    r_rings = get_ring_details(r_mol)
    p_rings = get_ring_details(p_mol)

    position_changes = compare_positions(r_rings, p_rings)
    analyzed_changes = analyze_position_changes(position_changes)

    results['position_changes'] = analyzed_changes
    return results


# Example usage
reactant = "c1c(CCNC)cc(C(=O)O)cc1"
product = "c1cc(CCN(C)CC)c(C(=O)O)cc1"

reactant = "C2CC1CCCC1C2"
product = "C2CCC1CCCCC1C2"

results = compare_structures(reactant, product)
print(results)
print("Position changes detected:", results['position_changes'])

{'position_changes': [], 'other_issues': []}
Position changes detected: []
