In [None]:
from rdkit import Chem
from rdkit.Chem import rdShapeAlign
import numpy as np
import py3Dmol

DEFAULT_PALETTE = {
    "H":  "#FFFFFF",
    "C":  "#909090",
    "N":  "#3050F8",
    "O":  "#FF0D0D",
    "F":  "#90E050",
    "P":  "#FF8000",
    "S":  "#FFFF30",
    "Cl": "#1FF01F",
    "Br": "#A62929",
    "I":  "#940094",
    "Si": "#F0C8A0",
    "B":  "#FFB5B5",
    "Na": "#AB5CF2",
    "K":  "#8F40D4",
    "Mg": "#8AFF00",
    "Ca": "#3DFF00",
    "Fe": "#E06633",
    "Zn": "#7D80B0",
}

def _hex_to_rgb(h):
    h = h.lstrip("#")
    return tuple(int(h[i:i+2], 16) for i in (0, 2, 4))

def _rgb_to_hex(rgb):
    return "#{:02X}{:02X}{:02X}".format(*rgb)

def tint_map(base_map, tint_hex="#FF00FF", strength=0.35):
    """
    Mix each element color with tint_hex.
    strength=0 -> original colors
    strength=1 -> fully tint color
    """
    tr, tg, tb = _hex_to_rgb(tint_hex)
    out = {}
    for elem, hx in base_map.items():
        r, g, b = _hex_to_rgb(hx)
        r2 = int(round((1 - strength) * r + strength * tr))
        g2 = int(round((1 - strength) * g + strength * tg))
        b2 = int(round((1 - strength) * b + strength * tb))
        out[elem] = _rgb_to_hex((r2, g2, b2))
    return out

def overlay_molblocks_3d_display(ref_molblock, probe_molblock,
                           probe_tint="#fc9803", tint_strength=0.40,
                           width=650, height=450):
    view = py3Dmol.view(width=width, height=height)

    # model 0 = reference
    view.addModel(ref_molblock, "sdf")
    view.setStyle(
        {"model": 0},
        {"stick": {"colorscheme": {"prop": "elem", "map": DEFAULT_PALETTE}}}
    )

    # model 1 = probe (tinted element colors)
    probe_map = tint_map(DEFAULT_PALETTE, tint_hex=probe_tint, strength=tint_strength)
    view.addModel(probe_molblock, "sdf")
    view.setStyle(
        {"model": 1},
        {"stick": {"colorscheme": {"prop": "elem", "map": probe_map}, "opacity": 0.85}}
    )

    view.zoomTo()
    view.show()

def _random_rotation_matrix(rng: np.random.Generator) -> np.ndarray:
    # random unit quaternion -> rotation matrix
    u1, u2, u3 = rng.random(3)
    q1 = np.sqrt(1 - u1) * np.sin(2*np.pi*u2)
    q2 = np.sqrt(1 - u1) * np.cos(2*np.pi*u2)
    q3 = np.sqrt(u1)     * np.sin(2*np.pi*u3)
    q4 = np.sqrt(u1)     * np.cos(2*np.pi*u3)
    R = np.array([
        [1 - 2*(q3*q3 + q4*q4),     2*(q2*q3 - q1*q4),     2*(q2*q4 + q1*q3)],
        [    2*(q2*q3 + q1*q4), 1 - 2*(q2*q2 + q4*q4),     2*(q3*q4 - q1*q2)],
        [    2*(q2*q4 - q1*q3),     2*(q3*q4 + q1*q2), 1 - 2*(q2*q2 + q3*q3)]
    ], dtype=float)
    return R

def _apply_rigid_transform(mol: Chem.Mol, conf_id: int, R: np.ndarray, t: np.ndarray) -> None:
    conf = mol.GetConformer(conf_id)
    for i in range(mol.GetNumAtoms()):
        p = np.array(conf.GetAtomPosition(i), dtype=float)
        p2 = R @ p + t
        conf.SetAtomPosition(i, p2)

def _calc_rmsd(mol1: Chem.Mol, mol2: Chem.Mol, conf_id1: int = -1, conf_id2: int = -1) -> float:
    """Calculate RMSD between corresponding atom positions of two aligned molecules."""
    conf1 = mol1.GetConformer(conf_id1)
    conf2 = mol2.GetConformer(conf_id2)
    n_atoms = mol1.GetNumAtoms()
    assert n_atoms == mol2.GetNumAtoms(), "Molecules must have the same number of atoms"
    
    sq_dists = []
    for i in range(n_atoms):
        p1 = np.array(conf1.GetAtomPosition(i))
        p2 = np.array(conf2.GetAtomPosition(i))
        sq_dists.append(np.sum((p1 - p2)**2))
    
    return np.sqrt(np.mean(sq_dists))

