In [None]:
def color_refine_smg(
    mg: StereoMolGraph,
    max_iter: Optional[int] = None,
    atom_labels: Optional[Iterable[str]] = ("atom_type",),
    bond_labels: Optional[Iterable[str]] = None,
) -> dict[AtomId, int]:
    atom_label_hash = label_hash(mg, atom_labels, bond_labels)

    # hashes are int64, 0 is reserved for "selfreferencing" in stereo
    atom_hash = np.array(
        [atom_label_hash[atom] for atom in mg.atoms] + [0], dtype=np.int64
    )

    n_atoms = np.int64(mg.n_atoms)
    arr_id = {atom: a_id for a_id, atom in enumerate(mg.atoms)}

    # atom_stereo_lists: list[list[view(atom_hash)]]
    a_s_lists = [[i] for i in arr_id.values()]

    # stereo_group, stereo_arr, stereo_atom_arr
    # : dict[type[Stereo], tuple[list[Stereo], np.ndarray, np.ndarray]]
    stereo_dict = {}

    stereos = [atom_stereo for atom in mg.atoms
               if (atom_stereo := mg.get_atom_stereo(atom)) is not None]
    stereos.sort(key=lambda s: s.PERMUTATION_GROUP)

    for s_perm_group_tup, s_group in itertools.groupby(
        stereos, lambda s: s.PERMUTATION_GROUP
    ):
        
        s_group = list(s_group)
        s_pg = s_group[0]
        print("s_perm_group", s_perm_group_tup)
        print("s_pg", s_pg)
        print("s_group", s_group)

        n_s_atoms = len(s_group[0].atoms)
        size = n_s_atoms * len(s_group)
        s_arr = np.zeros(size, dtype=np.int64)

        # perm group ids are int8
        s_perm_group = np.array(
            tuple(s_perm_group_tup), dtype=np.uint8
        )
        s_tup_arr = np.empty((size, len(s_perm_group)), dtype=np.int64)
        s_atom_arr = np.empty(
            (size, len(s_perm_group), n_s_atoms), dtype=np.int64
        )
        # atomids are int32
        s_atom_perm_id_arr = np.empty(
            (size, len(s_perm_group), n_s_atoms), dtype=np.int32
        )

        pos_counter = itertools.count()

        lsts = []  # [s.atoms if s.parity == -1 else s._inverted_atoms()
                   #for s in s_group]

        for s in s_group:
            atoms = s._inverted_atoms() if s.parity == -1 else s.atoms
            atoms = [arr_id[atom] for atom in atoms]
            lsts.append(atoms)

            for atom in atoms:
                c = next(pos_counter)
                view = s_arr[c : c + 1]
                a_s_lists[atom].append(view)

        lsts1 = np.array(lsts, dtype=np.int32)
        lsts2 = np.repeat(lsts1, n_s_atoms, axis=0)

        diag_mask = np.eye(n_s_atoms, dtype=bool)
        diag_mask = np.tile(diag_mask, (len(lsts), 1))
        lsts2[diag_mask] = -1  # id of "selfreference" in atom_hash array

        s_atom_perm_id_arr = lsts2[..., s_perm_group]

        stereo_dict[s_perm_group_tup] = (
            s_arr,
            s_tup_arr,
            s_atom_arr,
            s_atom_perm_id_arr,
        )

    a_s_lists.sort(key=len)
    #print ("a_s_lists", a_s_lists)
    new_a_s_lists = []

    for lngh, group in itertools.groupby(a_s_lists, key=len):
        ids = []
        hash_views = []
        for lst in group:
            i = lst[0]
            ids.append(i)
            lst[0] = atom_hash[i : i + 1]  # this is a view!
            hash_views.append(lst)

        new_a_s_lists.append(
            (
                np.array(ids, dtype=np.int32),
                np.empty((len(ids), len(hash_views[0]), 1), dtype=np.int64),
                hash_views,
            )
        )

    n_atom_classes = None
    counter = itertools.repeat(None) if max_iter is None else range(max_iter)
    new_atom_hash = np.empty_like(atom_hash, dtype=np.int64)
    #print("new_a_s_lists", new_a_s_lists)
    for _ in counter:
        #print("count", _)
        # iterate over stereo classes
        for (
            s_arr,
            s_tup_arr,
            s_atom_arr,
            s_atom_perm_id_arr,
        ) in stereo_dict.values():
            s_atom_arr[:] = atom_hash[s_atom_perm_id_arr]
            s_tup_arr[:] = numpy_int_tuple_hash(s_atom_arr, out=s_tup_arr)
            s_tup_arr.sort(axis=-1)
            s_arr = numpy_int_tuple_hash(s_tup_arr, out=s_arr)

        # iterate over atoms
        for ids, hashs, hash_views in new_a_s_lists:
            hashs[:] = hash_views
            hashs_v = hashs.reshape(hashs.shape[:-1])

            hashs_v.sort(axis=-1)
            new_atom_hash[ids] = numpy_int_tuple_hash(hashs_v,
                                                      out=new_atom_hash[ids])

        new_n_classes = np.unique(new_atom_hash).shape[0]

        if new_n_classes == n_atom_classes:
            break
        elif new_n_classes == n_atoms:
            break
        else:
            n_atom_classes = new_n_classes
            # efficient xor to swap int values in-place
            atom_hash ^= new_atom_hash
            new_atom_hash ^= atom_hash
            atom_hash ^= new_atom_hash

    return {a: int(h) for a, h in zip(mg.atoms, atom_hash)}
