In [1]:
import itertools
from typing import Optional

import numpy as np
from centrex_tlf import states, hamiltonian, transitions, couplings

In [None]:
def get_transition_from_state(
    ground_state: states.CoupledBasisState, excited_state: states.CoupledBasisState, transition: str = "E1"
) -> Optional[transitions.OpticalTransition]:
    if transition == "E1":
        checker = couplings.utils.check_transition_coupled_allowed
    elif transition == "E2":
        checker = couplings.utils.check_transition_coupled_allowed_E2
    else:
        raise ValueError(f"{transition} coupling not supported")
    if checker(ground_state, excited_state):
        ΔJ = excited_state.J - ground_state.J
        ΔJs = [val.value for val in transitions.OpticalTransitionType]
        if (ΔJ > max(ΔJs)) or (ΔJ < min(ΔJs)):
            return None
        return transitions.OpticalTransition(
            transitions.OpticalTransitionType(ΔJ),
            ground_state.J,
            excited_state.F1,
            excited_state.F,
        )


In [None]:
J_ground = [0,1,2,3,4,5,6,7,8]
J_excited = [1,2,3,4,5,6,7,8,9,10]

ground_select = states.QuantumSelector(J=J_ground)
excited_select = states.QuantumSelector(J=J_excited, P=[-1, 1])

QN_X = states.generate_coupled_states_X(ground_select)
QN_B = states.generate_coupled_states_B(excited_select, basis=states.Basis.CoupledP)

reduced_hamiltonian = hamiltonian.generate_total_reduced_hamiltonian(
    QN_X, QN_B, B=np.array([0, 0, 1e-3])
)

nr_ground_states = len(reduced_hamiltonian.X_states)
nr_excited_states = len(reduced_hamiltonian.B_states)

H_X = np.diag(reduced_hamiltonian.H_int)[:nr_ground_states].real / (
    2 * np.pi * 1e6
)  # MHz
H_B = np.diag(reduced_hamiltonian.H_int)[-nr_excited_states:].real / (
    2 * np.pi * 1e6
)  # MHz

In [258]:
transitions_E1 = []
for (idg, ground_state), (ide, excited_state) in itertools.product(
    enumerate(reduced_hamiltonian.X_states),
    enumerate(reduced_hamiltonian.B_states),
):
    if abs(excited_state.largest.J - ground_state.largest.J ) > 1:
        continue
    if excited_state.largest.P == ground_state.largest.P:
        continue
    cpl = hamiltonian.generate_ED_ME_mixed_state(excited_state, ground_state, reduced=True)
    transition = get_transition_from_state(ground_state.largest, excited_state.largest)
    if transition is not None:
        transitions_E1.append((idg, ide, transition, cpl))
    if (transition is None) and (cpl != 0):
        print(idg, ide, transition, cpl)
        print(ground_state.largest)
        print(excited_state.largest)

In [259]:
transitions_E2 = []
for (idg, ground_state), (ide, excited_state) in itertools.product(
    enumerate(reduced_hamiltonian.X_states),
    enumerate(reduced_hamiltonian.B_states),
):
    if abs(excited_state.largest.J - ground_state.largest.J ) != 2:
        continue
    if excited_state.largest.P != ground_state.largest.P:
        continue
    cpl = hamiltonian.generate_EQ_ME_mixed_state(excited_state, ground_state, reduced=True)
    transition = get_transition_from_state(ground_state.largest, excited_state.largest, "E2")
    if (transition is not None) and (abs(cpl) > 1e-14):
        transitions_E2.append((idg, ide, transition, cpl))
    if (transition is None) and (cpl != 0):
        print(idg, ide, transition, cpl)
        print(ground_state.largest)
        print(excited_state.largest)

In [241]:
from dataclasses import dataclass
import numpy.typing as npt

@dataclass
class Transition:
    transition: transitions.OpticalTransition
    ground_indices: list[int]
    excited_indices: list[int]
    coupling_elements_squared: list[float]
    ground_energies: npt.NDArray[np.float64]
    excited_energies: npt.NDArray[np.float64]
    weighted_energy: float
    nphotons: float | None
    branching: dict[int, float] | None


    def __repr__(self) -> str:
        if self.nphotons is not None:
            return f"Transition({self.transition.name}, nγ={self.nphotons:.2f})"
        else:
            return f"Transition({self.transition.name}, nγ=None)"

