In [1]:
from collections.abc import Mapping
from collections import defaultdict, Counter
import itertools
import random
from stereomolgraph import StereoMolGraph, MolGraph, AtomId
from stereomolgraph.stereodescriptors import Tetrahedral
from stereomolgraph.algorithms.color_refine import label_hash, numpy_int_tuple_hash, numpy_int_multiset_hash
from stereomolgraph.experimental import generate_stereoisomers

import numpy as np
import rdkit
from pprint import pprint
import stereomolgraph.ipython

print = lambda *x: pprint(x, width=1000, compact=True)
%load_ext line_profiler

In [2]:
smg = StereoMolGraph()
smg.add_atom(0, "H")
smg.add_atom(1, "H")
smg.add_atom(2, "H")
smg.add_atom(3, "H")
smg.add_atom(4, "C")

smg.add_bond(0,4)
smg.add_bond(1,4)
smg.add_bond(2,4)
smg.add_bond(3,4)

smg.set_atom_stereo(Tetrahedral( (4, *smg.bonded_to(4) ), parity = 1) )

#smg2 = StereoMolGraph.from_rdmol()
#smg2 = next(iter(generate_stereoisomers(smg2)))

In [5]:
rdmol = rdkit.Chem.AddHs(rdkit.Chem.MolFromSmiles("CC=CCC=CCC=CCCC=CCC=CCCCC=CCC=CCCCC=CCCCC=CCC=C"))
mg = MolGraph.from_rdmol(rdmol)
#rdmols = [rdkit.Chem.RenumberAtoms(rdmol,
#                                   sorted(iter(list(range(rdmol.GetNumAtoms()))),

##                                          key=lambda k: random.random())) for _ in range(100)]

#smgs = [next(iter(generate_stereoisomers(StereoMolGraph.from_rdmol(renumbered_rdmol)))) for renumbered_rdmol in rdmols]
#smg2 = next(iter(generate_stereoisomers(smg2)))

