In [None]:
from __future__ import annotations

import itertools
from typing import List, Literal

import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem, rdMolTransforms
from rdkit.Geometry import Point3D
from rdkit.Geometry.rdGeometry import Point3D


def embed(
    mol: Chem.Mol, num_confs: int = 100, rmsd_threshold: float | None = None
) -> Chem.Mol:
    """Embed `nconf` ETKDGv3 conformers and MMFF94‑optimise them in place."""
    params = AllChem.ETKDGv3()
    params.randomSeed = 0xC0FFEE
    if rmsd_threshold:
        params.pruneRmsThresh = rmsd_threshold
    params.numThreads = 0  # use all cores for embedding
    AllChem.EmbedMultipleConfs(mol, numConfs=num_confs, params=params)

    props = AllChem.MMFFGetMoleculeProperties(mol, mmffVariant="MMFF94")
    for cid in range(mol.GetNumConformers()):
        ff = AllChem.MMFFGetMoleculeForceField(mol, props, confId=cid)
        ff.Minimize(maxIts=250)
    return mol


def get_bond_vector_angle(conf, a1, b1, a2, b2):
    """Angle between bonds a1–b1 and a2–b2 in an RDKit Conformer."""
    v1 = conf.GetAtomPosition(a1) - conf.GetAtomPosition(b1)
    v2 = conf.GetAtomPosition(a2) - conf.GetAtomPosition(b2)
    ang = Point3D.AngleTo(v1, v2)
    return ang


def get_vectors(rdmol: Chem.rdchem.Mol, atom: Literal["H", "dummy"], conf_id: int = -1):
    if rdmol.GetNumConformers() == 0:
        raise ValueError("Molecule has no conformers")

    atom_idx = 0 if atom == "dummy" else 1
    sub_idxs = [
        atom.GetIdx() for atom in rdmol.GetAtoms() if atom.GetAtomicNum() == atom_idx
    ]

    conf = rdmol.GetConformer(conf_id)

    attachments = []

    for h1, h2 in list(itertools.combinations(sub_idxs, 2)):
        ha1 = (
            rdmol.GetAtomWithIdx(h1).GetNeighbors()[0].GetIdx()
        )  # Only have 1 neighbour - the heavy atom
        ha2 = rdmol.GetAtomWithIdx(h2).GetNeighbors()[0].GetIdx()

        bond_vector_angle = get_bond_vector_angle(conf, h1, ha1, h2, ha2)
        dihedral_angle = rdMolTransforms.GetDihedralRad(conf, h1, ha1, h2, ha2)
        h_distance = Point3D.Distance(
            conf.GetAtomPosition(h1), conf.GetAtomPosition(h2)
        )

        _attachments = dict(
            attachment_atom_idxs=(h1, h2),
            bond_vector_angle=bond_vector_angle,
            dihedral_angle=dihedral_angle,
            h_distance=h_distance,
        )
        attachments.append(_attachments)

    return pd.DataFrame(attachments)


def _get_ring_systems(mol: Chem.Mol) -> List[set[int]]:
    """Return a list of atom-index sets, one per *ring system* (fused rings collapsed)."""
    ring_info = mol.GetRingInfo()
    rings = [set(r) for r in ring_info.AtomRings()]

    systems: List[set[int]] = []
    for r in rings:
        merged = False
        for s in systems:
            if r & s:  # share ≥1 atom → same system
                s |= r
                merged = True
                break
        if not merged:
            systems.append(set(r))

    # one more pass in case several rings chained together
    changed = True
    while changed:
        changed = False
        out: List[set[int]] = []
        for s in systems:
            for o in out:
                if s & o:
                    o |= s
                    changed = True
                    break
            else:
                out.append(set(s))
        systems = out
    return systems


def get_exocyclic_bonds(mol, ring_system):
    exo_bonds = []
    for bond in mol.GetBonds():
        a, b = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        if (a in ring_system) ^ (b in ring_system):  # only one atom in ring
            if bond.GetBondType() == Chem.BondType.SINGLE and (
                bond.GetBeginAtom().GetAtomicNum() != 1
                and bond.GetEndAtom().GetAtomicNum() != 1
            ):
                exo_bonds.append(bond.GetIdx())
    if len(exo_bonds) == 2:
        return exo_bonds
    else:
        print(exo_bonds)
        return []