def align_molblocks_shape(
    ref_molblock: str,
    probe_molblock: str,
    ref_conf_id: int = -1,
    probe_conf_id: int = -1,
    n_starts: int = 25,
    use_colors: bool = True,
    seed: int = 0,
):
    ref = Chem.MolFromMolBlock(ref_molblock, removeHs=False)
    probe0 = Chem.MolFromMolBlock(probe_molblock, removeHs=False)
    if ref is None or probe0 is None:
        raise ValueError("Failed to parse one of the MolBlocks.")
    if ref.GetNumConformers() == 0 or probe0.GetNumConformers() == 0:
        raise ValueError("Both molecules must have 3D coordinates (at least one conformer).")

    # Speedup when aligning many probes: precompute the reference shape once
    ref_shape = rdShapeAlign.PrepareConformer(ref, confId=ref_conf_id)

    rng = np.random.default_rng(seed)
    best = (-1e9, None, None)  # (score, (shapeT, colorT), aligned_probe)

    # pick scoring objective
    def score(shapeT, colorT):
        return colorT if use_colors else shapeT

    for k in range(n_starts):
        probe = Chem.Mol(probe0)  # copy
        # random rigid-body pose as a different starting point
        R = _random_rotation_matrix(rng)
        t = rng.normal(scale=5.0, size=3)  # random translation (angstrom-ish scale)
        _apply_rigid_transform(probe, conf_id=probe_conf_id, R=R, t=t)

        shapeT, colorT = rdShapeAlign.AlignMol(
            ref_shape, probe,
            probeConfId=probe_conf_id,
            useColors=use_colors
        )

        sc = score(shapeT, colorT)
        if sc > best[0]:
            best = (sc, (shapeT, colorT), probe)

    _, (shapeT, colorT), aligned_probe = best
    rmsd = _calc_rmsd(ref, aligned_probe, ref_conf_id, probe_conf_id)
    return aligned_probe, shapeT, colorT, rmsd

def viz_align(ref, probe):
    aligned_probe, shapeT, colorT, rmsd = align_molblocks_shape(
        ref, probe, n_starts=30, use_colors=True, seed=123
    )
    print("shapeTanimoto:", shapeT, "colorTanimoto:", colorT, "RMSD:", rmsd)
    
    ref_m = Chem.MolFromMolBlock(ref, removeHs=False)
    probe_block_aligned = Chem.MolToMolBlock(aligned_probe)
    
    overlay_molblocks_3d_display(Chem.MolToMolBlock(ref_m), probe_block_aligned)