In [None]:
def color_refine_smg(
    smg: StereoMolGraph,
    iterations: None|int = None,
    atom_labels: None|Mapping[AtomId, int] = None,
) -> Mapping[AtomId, int]:
    """
    Stereochemical color refinement.
    Each atom Stereo is aggregated at each iteration.
    Bond Stereo is aggregated at every second iteration."""
    n_atoms = len(smg.atoms)
    if atom_labels:
        assert len(atom_labels) == n_atoms
        assert set(atom_labels.keys()) == set(smg.atoms)

    initial_atom_label_hash = (label_hash(smg, ("atom_type",))
                               if atom_labels is None else atom_labels)
    if iterations == 0:
        return initial_atom_label_hash
    
    atom_hash = np.array(
        [initial_atom_label_hash[atom] for atom in smg.atoms], dtype=np.int64
    )
    arr_id_dict = {atom: a_id for a_id, atom in enumerate(smg.atoms)}
    id_arr_dict = {a_id: atom for atom, a_id in arr_id_dict.items()}

    stereo_hash_pointer = {i: [] for i in range(n_atoms)} # arr_id: list[stereo_pointer]


    grouped_atom_stereo: dict = defaultdict(list)
    atoms_with_atom_stereo: set[int] = set()
    atoms_with_bond_stereo: set[int] = set()

    grouped_bond_stereo: dict = defaultdict(list)
    atoms_with_atom_stereo: set[int] = set()

    as_atoms = []
    as_perm_atoms = []
    as_nbr_atoms = []
    
    bs_atoms = []
    bs_nbr_atoms = []
    bs_perm_atoms = []

    # i: arrays to store intermediate values. Avoids additional memory allocation.
    i_a_perm_nbrs = []
    i_a_perm = []
    i_a = []

    i_b_perm_nbrs = []
    i_b_perm = []
    i_b = []

    for atom, stereo in smg.atom_stereo.items():
        if stereo.parity is not None:

            nbr_atoms = stereo.atoms if stereo.parity != -1 else stereo._inverted_atoms()
            grouped_atom_stereo[stereo.__class__.PERMUTATION_GROUP].append((atom, nbr_atoms))

            atoms_with_atom_stereo.add(atom)

    for bond, stereo in smg.bond_stereo.items():
        if stereo.parity is not None:
            
            nbr_atoms = stereo.atoms if stereo.parity != -1 else stereo._inverted_atoms()
            grouped_bond_stereo[stereo.__class__.PERMUTATION_GROUP].append((bond, nbr_atoms))
            
            for atom in bond:
                atoms_with_bond_stereo.add(atom)

    
    atoms_without_atom_stereo = set(smg.atoms) - atoms_with_atom_stereo
    
    for atom in atoms_without_atom_stereo:
        fake_stereo_atoms = (atom, *smg.bonded_to(atom))
        perm_gen = itertools.permutations(range(1, len(fake_stereo_atoms)))
        perm_group = tuple((0, *perm) for perm in perm_gen)
        grouped_atom_stereo[perm_group].append((atom, fake_stereo_atoms))
    
    # atom_stereo
    for perm_group, atom_nbr_atoms_list_tup in grouped_atom_stereo.items():
        atom_arr_ids = np.array([arr_id_dict[atom] for atom, _nrb_atom_list in atom_nbr_atoms_list_tup],
                        dtype=np.uint16)
        as_atoms.append(atom_arr_ids)

        nbr_atoms = np.array([[arr_id_dict[a] for a in nbr_lst]
                     for _atom, nbr_lst in atom_nbr_atoms_list_tup],
                     dtype=np.uint16)

        as_nbr_atoms.append(nbr_atoms)

        perm_group = np.array(perm_group, dtype=np.uint8)
        perm_atoms = nbr_atoms[..., perm_group]
        as_perm_atoms.append(perm_atoms)

        # zeros intermediate arrays
        a_perm_nbrs = np.zeros( perm_atoms.shape,
                                 dtype=np.int64)
        i_a_perm_nbrs.append(a_perm_nbrs)
        i_a_perm.append(np.zeros(a_perm_nbrs.shape[0:2], dtype=np.int64))
        i_atom_stereo = np.zeros(a_perm_nbrs.shape[0:1], dtype=np.int64)
        i_a.append(i_atom_stereo)

        for stereo_id, atom_arr_id in enumerate(atom_arr_ids):
            stereo_hash_pointer[atom_arr_id].append(i_atom_stereo[stereo_id:stereo_id+1])
            # by reference

    
    # bond_stereo
    for perm_group, atom_nbr_atoms_list_tup in grouped_bond_stereo.items():
        atom_arr_ids = np.array([[arr_id_dict[atom] for atom in bond]
                                  for bond, _nrb_atom_list in atom_nbr_atoms_list_tup],
                        dtype=np.uint16)
        bs_atoms.append(atom_arr_ids)

        nbr_atoms = np.array([[arr_id_dict[a] for a in nbr_lst]
                     for _atom, nbr_lst in atom_nbr_atoms_list_tup],
                     dtype=np.uint16)

        bs_nbr_atoms.append(nbr_atoms)

        perm_group = np.array(perm_group, dtype=np.uint8)
        perm_atoms = nbr_atoms[..., perm_group]
        bs_perm_atoms.append(perm_atoms)

        # zeros intermediate arrays
        b_perm_nbrs = np.zeros( perm_atoms.shape,
                                 dtype=np.int64)
        i_b_perm_nbrs.append(b_perm_nbrs)
        i_b_perm.append(np.zeros(b_perm_nbrs.shape[0:2], dtype=np.int64))
        i_bond_stereo = np.zeros(b_perm_nbrs.shape[0:1], dtype=np.int64)
        i_b.append(i_bond_stereo)

        #print("init", i_b_perm, i_b_perm_nbrs, i_bond_stereo)

        for stereo_id, (atom_arr_id1, atom_arr_id2) in enumerate(atom_arr_ids):
            stereo_hash_pointer[atom_arr_id1].append(i_bond_stereo[stereo_id:stereo_id+1])
            stereo_hash_pointer[atom_arr_id2].append(i_bond_stereo[stereo_id:stereo_id+1])
    
    counter = itertools.count(0) if iterations is None else range(iterations)

    i_atoms_with_n_stereo = [] # atoms, i_stereo, group

    pntr = sorted(((id, ptr) for id, ptr in stereo_hash_pointer.items() if ptr),
               key=lambda x: len(x[1]))

    for key, group in itertools.groupby(pntr, key=lambda x: len(x[1])):
        group = list(group)

        ids = []
        pntr_groups = []
        for id, ptrs in group:
            ids.append(id)
            pntr_groups.append(ptrs)
        
        atoms = np.array(ids, dtype=np.int16)
        i_hash = np.empty((len(pntr_groups), key), dtype=np.int64)
        i_atoms_with_n_stereo.append((atoms, i_hash, pntr_groups))

    n_atom_classes = None
    counter = itertools.count(1, 1) if iterations is None else range(iterations + 1)

    for count in counter:
        # atom stereo
        for perm_atoms, _atoms, a_perm_nbrs, a_perm, a in zip(
            as_perm_atoms, as_atoms, i_a_perm_nbrs, i_a_perm, i_a):

            numpy_int_tuple_hash(atom_hash[perm_atoms], out=a_perm)

            numpy_int_multiset_hash(a_perm, out=a)
        # bond stereo
        if count % 2 == 0 and count != 0:
            
            for perm_atoms, _atoms, b_perm_nbrs, b_perm, b in zip(
                bs_perm_atoms, bs_atoms, i_b_perm_nbrs, i_b_perm, i_b):

                numpy_int_tuple_hash(atom_hash[perm_atoms], out=b_perm)

                numpy_int_multiset_hash(b_perm, out=b)

        for atoms, i_stereo, ptr_lsts in i_atoms_with_n_stereo: # atoms, i_stereo, group

            i_stereo[:] = np.asarray([[ptr.item() for ptr in ptr_list] for ptr_list in ptr_lsts])

            atom_hash[atoms] = numpy_int_multiset_hash(i_stereo, out=atom_hash[atoms])

        if count % 2 == 0 and iterations is None:
            new_n_classes = np.unique(atom_hash, sorted=False).shape[0]
            if new_n_classes == n_atom_classes:
                break
            elif new_n_classes == n_atoms:
                break
            else:
                n_atom_classes = new_n_classes

    return {id_arr_dict[arr_id]: int(h) for arr_id, h in enumerate(atom_hash)}