def stitch_fragments(
    central: Chem.Mol,
    r1: Chem.Mol,
    r2: Chem.Mol,
    bond_order: Chem.BondType = Chem.BondType.SINGLE,
) -> Chem.Mol:
    """Re-connects `central`, `r1`, and `r2` into a single molecule.

    Assumes each fragment still contains dummy atoms (atomicNum == 0) whose
    *molAtomMapNumber* property indicates which pieces should be joined:
        • central has dummies tagged 1 and 2
        • r1  has one dummy tagged 1
        • r2  has one dummy tagged 2

    Parameters
    ----------
    central, r1, r2 : Chem.Mol
        The three pieces from `cut_molecule`.
    bond_order : Chem.BondType, optional
        Bond type to use when reconnecting (defaults to *single*).

    Returns
    -------
    Chem.Mol
        A sanitised molecule with the dummies removed and the pieces fused.

    Raises
    ------
    ValueError
        If any map number does not occur exactly twice (once on each partner).

    """
    # 1️merge the three fragments into one RWMol
    merged = Chem.CombineMols(central, r1)
    merged = Chem.CombineMols(merged, r2)
    rw = Chem.RWMol(merged)

    # locate all dummy atoms and bucket them by map number
    dummies_by_map = {}
    all_dummy_idx = []
    for atom in rw.GetAtoms():
        if atom.GetAtomicNum() == 0:
            mnum = atom.GetIsotope()
            dummies_by_map.setdefault(mnum, []).append(atom.GetIdx())
            all_dummy_idx.append(atom.GetIdx())

    # for each label (1, 2, …) add a bond between the *real* neighbours
    for mnum, idxs in dummies_by_map.items():
        if len(idxs) != 2:
            raise ValueError(f"Label {mnum} occurs {len(idxs)} times – expected 2")
        d1, d2 = idxs
        n1 = rw.GetAtomWithIdx(d1).GetNeighbors()[0].GetIdx()
        n2 = rw.GetAtomWithIdx(d2).GetNeighbors()[0].GetIdx()
        rw.AddBond(n1, n2, bond_order)

    # delete the dummies (descending order so indices stay valid)
    for idx in sorted(all_dummy_idx, reverse=True):
        rw.RemoveAtom(idx)

    # final clean-up
    Chem.SanitizeMol(rw)

    Chem.AssignStereochemistryFrom3D(rw)
    Chem.AssignAtomChiralTagsFromStructure(rw)
    return rw.GetMol()


def cut_molecule(mol: Chem.Mol, bonds) -> List[List[Chem.Mol]]:
    split_mols = []
    for ring_system, bond_ids in bonds:
        # Insert dummies with map numbers 1 and 2 so the caller knows which is which
        fragged = Chem.FragmentOnBonds(
            mol,
            bond_ids,
            addDummies=True,
            dummyLabels=[(1, 1), (2, 2)],
        )

        frags = Chem.GetMolFrags(fragged, asMols=True, sanitizeFrags=True)
        atoms_in_ring = [
            sum(
                (atom.GetIdx() in ring_system and atom.GetAtomicNum() != 1)
                for atom in f.GetAtoms()
            )
            for f in frags
        ]
        central = frags[np.argmax(atoms_in_ring)]
        rgroups_ixs = np.argwhere(atoms_in_ring != np.max(atoms_in_ring)).flatten()
        rgroups = np.array(frags)[rgroups_ixs]

        central.UpdatePropertyCache()
        for rgroup in rgroups:
            rgroup.UpdatePropertyCache()

        split_mols.append((central, rgroups[0], rgroups[1]))

    return split_mols


def replace_atoms_with_dummy(mol: Chem.Mol, atom_indices: list[int]) -> Chem.Mol:
    # Make a deep copy so the original mol isn't modified

    mol_copy = Chem.RWMol(Chem.Mol(mol))

    for i, idx in enumerate(atom_indices):
        atom = mol_copy.GetAtomWithIdx(idx)
        atom.SetAtomicNum(0)  # Make it a dummy atom
        atom.SetIsotope(i + 1)  # Label as 1*, 2*, ...

    return mol_copy.GetMol()


def extract_conformer(mol: Chem.Mol, conf_id: int) -> Chem.Mol:
    """Extract a specific conformer from a molecule as a separate single-conformer mol.

    Args:
        mol: RDKit molecule with multiple conformers.
        conf_id: ID of the conformer to extract.

    Returns:
        A new Chem.Mol with only the selected conformer.

    """
    # Deep copy of the molecule without conformers
    new_mol = Chem.Mol(mol)
    new_mol.RemoveAllConformers()

    # Copy the desired conformer
    conf = mol.GetConformer(conf_id)
    new_conf = Chem.Conformer(conf)
    new_mol.AddConformer(new_conf, assignId=True)

    return new_mol