In [None]:
ref_block = """
     RDKit          3D

 29 30  0  0  0  0  0  0  0  0999 V2000
    0.1916    0.2919   -1.1488 N   0  0  0  0  0  0  0  0  0  0  0  0
    1.0530    1.1006   -0.5109 C   0  0  0  0  0  0  0  0  0  0  0  0
    0.9166    2.3235   -0.4667 O   0  0  0  0  0  0  0  0  0  0  0  0
    2.2170    0.4122    0.1156 C   0  0  0  0  0  0  0  0  0  0  0  0
   -1.0720    0.7455   -1.6969 C   0  0  0  0  0  0  0  0  0  0  0  0
   -2.2224    0.2667   -0.8474 C   0  0  0  0  0  0  0  0  0  0  0  0
   -2.2987    0.6593    0.4842 C   0  0  0  0  0  0  0  0  0  0  0  0
   -3.1902   -0.5859   -1.3591 C   0  0  0  0  0  0  0  0  0  0  0  0
   -3.3273    0.2075    1.2904 C   0  0  0  0  0  0  0  0  0  0  0  0
   -4.2234   -1.0350   -0.5534 C   0  0  0  0  0  0  0  0  0  0  0  0
   -4.2922   -0.6415    0.7721 C   0  0  0  0  0  0  0  0  0  0  0  0
    3.4586    1.0392    0.0866 C   0  0  0  0  0  0  0  0  0  0  0  0
    2.0817   -0.8131    0.7625 C   0  0  0  0  0  0  0  0  0  0  0  0
    4.5589    0.4347    0.6663 C   0  0  0  0  0  0  0  0  0  0  0  0
    3.1828   -1.4082    1.3521 C   0  0  0  0  0  0  0  0  0  0  0  0
    4.4224   -0.7916    1.2967 C   0  0  0  0  0  0  0  0  0  0  0  0
    0.3506   -0.7039   -1.1314 H   0  0  0  0  0  0  0  0  0  0  0  0
   -1.0324    1.8389   -1.7032 H   0  0  0  0  0  0  0  0  0  0  0  0
   -1.1722    0.3814   -2.7219 H   0  0  0  0  0  0  0  0  0  0  0  0
   -1.5456    1.3257    0.8818 H   0  0  0  0  0  0  0  0  0  0  0  0
   -3.1383   -0.8969   -2.3919 H   0  0  0  0  0  0  0  0  0  0  0  0
   -3.3780    0.5201    2.3226 H   0  0  0  0  0  0  0  0  0  0  0  0
   -4.9731   -1.6964   -0.9603 H   0  0  0  0  0  0  0  0  0  0  0  0
   -5.0956   -0.9934    1.4011 H   0  0  0  0  0  0  0  0  0  0  0  0
    3.5407    2.0006   -0.3989 H   0  0  0  0  0  0  0  0  0  0  0  0
    1.1139   -1.2873    0.8362 H   0  0  0  0  0  0  0  0  0  0  0  0
    5.5208    0.9227    0.6302 H   0  0  0  0  0  0  0  0  0  0  0  0
    3.0727   -2.3547    1.8586 H   0  0  0  0  0  0  0  0  0  0  0  0
    5.2801   -1.2627    1.7515 H   0  0  0  0  0  0  0  0  0  0  0  0
  1  2  1  0
  2  3  2  0
  2  4  1  0
  1  5  1  0
  5  6  1  0
  6  7  2  0
  8  6  1  0
  7  9  1  0
 10  8  2  0
  9 11  2  0
 11 10  1  0
  4 12  2  0
 13  4  1  0
 12 14  1  0
 15 13  2  0
 14 16  2  0
 16 15  1  0
  1 17  1  0
  5 18  1  0
  5 19  1  0
  7 20  1  0
  8 21  1  0
  9 22  1  0
 10 23  1  0
 11 24  1  0
 12 25  1  0
 13 26  1  0
 14 27  1  0
 15 28  1  0
 16 29  1  0
M  END
"""
probe1 = """
     RDKit          3D

 29 30  0  0  0  0  0  0  0  0999 V2000
    0.0408    0.4891    0.4661 N   0  0  0  0  0  0  0  0  0  0  0  0
   -1.0761    1.1517    0.8019 C   0  0  0  0  0  0  0  0  0  0  0  0
   -1.0577    2.1125    1.5732 O   0  0  0  0  0  0  0  0  0  0  0  0
   -2.3382    0.6648    0.1778 C   0  0  0  0  0  0  0  0  0  0  0  0
    1.3445    0.8577    1.0007 C   0  0  0  0  0  0  0  0  0  0  0  0
    2.4208    0.0556    0.3282 C   0  0  0  0  0  0  0  0  0  0  0  0
    2.7642   -1.2006    0.8135 C   0  0  0  0  0  0  0  0  0  0  0  0
    3.0604    0.5405   -0.8055 C   0  0  0  0  0  0  0  0  0  0  0  0
    3.7344   -1.9567    0.1794 C   0  0  0  0  0  0  0  0  0  0  0  0
    4.0315   -0.2151   -1.4398 C   0  0  0  0  0  0  0  0  0  0  0  0
    4.3701   -1.4647   -0.9484 C   0  0  0  0  0  0  0  0  0  0  0  0
   -3.3572    1.5886   -0.0379 C   0  0  0  0  0  0  0  0  0  0  0  0
   -2.5453   -0.6704   -0.1580 C   0  0  0  0  0  0  0  0  0  0  0  0
   -4.5548    1.1924   -0.6030 C   0  0  0  0  0  0  0  0  0  0  0  0
   -3.7490   -1.0646   -0.7151 C   0  0  0  0  0  0  0  0  0  0  0  0
   -4.7510   -0.1358   -0.9450 C   0  0  0  0  0  0  0  0  0  0  0  0
    0.0101   -0.2553   -0.2126 H   0  0  0  0  0  0  0  0  0  0  0  0
    1.4909    1.9318    0.8433 H   0  0  0  0  0  0  0  0  0  0  0  0
    1.3412    0.6808    2.0823 H   0  0  0  0  0  0  0  0  0  0  0  0
    2.2721   -1.5828    1.6955 H   0  0  0  0  0  0  0  0  0  0  0  0
    2.7992    1.5159   -1.1884 H   0  0  0  0  0  0  0  0  0  0  0  0
    3.9966   -2.9294    0.5669 H   0  0  0  0  0  0  0  0  0  0  0  0
    4.5253    0.1726   -2.3178 H   0  0  0  0  0  0  0  0  0  0  0  0
    5.1292   -2.0526   -1.4415 H   0  0  0  0  0  0  0  0  0  0  0  0
   -3.1866    2.6167    0.2461 H   0  0  0  0  0  0  0  0  0  0  0  0
   -1.7857   -1.4110    0.0408 H   0  0  0  0  0  0  0  0  0  0  0  0
   -5.3354    1.9180   -0.7717 H   0  0  0  0  0  0  0  0  0  0  0  0
   -3.9073   -2.1018   -0.9676 H   0  0  0  0  0  0  0  0  0  0  0  0
   -5.6870   -0.4481   -1.3818 H   0  0  0  0  0  0  0  0  0  0  0  0
  1  2  1  0
  2  3  2  0
  2  4  1  0
  1  5  1  0
  5  6  1  0
  6  7  2  0
  8  6  1  0
  7  9  1  0
 10  8  2  0
  9 11  2  0
 11 10  1  0
  4 12  2  0
 13  4  1  0
 12 14  1  0
 15 13  2  0
 14 16  2  0
 16 15  1  0
  1 17  1  0
  5 18  1  0
  5 19  1  0
  7 20  1  0
  8 21  1  0
  9 22  1  0
 10 23  1  0
 11 24  1  0
 12 25  1  0
 13 26  1  0
 14 27  1  0
 15 28  1  0
 16 29  1  0
M  END
"""

