In [1]:
from collections.abc import Iterable
from collections import defaultdict
import itertools
from stereomolgraph import StereoMolGraph, AtomId
from stereomolgraph.stereodescriptors import Tetrahedral
from stereomolgraph.algorithms.color_refine import label_hash, numpy_int_tuple_hash, numpy_int_multiset_hash
from stereomolgraph.experimental import generate_stereoisomers

import numpy as np
import rdkit

from pprint import pprint

In [2]:
smg = StereoMolGraph()
smg.add_atom(0, "H")
smg.add_atom(1, "H")
smg.add_atom(2, "H")
smg.add_atom(3, "H")
smg.add_atom(4, "C")

smg.add_bond(0,4)
smg.add_bond(1,4)
smg.add_bond(2,4)
smg.add_bond(3,4)

smg.set_atom_stereo(Tetrahedral( (4, *smg.bonded_to(4) ), parity = 1) )

smg2 = StereoMolGraph.from_rdmol(rdkit.Chem.AddHs(rdkit.Chem.MolFromSmiles("CCC")))
smg2 = next(iter(generate_stereoisomers(smg2)))

In [19]:
def color_refine_smg(
    smg: StereoMolGraph,
    max_iter: None|int = 1,
    atom_labels: Iterable[str] = ("atom_type",),
) -> dict[AtomId, int]:

    atom_label_hash = label_hash(smg, atom_labels)
    atom_hash = np.array(
        [atom_label_hash[atom] for atom in smg.atoms], dtype=np.int64
    )
    atom_stereo_hash = np.zeros(atom_hash.shape, dtype=np.int64)
    atom_bond_hash = np.zeros(atom_hash.shape, dtype=np.int64)

    arr_id = {atom: a_id for a_id, atom in enumerate(smg.atoms)}

    grouped_atom_stereo: dict = defaultdict(list)
    atoms_with_atom_stereo: set[int] = set()

    grouped_bond_stereo: dict = defaultdict(list)
    atoms_with_atom_stereo: set[int] = set()

    as_perm_groups: list[tuple[tuple[int, ...], ...]] = []
    as_perm_atoms = []
    
    bs_atoms: list[np.ndarray[tuple[int], np.dtype[np.int16]]] = []
    bs_nbr_atoms = []

    # i: arrays to store intermediate values. Avoids additional memory allocation.
    i_a_perm_nbrs = []
    i_a_perm = []
    i_a = []

    i_b_perm_nbrs = []
    i_b_perm = []
    i_b = []

    for atom, stereo in smg.atom_stereo.items():
        if stereo.parity is not None:
            atoms_with_atom_stereo.add(atom)
            nbr_atoms = stereo.atoms if stereo.parity != -1 else stereo._inverted_atoms()

            grouped_atom_stereo[stereo.__class__.PERMUTATION_GROUP].append(nbr_atoms)

    for bond, stereo in smg.bond_stereo.items():
        if stereo.parity is not None:
            for a in bond:
                atoms_with_bond_stereo.add(a)
                nbr_atoms = stereo.atoms if stereo.parity != -1 else stereo._inverted_atoms()

                grouped_atom_stereo[stereo.__class__.PERMUTATION_GROUP].append(nbr_atoms)

    
    atoms_with_atom_stereo = set(smg.atoms) - atoms_with_atom_stereo
    # TODO: for atoms without atom_stereo

    # atom_stereo
    for perm_group, nbr_atoms_list in grouped_atom_stereo.items():
        perm_group = np.array(perm_group, dtype=np.uint8)
        as_perm_groups.append(perm_group)
        

        as_atoms.append(np.array([arr_id[atoms[0]] for atoms in nbr_atoms_list],
                        dtype=np.uint16))

        nbr_atoms = np.array([[arr_id[a] for a in nbr_lst]  
                     for nbr_lst in nbr_atoms_list ], dtype=np.uint16)

        as_nbr_atoms.append(nbr_atoms)

        perm_atoms = nbr_atoms[..., perm_group]
        as_perm_atoms.append(perm_atoms)

        # empty intermediate arrays
        a_perm_nbrs = np.empty( perm_atoms.shape,
                                 dtype=np.int64)
        i_a_perm_nbrs.append(a_perm_nbrs)
        i_a_perm.append(np.empty(a_perm_nbrs.shape[0:2], dtype=np.int64))
        i_a.append(np.empty(a_perm_nbrs.shape[0:1], dtype=np.int64))

    
    # bond_stereo
    for perm_group, nbr_atoms_list in grouped_bond_stereo.items():
        perm_group = np.array(perm_group, dtype=np.uint8)
        as_perm_groups.append(perm_group)
        

        bs_atoms.append(np.array([arr_id[atoms[0]] for atoms in nbr_atoms_list],
                        dtype=np.uint16))

        nbr_atoms = np.array([[arr_id[a] for a in nbr_lst]  
                     for nbr_lst in nbr_atoms_list ], dtype=np.uint16)

        bs_nbr_atoms.append(nbr_atoms)

        perm_atoms = nbr_atoms[..., perm_group]
        bs_perm_atoms.append(perm_atoms)

        # empty intermediate arrays
        b_perm_nbrs = np.empty( perm_atoms.shape,
                                 dtype=np.int64)
        i_b_perm_nbrs.append(b_perm_nbrs)
        i_b_perm.append(np.empty(b_perm_nbrs.shape[0:2], dtype=np.int64))
        i_b.append(np.empty(b_perm_nbrs.shape[0:1], dtype=np.int64))
    
    counter = itertools.count(0) if max_iter is None else range(max_iter)

    for count in counter:
        # atom stereo
        for perm_atoms, atoms, _nbr_atoms, a_perm_nbrs, a_perm, a in zip(
            as_perm_atoms, as_atoms, as_nbr_atoms, i_a_perm_nbrs, i_a_perm, i_a):

            a_perm_nbrs[:] = atom_hash[perm_atoms]

            numpy_int_tuple_hash(a_perm_nbrs, out=a_perm)

            numpy_int_multiset_hash(a_perm, out=a)
            atom_stereo_hash[atoms] = a

        # bond stereo

        if count % 2 == 0:
            for perm_atoms, atoms, _nbr_atoms, b_perm_nbrs, b_perm, b in zip(
                bs_perm_atoms, bs_atoms, bs_nbr_atoms, i_b_perm_nbrs, i_b_perm, i_b):

                b_perm_nbrs[:] = atom_hash[perm_atoms]

                numpy_int_tuple_hash(b_perm_nbrs, out=b_perm)

                numpy_int_multiset_hash(b_perm, out=b)
                bond_stereo_hash[atoms] = b
       




In [20]:
color_refine_smg(smg2)

NameError: name 'as_atoms' is not defined