In [2]:
from pybkit.amo.atom import Yb171, HyperfineEnergyLevel
from pybkit.amo.laser import LaserPolarization
from qutip import *
import numpy as np
import matplotlib.pyplot as plt
from sympy.physics.wigner import wigner_6j
from sympy.physics.quantum.cg import CG

In [4]:
def simulate_constant_raman_gate(
    yb: Yb171,
    hyperfine_levels: list[HyperfineEnergyLevel],
    Omega_649: float,
    Omega_770: float,
    intermediate_detuning: float,
    polarization_649: LaserPolarization,
    polarization_770: LaserPolarization,
    gate_time: float,
    num_points: int,
    initial_states: list[str],
):
    
    freq_scale = 1e9  # GHz
    
    # Initialize basis states
    N = len(hyperfine_levels) + 1  # include extra state to model 3S1 -> 3P1 decay
    basis_states = {}
    for i, level in enumerate(hyperfine_levels):
        basis_states[level.uid] = basis(N, i)

    # Specify desired transition states
    ground_state = yb.get_hyperfine_level('6s6p 3P0 F=0.5 mF=0.5')
    intermediate_state = yb.get_hyperfine_level('6s7s 3S1 F=1.5 mF=0.5')
    excited_state = yb.get_hyperfine_level('6s6p 3P2 F=2.5 mF=0.5')

    # Reference matrix elements used to set Rabi frequencies
    reference_matrix_element_649 = yb.get_matrix_element(ground_state, intermediate_state)
    reference_matrix_element_770 = yb.get_matrix_element(intermediate_state, excited_state)
        
    # Initialize Hamiltonian
    H = 0

    # 3P0 -> 3S1 couplings
    for level1 in hyperfine_levels:
        for level2 in hyperfine_levels:
            if not (level1.term_symbol == '3P0' and level2.term_symbol == '3S1'):
                continue
            matrix_element = yb.get_matrix_element(level1, level2)
            if not matrix_element:
                continue
            Omega = Omega_649 * matrix_element / reference_matrix_element_649
            Omega /= freq_scale
            q = level2.mF - level1.mF
            if q == 0:
                Omega *= polarization_649.pi
            elif q == 1:
                Omega *= polarization_649.sigma_plus
            elif q == -1:
                Omega *= polarization_649.sigma_minus
            if Omega != 0:
                basis1 = basis_states[level1.uid]
                basis2 = basis_states[level2.uid]
                coupling_term = (Omega / 2) * (basis1 * basis2.dag() + basis2 * basis1.dag())
                H += coupling_term
            
    # 3S1 -> 3P2 couplings
    for level1 in hyperfine_levels:
        for level2 in hyperfine_levels:
            if not (level1.term_symbol == '3P2' and level2.term_symbol == '3S1'):
                continue
            matrix_element = yb.get_matrix_element(level1, level2)
            if not matrix_element:
                continue
            Omega = Omega_770 * matrix_element / reference_matrix_element_770
            Omega /= freq_scale
            q = level2.mF - level1.mF
            if q == 0:
                Omega *= polarization_770.pi
            elif q == 1:
                Omega *= polarization_770.sigma_plus
            elif q == -1:
                Omega *= polarization_770.sigma_minus
            if Omega != 0:
                basis1 = basis_states[level1.uid]
                basis2 = basis_states[level2.uid]
                coupling_term = (Omega / 2) * (basis1 * basis2.dag() + basis2 * basis1.dag())
                H += coupling_term
        
    # add detunings
    for level in hyperfine_levels:
        basis_state = basis_states[level.uid]
        if level.term_symbol in ['3P0', '3P2']:
            H += (intermediate_detuning / freq_scale) * basis_state * basis_state.dag()
        if level.term_symbol == '3S1':
            sublevel_detuning = (level.energy_Hz - intermediate_state.energy_Hz) / freq_scale
            H += 2 * np.pi * sublevel_detuning * basis_state * basis_state.dag()
        if level.term_symbol == '3P2':
            sublevel_detuning = (level.energy_Hz - excited_state.energy_Hz) / freq_scale
            H += 2 * np.pi * sublevel_detuning * basis_state * basis_state.dag()
    
    # Time vector
    t_max = gate_time * freq_scale    
    t_list = np.linspace(0, t_max, num_points)
        
    # ============================
    # Construct collapse operators
    # ============================
    
    levels_3s1 = [level for level in hyperfine_levels if level.term_symbol == '3S1']
    levels_3p2 = [level for level in hyperfine_levels if level.term_symbol == '3P2']
    levels_3p0 = [level for level in hyperfine_levels if level.term_symbol == '3P0']
    decay_rate_3s1 = 2 * np.pi * 10.01 * 1e6 / freq_scale
    branching_ratios = {'3P2': 0.5, '3P1': 0.37, '3P0': 0.13}
    
    def calculate_c_op(level_g, level_e):
        I = 1/2
        F_e = level_e.F
        F_g = level_g.F
        mF_e = level_e.mF
        mF_g = level_g.mF
        J_e = level_e.J
        J_g = level_g.J
        basis_e = basis_states[level_e.uid]
        basis_g = basis_states[level_g.uid]
        return (-1)**(F_e + J_g + 1 + I) * \
                np.sqrt((2 * F_e + 1) * (2 * J_g + 1)) * \
                CG(F_e, mF_e, 1, q, F_g, mF_g).doit() * \
                float(wigner_6j(J_e, J_g, 1, F_g, F_e, I)) * \
                basis_g * basis_e.dag()
    
    c_ops_3p2 = []
    c_ops_3p0 = []
    for q in [-1, 0, 1]:
        c_op_3p2 = 0
        c_op_3p0 = 0
        for level_3s1 in levels_3s1:
            for level_3p2 in levels_3p2:
                c_op_3p2 += calculate_c_op(level_3p2, level_3s1)
            for level_3p0 in levels_3p0:
                c_op_3p0 += calculate_c_op(level_3p0, level_3s1)
        c_op_3p2 *= np.sqrt(decay_rate_3s1 * branching_ratios['3P2'])
        c_op_3p0 *= np.sqrt(decay_rate_3s1 * branching_ratios['3P0'])
        c_ops_3p2.append(c_op_3p2)
        c_ops_3p0.append(c_op_3p0)
    c_op_3p1 = 0
    basis_3p1 = basis(N, N-1)
    for level_3s1 in levels_3s1:
        c_op_3p1 += basis_3p1 * basis_states[level_3s1.uid].dag() 
    c_op_3p1 *= np.sqrt(decay_rate_3s1 * branching_ratios['3P1'])
    c_ops = c_ops_3p0 + c_ops_3p2 + [c_op_3p1]
    
    # =======================
    # Simulate time evolution
    # =======================
    
    qubit_basis_0 = basis_states['6s6p 3P0 F=0.5 mF=-0.5']
    qubit_basis_1 = basis_states['6s6p 3P0 F=0.5 mF=0.5']
    rho_inital_dict = {
        '00': qubit_basis_0 * qubit_basis_0.dag(), 
        '01': qubit_basis_0 * qubit_basis_1.dag(), 
        '10': qubit_basis_1 * qubit_basis_0.dag(), 
        '11': qubit_basis_1 * qubit_basis_1.dag()
    }

    # Expectation operators
    e_ops = None #[state['basis'] * state['basis'].dag() for state in basis_states.to_records()]

    # Solve for the time evolution
    result_dict = {}
    for state in initial_states:
        result = mesolve(
            H,
            rho_inital_dict[state], 
            t_list, 
            e_ops=e_ops, 
            c_ops=c_ops,
            options=dict(store_states=True))
        result_dict[state] = result
        
    # Construct return value
    output = {
        'basis_states': basis_states,
        't_list': t_list,
        'results': result_dict
    }
        
    return output