probe2 = """
     RDKit          3D

 29 30  0  0  0  0  0  0  0  0999 V2000
    0.2524    0.6765    0.8862 N   0  0  0  0  0  0  0  0  0  0  0  0
    0.9757   -0.4432    1.0586 C   0  0  0  0  0  0  0  0  0  0  0  0
    0.6757   -1.3100    1.8802 O   0  0  0  0  0  0  0  0  0  0  0  0
    2.1801   -0.5681    0.1904 C   0  0  0  0  0  0  0  0  0  0  0  0
   -0.9911    0.9308    1.5893 C   0  0  0  0  0  0  0  0  0  0  0  0
   -2.1821    0.7172    0.6898 C   0  0  0  0  0  0  0  0  0  0  0  0
   -2.9893    1.7768    0.3009 C   0  0  0  0  0  0  0  0  0  0  0  0
   -2.4640   -0.5595    0.2178 C   0  0  0  0  0  0  0  0  0  0  0  0
   -4.0673    1.5627   -0.5417 C   0  0  0  0  0  0  0  0  0  0  0  0
   -3.5374   -0.7731   -0.6275 C   0  0  0  0  0  0  0  0  0  0  0  0
   -4.3415    0.2889   -1.0087 C   0  0  0  0  0  0  0  0  0  0  0  0
    2.5911   -1.8457   -0.1785 C   0  0  0  0  0  0  0  0  0  0  0  0
    2.9168    0.5351   -0.2324 C   0  0  0  0  0  0  0  0  0  0  0  0
    3.7038   -2.0193   -0.9806 C   0  0  0  0  0  0  0  0  0  0  0  0
    4.0357    0.3564   -1.0263 C   0  0  0  0  0  0  0  0  0  0  0  0
    4.4258   -0.9171   -1.4085 C   0  0  0  0  0  0  0  0  0  0  0  0
    0.4745    1.2890    0.1166 H   0  0  0  0  0  0  0  0  0  0  0  0
   -1.0208    0.2290    2.4279 H   0  0  0  0  0  0  0  0  0  0  0  0
   -0.9778    1.9539    1.9721 H   0  0  0  0  0  0  0  0  0  0  0  0
   -2.7783    2.7724    0.6622 H   0  0  0  0  0  0  0  0  0  0  0  0
   -1.8350   -1.3843    0.5224 H   0  0  0  0  0  0  0  0  0  0  0  0
   -4.6920    2.3922   -0.8366 H   0  0  0  0  0  0  0  0  0  0  0  0
   -3.7486   -1.7691   -0.9868 H   0  0  0  0  0  0  0  0  0  0  0  0
   -5.1811    0.1215   -1.6662 H   0  0  0  0  0  0  0  0  0  0  0  0
    2.0211   -2.6923    0.1751 H   0  0  0  0  0  0  0  0  0  0  0  0
    2.6410    1.5301    0.0819 H   0  0  0  0  0  0  0  0  0  0  0  0
    4.0091   -3.0138   -1.2672 H   0  0  0  0  0  0  0  0  0  0  0  0
    4.6074    1.2140   -1.3459 H   0  0  0  0  0  0  0  0  0  0  0  0
    5.2964   -1.0509   -2.0318 H   0  0  0  0  0  0  0  0  0  0  0  0
  1  2  1  0
  2  3  2  0
  2  4  1  0
  1  5  1  0
  5  6  1  0
  6  7  2  0
  8  6  1  0
  7  9  1  0
 10  8  2  0
  9 11  2  0
 11 10  1  0
  4 12  2  0
 13  4  1  0
 12 14  1  0
 15 13  2  0
 14 16  2  0
 16 15  1  0
  1 17  1  0
  5 18  1  0
  5 19  1  0
  7 20  1  0
  8 21  1  0
  9 22  1  0
 10 23  1  0
 11 24  1  0
 12 25  1  0
 13 26  1  0
 14 27  1  0
 15 28  1  0
 16 29  1  0
M  END
"""

In [None]:
viz_align(ref_block, probe1)