In [2]:
import itertools
from typing import List
import numpy as np

In [5]:
def generate_atom_combinations(symbols: List[str]):
    """Yield combinations of atom indices for transformations

    The method first yields combinations of 3 heavy atom indices.
    Each combination is followed by its reverse. Once the heavy atoms
    are exhausted, the heavy atoms then get combined with the hydrogens.

    Parameters
    ----------
    symbols: list of str
        List of atom elements

    Examples
    --------

    ::

        >>> symbols = ["H", "C", "C", "O", "N"]
        >>> comb = generate_atom_combinations(symbols)
        >>> next(comb)
        (1, 2, 3)
        >>> next(comb)
        (3, 2, 1)
        >>> next(comb)
        (1, 2, 4)
        >>> next(comb)
        (4, 2, 1)

    """
    symbols = np.asarray(symbols)
    is_H = symbols == "H"
    h_atoms = list(np.flatnonzero(is_H))
    heavy_atoms = list(np.flatnonzero(~is_H))
    seen = set()

    for comb in itertools.combinations(heavy_atoms, 3):
        seen.add(comb)
        yield comb
        yield comb[::-1]

    for comb in itertools.combinations(heavy_atoms + h_atoms, 3):
        if comb in seen:
            continue
        seen.add(comb)
        yield comb
        yield comb[::-1]

In [6]:
list(generate_atom_combinations(["H", "C", "C", "O", "N"]))

[(1, 2, 3),
 (3, 2, 1),
 (1, 2, 4),
 (4, 2, 1),
 (1, 3, 4),
 (4, 3, 1),
 (2, 3, 4),
 (4, 3, 2),
 (1, 2, 0),
 (0, 2, 1),
 (1, 3, 0),
 (0, 3, 1),
 (1, 4, 0),
 (0, 4, 1),
 (2, 3, 0),
 (0, 3, 2),
 (2, 4, 0),
 (0, 4, 2),
 (3, 4, 0),
 (0, 4, 3)]