def save_mols_to_sdf(mols: list[Chem.Mol], filename: str) -> None:
    writer = Chem.SDWriter(filename)
    for mol in mols:
        if mol is not None:
            writer.write(mol)
    writer.close()


def angle_diff(a, b):
    """Compute minimal angular difference (radians), accounting for 2π periodicity."""
    return np.abs((a - b + np.pi) % (2 * np.pi) - np.pi)


def find_df_matches_with_tolerance(
    df: pd.DataFrame,
    query: tuple[float, float, float],
    tol_bond: float,
    tol_dihedral: float,
    tol_dist: float,
) -> pd.DataFrame:
    """Filter rows of a DataFrame where each component is within tolerance of query.
    Angles are assumed to be in radians.

    Args:
        df: DataFrame with columns ['bond_vector_angle', 'dihedral_angle', 'h_distance']
        query: tuple of (bond_vector_angle, dihedral_angle, h_distance)
        tol_*: tolerances for each component

    Returns:
        Filtered DataFrame of matches

    """
    a_bond, a_dihedral, a_dist = query

    bond_diff = angle_diff(df["bond_vector_angle"].to_numpy(), a_bond)
    dihedral_diff = angle_diff(df["dihedral_angle"].to_numpy(), a_dihedral)
    dist_diff = np.abs(df["h_distance"].to_numpy() - a_dist)

    mask = (
        (bond_diff <= tol_bond)
        & (dihedral_diff <= tol_dihedral)
        & (dist_diff <= tol_dist)
    )

    return df[mask]

In [None]:
# Make database of rings in approved drugs
approved_ring_system_smiles = ["c1ccccc1", "c1cccnc1"]

results = []

mol_dict = {}
vector_dfs = []
for mol_ix, smi in enumerate(approved_ring_system_smiles):
    mol = Chem.MolFromSmiles(smi)
    mol = Chem.AddHs(mol)
    embed(mol, rmsd_threshold=0.25)
    mol_dict[mol_ix] = mol.ToBinary()
    for conf_id in range(mol.GetNumConformers()):
        vector_df = get_vectors(mol, atom="H", conf_id=conf_id)
        vector_df["conf_id"] = conf_id
        vector_df["mol_ix"] = mol_ix
        vector_dfs.append(vector_df)

vector_df = pd.concat(vector_dfs).reset_index(drop=True)

vector_df.sample(5)

In [None]:
smiles = "CCOc1cnc2ccccc2c1OC"  # naphthalene core with two single-bond side chains
# smiles = "COc1ncccc1C"
mol = Chem.RemoveHs(Chem.MolFromSmiles(smiles))
AllChem.EmbedMolecule(mol)


remols = []
bond_to_break = [
    (ring_system, get_exocyclic_bonds(mol, ring_system))
    for ring_system in _get_ring_systems(mol)
]
split_mols = cut_molecule(mol, bonds=bond_to_break)
for central, rgroup0, rgroup1 in split_mols:
    ref_vector_df = get_vectors(central, atom="dummy")
    for row in ref_vector_df.itertuples():
        matches_df = find_df_matches_with_tolerance(
            vector_df,
            query=(row.bond_vector_angle, row.dihedral_angle, row.h_distance),
            tol_bond=0.1,
            tol_dihedral=0.1,
            tol_dist=0.5,
        )

        for row in matches_df.itertuples():
            # grab molecule
            query_mol = Chem.Mol(mol_dict[row.mol_ix])

            # Turn hydrogens into dummies
            query_mol_w_dummies = replace_atoms_with_dummy(
                query_mol, atom_indices=row.attachment_atom_idxs
            )
            query_mol_w_dummies = Chem.RemoveHs(query_mol_w_dummies)

            # TODO: Optimise the positions of `query_mol_w_dummies` to match `central`

            remol = stitch_fragments(query_mol_w_dummies, rgroup0, rgroup1)
            remol = extract_conformer(remol, conf_id=row.conf_id)

            remols.append(remol)

Chem.MolToMolFile(mol, "/tmp/ref.sdf")
save_mols_to_sdf(remols, "/tmp/re.sdf")