In [234]:
def precalculate_couplings(X_states, B_states):
    n_X = len(X_states)
    n_B = len(B_states)
    coupling_matrix = np.zeros((n_B, n_X))

    pol_vecs = [
        np.array([1, 0, 0], dtype=complex),
        np.array([0, 1, 0], dtype=complex),
        np.array([0, 0, 1], dtype=complex)
    ]

    for i, excited in enumerate(B_states):
        for j, ground in enumerate(X_states):
            strength = 0.0
            for pol in pol_vecs:
                cpl = hamiltonian.generate_ED_ME_mixed_state(
                    excited, ground, pol_vec=pol, reduced=False
                )
                strength += np.abs(cpl)**2
            if strength < 1e-14:
                continue
            coupling_matrix[i, j] = strength
    return coupling_matrix

coupling_matrix = precalculate_couplings(reduced_hamiltonian.X_states, reduced_hamiltonian.B_states)
ground_Js = np.array([s.largest.J for s in reduced_hamiltonian.X_states])

In [None]:
def group_transitions(
    transitions_sequence: list[tuple[int, int, transitions.OpticalTransition, complex]],
    H_X: npt.NDArray[np.float64],
    H_B: npt.NDArray[np.float64],
    coupling_matrix: npt.NDArray[np.float64],
    ground_Js: npt.NDArray[np.int_]
) -> list[Transition]:
    transitions_list: list[Transition] = []

    # Helper to find index of transition in list
    def find_transition_index(trans_list: list[Transition], trans: transitions.OpticalTransition) -> int:
        for i, t in enumerate(trans_list):
            if t.transition == trans:
                return i
        return -1

    for idg, ide, transition, cpl in transitions_sequence:
        idx = find_transition_index(transitions_list, transition)

        # Use the passed coupling element (squared magnitude)
        cpl_squared = np.abs(cpl)**2

        if idx >= 0:
            transitions_list[idx].ground_indices.append(idg)
            transitions_list[idx].excited_indices.append(ide)
            transitions_list[idx].coupling_elements_squared.append(cpl_squared)
        else:
            transitions_list.append(
                Transition(
                    transition=transition,
                    ground_indices=[idg],
                    excited_indices=[ide],
                    coupling_elements_squared=[cpl_squared],
                    ground_energies=np.array([]),
                    excited_energies=np.array([]),
                    weighted_energy=0.0,
                    nphotons=None,
                    branching=None
                )
            )

    unique_Js = np.unique(ground_Js)

    for transition in transitions_list:
        # Check if all necessary ground states are present
        J_excited = transition.transition.J_excited
        required_Js = {J_excited - 1, J_excited, J_excited + 1}
        required_Js = {J for J in required_Js if J >= 0}

        missing_Js = [J for J in required_Js if J not in unique_Js]

        if not missing_Js:
            # Branching calculation using pre-calculated matrix (E1 decays)
            exc_indices = transition.excited_indices

            # Get couplings for these excited states: shape (N_exc, N_ground)
            culs = coupling_matrix[exc_indices, :]

            # Total decay rate per excited state
            total_decay = np.sum(culs, axis=1)

            # Avoid division by zero
            total_decay[total_decay == 0] = 1.0

            # Branching fractions matrix: (N_exc, N_ground)
            br_matrix = culs / total_decay[:, None]

            # Calculate branching for each J
            branching = {}
            for J in unique_Js:
                mask = (ground_Js == J)
                if np.any(mask):
                    br_vals = np.sum(br_matrix[:, mask], axis=1)
                    branching[int(J)] = float(np.mean(br_vals))

            transition.branching = branching

            # nphotons
            vib_branching = 0.99
            target_J = transition.transition.J_ground
            transition.nphotons = 1 / (1 - vib_branching * branching.get(target_J, 0))
        else:
            transition.branching = None
            transition.nphotons = None

        transition.ground_energies = H_X[transition.ground_indices]
        transition.excited_energies = H_B[transition.excited_indices]

        energies = transition.excited_energies - transition.ground_energies

        # Use stored coupling elements (which are now |cpl|^2 from the input tuple) as weights
        weights = np.array(transition.coupling_elements_squared)

        if np.sum(weights) != 0:
            transition.weighted_energy = np.average(energies, weights=weights)
        else:
            transition.weighted_energy = np.mean(energies)

    return transitions_list