In [5]:
yb = yb = Yb171(B_field=5e-4)

fine_level_ids = ['6s6p 3P0', '6s6p 3P2', '6s7s 3S1']
fine_levels = [yb.get_fine_level(level_id) for level_id in fine_level_ids]

hyperfine_levels = []
for fine_level in fine_levels:
    hyperfine_levels.extend(fine_level.get_hyperfine_levels())
    
Omega_Raman = 2 * np.pi * 1 * 1e6
intermediate_detuning = 2 * np.pi * 20 * 1e9
Omega_649 = np.sqrt(2 * intermediate_detuning * Omega_Raman)
Omega_770 = np.sqrt(2 * intermediate_detuning * Omega_Raman)
polarization_649 = LaserPolarization(sigma_plus=1, sigma_minus=0, pi=0)
polarization_770 = LaserPolarization(sigma_plus=1, sigma_minus=0, pi=0)
gate_time = 1e-6
num_points = 1000
initial_states = ['00', '11']

output = simulate_constant_raman_gate(
    yb,
    hyperfine_levels,
    Omega_649,
    Omega_770,
    intermediate_detuning,
    polarization_649,
    polarization_770,
    gate_time,
    num_points,
    initial_states
)

In [None]:
%matplotlib widget

basis_states = output['basis_states']
results = output['results']

print(basis_states)

ground_0_idx = [i for i, state in enumerate(basis_states.to_records()) \
    if state['level'] == '3P0' and state['F'] == 0.5 and state['mF'] == -0.5][0]
ground_1_idx = [i for i, state in enumerate(basis_states.to_records()) \
    if state['level'] == '3P0' and state['F'] == 0.5 and state['mF'] == 0.5][0]

intermediate_0_idx = [i for i, state in enumerate(basis_states.to_records()) \
    if state['level'] == '3S1' and state['F'] == 1.5 and state['mF'] == -0.5][0]
intermediate_1_idx = [i for i, state in enumerate(basis_states.to_records()) \
    if state['level'] == '3S1' and state['F'] == 1.5 and state['mF'] == 0.5][0]

excited_0_idx = [i for i, state in enumerate(basis_states.to_records()) \
    if state['level'] == '3P2' and state['F'] == 2.5 and state['mF'] == -0.5][0]
excited_1_idx = [i for i, state in enumerate(basis_states.to_records()) \
    if state['level'] == '3P2' and state['F'] == 2.5 and state['mF'] == 0.5][0]


ts = output['t_list']
ys_0 = results['00'].expect[ground_0_idx]
ys_1 = results['11'].expect[ground_1_idx]

fig, ax = plt.subplots(ncols=2, sharex=True, figsize=(8, 4))

ax[0].plot(ts, ys_0)
ax[1].plot(ts, ys_1)


ax[0].set_xlabel('Time [ns]')
ax[1].set_xlabel('Time [ns]')

In [None]:
Gamma = 2 * np.pi * 10.01 * 1e6
1 / Gamma / 1e-9
tau = 15.9 * 1e-9
Gamma = 1 / tau / 2 / np.pi
Gamma * 1e-6

In [None]:
1 / tau  / 100e9 / 1e-4

In [None]:
Omega_R = 2 * np.pi / 1e-6
Delta = 2 * np.pi * 100 * 1e9
Omega = np.sqrt(2 * Omega_R * Delta)
Omega**2 / (4 * Delta**2)

In [None]:
delta = 2 * np.pi * 10 * 1e6
Omega_R**2 / (4 * delta**2)