In [22]:
def color_refine_mg(
    mg: MolGraph,
    iterations: None | int = None,
    atom_labels: None | Mapping[AtomId, int] = None,
) -> Mapping[AtomId, int]:
    """Color refinement algorithm for MolGraph.

    This algorithm refines the atom coloring based on their connectivity.
    Identical to the Weisfeiler-Lehman (1-WL) algorithm.

    :param mg: MolGraph object containing the atoms and their connectivity.
    :param max_iter: Maximum number of iterations for refinement.
        Default is None, which means it will run until convergence."""
    n_atoms = len(mg.atoms)
    if atom_labels:
        assert len(atom_labels) == n_atoms
        assert set(atom_labels.keys()) == set(mg.atoms)

    initial_atom_label_hash = (
        label_hash(mg, ("atom_type",)) if atom_labels is None else atom_labels
    )
    if iterations == 0:
        return initial_atom_label_hash

    atom_hash = np.array(
        [initial_atom_label_hash[atom] for atom in mg.atoms], dtype=np.int64
    )

    arr_id_dict, id_arr_dict = {}, {}
    for id, atom in enumerate(mg.atoms):
        arr_id_dict[atom] = id
        id_arr_dict[id] = atom
    
    bonded_lst = [(id, [arr_id_dict[a] for a in mg.bonded_to(atom)])
                  for atom, id in arr_id_dict.items()]
    
    bonded_lst.sort(key=lambda x: len(x[1]))

    id_nbrs_tuple_list = []
    for _key, group in itertools.groupby(bonded_lst, key=lambda x: len(x[1])):
        group = list(group)
        ids_lst, nbrs_lists = [], []
        for id, nbrs in group:
            ids_lst.append(id)
            nbrs_lists.append(nbrs)
        ids = np.array(ids_lst, dtype=np.int16)
        nbrs = np.array(nbrs_lists, dtype=np.int16)
        id_nbrs_tuple_list.append((ids, nbrs))
    
    
    n_atom_classes = np.unique(atom_hash, sorted=False).shape[0]
    counter = itertools.count(1, 1) if iterations is None else range(iterations + 1)
    new_atom_hashes = np.empty_like(atom_hash, dtype=np.int64)

    for _ in counter:
        for ids, nbrs in id_nbrs_tuple_list:
            # Compute the new hash for each atom based on its neighbors
            new_atom_hashes[ids] = numpy_int_multiset_hash(atom_hash[nbrs])
        
        new_n_classes = np.unique(atom_hash, sorted=False).shape[0]
        if new_n_classes == n_atom_classes:
            break
        elif new_n_classes == n_atoms:
            break
        else:
            n_atom_classes = new_n_classes
            atom_hash, new_atom_hashes = new_atom_hashes, atom_hash

    return {id_arr_dict[arr_id]: int(h) for arr_id, h in enumerate(atom_hash)}

In [None]:
{frozenset(sorted(Counter(color_refine_smg(smg, iterations=None).values()).items())) for smg in smgs}

{frozenset({(-8300630967067387254, 2),
            (-6613449356398020983, 4),
            (-6611864318748356162, 1),
            (-6105569032067994779, 4),
            (-4516373216795078962, 2),
            (-2369376690388761227, 4),
            (-686108559420242651, 6),
            (-443162110294980450, 2),
            (2618423732730923493, 4),
            (3028710709823717049, 2),
            (3680202844943605406, 2),
            (6298530441632723246, 2)})}

In [24]:
n = 20
color_refine_mg(mg, iterations=n)
%timeit color_refine_mg(mg, iterations=n)
%lprun -f color_refine_mg color_refine_mg(mg, iterations=n)

452 μs ± 47 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Timer unit: 1e-07 s

Total time: 0.0021943 s
File: C:\Users\maxim\AppData\Local\Temp\ipykernel_44184\3693040269.py
Function: color_refine_mg at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def color_refine_mg(
     2                                               mg: MolGraph,
     3                                               iterations: None | int = None,
     4                                               atom_labels: None | Mapping[AtomId, int] = None,
     5                                           ) -> Mapping[AtomId, int]:
     6                                               """Color refinement algorithm for MolGraph.
     7                                           
     8                                               This algorithm refines the atom coloring based on their connectivity.
     9                                               Identical to the Weisfeiler-Lehman (1-WL) algorithm.
    10    

In [260]:
a_list = [[], [], []]

In [261]:
a_stereo_hash = np.array([1,2,3,4])

In [262]:
a_list[0].append(a_stereo_hash[1:2])

In [263]:
a_list


[[array([2])], [], []]

In [264]:
a_stereo_hash[1] = 12

In [265]:
a_stereo_hash, a_list

(array([ 1, 12,  3,  4]), [[array([12])], [], []])