transitions_list_E1 = group_transitions(transitions_E1, H_X, H_B, coupling_matrix, ground_Js)
transitions_list_E2 = group_transitions(transitions_E2, H_X, H_B, coupling_matrix, ground_Js)
print(f"Grouped {len(transitions_E1)} matrix elements into {len(transitions_list_E1)} unique transitions")
print(f"Grouped {len(transitions_E2)} matrix elements into {len(transitions_list_E2)} unique transitions")

Grouped 3100 matrix elements into 36 unique transitions
Grouped 1986 matrix elements into 20 unique transitions


In [261]:
sorted(transitions_list_E1 + transitions_list_E2, key= lambda t: t.weighted_energy)

[Transition(O(3) F1'=1/2 F'=0, nγ=1.00),
 Transition(O(3) F1'=1/2 F'=1, nγ=1.00),
 Transition(O(3) F1'=3/2 F'=1, nγ=1.00),
 Transition(O(3) F1'=3/2 F'=2, nγ=1.00),
 Transition(P(2) F1'=1/2 F'=0, nγ=1.49),
 Transition(P(2) F1'=1/2 F'=1, nγ=1.49),
 Transition(P(3) F1'=5/2 F'=2, nγ=1.89),
 Transition(P(3) F1'=5/2 F'=3, nγ=1.89),
 Transition(P(3) F1'=3/2 F'=1, nγ=1.40),
 Transition(P(3) F1'=3/2 F'=2, nγ=1.40),
 Transition(P(2) F1'=3/2 F'=1, nγ=2.04),
 Transition(P(2) F1'=3/2 F'=2, nγ=2.06),
 Transition(Q(1) F1'=1/2 F'=0, nγ=100.00),
 Transition(Q(1) F1'=1/2 F'=1, nγ=98.76),
 Transition(R(0) F1'=1/2 F'=0, nγ=2.94),
 Transition(R(0) F1'=1/2 F'=1, nγ=2.94),
 Transition(Q(1) F1'=3/2 F'=1, nγ=8.46),
 Transition(Q(2) F1'=5/2 F'=2, nγ=99.29),
 Transition(Q(1) F1'=3/2 F'=2, nγ=8.27),
 Transition(Q(2) F1'=5/2 F'=3, nγ=100.00),
 Transition(Q(3) F1'=7/2 F'=3, nγ=None),
 Transition(Q(3) F1'=7/2 F'=4, nγ=None),
 Transition(Q(2) F1'=3/2 F'=1, nγ=5.24),
 Transition(Q(3) F1'=5/2 F'=2, nγ=None),
 Transitio

In [262]:
sorted(transitions_list_E1, key= lambda t: t.weighted_energy)[29]

Transition(R(2) F1'=7/2 F'=4, nγ=None)

In [266]:
transitions_list_E2[-1].coupling_elements_squared

[np.float64(6.0995378592746174),
 np.float64(6.099537859735365),
 np.float64(6.099537860237556),
 np.float64(6.0995378607763175),
 np.float64(6.09953786127765),
 np.float64(6.099537861779277),
 np.float64(6.099537862279258),
 np.float64(6.099537862779242),
 np.float64(6.099537863263625),
 np.float64(6.099537863763245),
 np.float64(6.099537864259974),
 np.float64(6.099537864755969),
 np.float64(6.099537865275921),
 np.float64(6.099537856442198),
 np.float64(6.099537856902945),
 np.float64(6.099537857405136),
 np.float64(6.099537857943895),
 np.float64(6.09953785844523),
 np.float64(6.099537858946855),
 np.float64(6.099537859446838),
 np.float64(6.099537859946819),
 np.float64(6.099537860431205),
 np.float64(6.099537860930822),
 np.float64(6.099537861427553),
 np.float64(6.099537861923549),
 np.float64(6.0995378624435),
 np.float64(6.099537854538079),
 np.float64(6.099537854998827),
 np.float64(6.099537855501017),
 np.float64(6.099537856039777),
 np.float64(6.099537856541109),
 np.float6

In [264]:
transitions_list_E1[-1].transition

OpticalTransition(R(3) F1'=9/2 F'=5)