# Quantum Optimal Control: Two-Qubit Gate

## Imports

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import time

from arc import Rubidium87
from qutip import *
from quantum_optimal_control.two_qubit.propagator_vl import (
    PropagatorVL,
    single_optimization_step
)
from quantum_optimal_control.two_qubit.qutip_simulation import gen_n_level_atom_basis, Mesolve_5lvl_t
from quantum_optimal_control.toolkits.plotting_helper import PlotPlotter, getStylishFigureAxes

key = jax.random.PRNGKey(12)

## System Parameters Setup

We'll define atomic parameters for Rubidium87, including lifetimes and decay rates for the intermediate excited state (6P3/2) and the Rydberg state (70 S1/2). For the two-qubit gate, additional parameters such as the interaction strength `V_int`, gate time `tau`, and detuning `Delta_i` are also specified.

The Rydberg decay rate includes contributions from both radiative decay and blackbody-stimulated transitions.

In [None]:
atom = Rubidium87()

# Intermediate excited state: 6P3/2
n_i = 6
l_i = 1
j_i = 1.5
T_i = atom.getStateLifetime(n_i, l_i, j_i)  # Lifetime of the intermediate state
Gamma_ig = 1/T_i  # Decay rate of the intermediate state

# Additional intermediate decay channels for two-qubit gate
Gamma_i1 = 1 * Gamma_ig  # Allowed decay
Gamma_i0 = 0            # Not used

# Hyperfine ground state decay rate (from e.g. Levine2018)
Gamma_10 = 0.1  # Decay rate between hyperfine ground states (1/s)

# Rydberg state: 70 S1/2
n_r = 70
l_r = 0
j_r = 0.5

# Total lifetime including blackbody stimulation and radiative decay
T_rTot = atom.getStateLifetime(n_r, l_r, j_r, temperature=300, includeLevelsUpTo=n_r + 50)
T_rRad = atom.getStateLifetime(n_r, l_r, j_r, temperature=0)
T_ri = 1/atom.getTransitionRate(n_r, l_r, j_r, n_i, l_i, j_i, temperature=0)

# Effective lifetimes
T_rgp = 1/(1/T_rRad - 1/T_ri)
T_rBB = 1/(1/T_rTot - 1/T_rRad)

Gamma_ri = 1/T_ri
Gamma_rrp = 1/T_rBB
Gamma_rgp = 1/T_rgp
Gamma_rTot = Gamma_ri + Gamma_rrp + Gamma_rgp
Gamma_rd = Gamma_rrp + Gamma_rgp

# Consolidate decay rates
Gammas = [Gamma_10, Gamma_i1, Gamma_ri, Gamma_rd]

# Two-qubit gate specific parameters
V_int = 2 * np.pi * 10e6  # Interaction strength in rad/s
tau = 324e-9             # Gate time (s)
Delta_i = 2 * np.pi * -35.7e6  # Detuning for the gate (rad/s)

print(f'V_int: {V_int/(2*np.pi*1e6):.2f} MHz')
print(f'tau: {tau*1e9:.2f} ns')
print(f'Delta_i: {Delta_i/(2*np.pi*1e6):.2f} MHz')

# Additional system parameters
Rabi_i = 2 * np.pi * 100e6
Rabi_r = 2 * np.pi * 100e6
del_total = 0

# Time grid settings
t_0 = 0
t_f = 2 * tau
nt = 500  # Increase for better accuracy if needed
pad = int(0.03 * nt)  # Padding points
delta_t = (t_f - t_0) / (nt + 2 * pad)
tlist = np.linspace(t_0, t_f, nt + 2*pad)

# Pulse bandwidth (largest frequency component)
f_std = 50e6

# Basis dimension for the Gaussian modes
input_dim = 10

# For the cost function, number of control amplitudes is 5
# (Delta_i, Rabi_i Re/Im, Rabi_r Re/Im in the final pulses)
numb_ctrl_amps = 5

## Build Propagator
We'll create an instance of `PropagatorVL` by providing the relevant parameters (dimensions, time step, detunings, Rabi frequencies, etc.). We'll later use this to both (a) convert `(ctrl_a, ctrl_b, ctrl_c)` into physical pulses, and (b) compute the gate infidelity in the final optimization.

In [None]:
propagator = PropagatorVL(
    input_dim=input_dim,
    no_of_steps=nt,
    pad=pad,
    f_std=f_std,
    delta_t=delta_t,
    del_total=del_total,
    V_int=V_int,
    Delta_i=Delta_i,
    Rabi_i=Rabi_i,
    Rabi_r=Rabi_r,
    Gammas_all=Gammas
)

## (1) Define Analytical Pulses for Initialization
We want to approximate an initial guess:
\[
  \Omega_i(t) = \sin\Bigl(\frac{\pi}{2\tau} t\Bigr),\quad
  \Omega_r(t) = \bigl|\cos\bigl(\frac{\pi}{2\tau} t\bigr)\bigr|.
\]
We'll build arrays for these on our time grid `tlist` and do a short gradient-based step to fit the `(ctrl_a, ctrl_b, ctrl_c)` parameters to match these shapes (in dimensionless form).

In [None]:
def Rabi_i_analytical(t, tau):
    return np.sin(np.pi / (2 * tau) * t)

def Rabi_r_analytical(t, tau):
    return np.abs(np.cos(np.pi / (2 * tau) * t))

# Build arrays for the target Rabi pulses on the entire [0, 2*tau] range
target_ri = np.array([Rabi_i_analytical(t, tau) for t in tlist])  # dimensionless in [0,1]
target_rr = np.array([Rabi_r_analytical(t, tau) for t in tlist])

### Cost Function to Match Rabi Pulses
We will define a simple cost function:
\[
  C = \frac{1}{2N}\sum_{t}\Bigl(\Omega_i^{(gen)}(t) - \Omega_i^{(target)}(t)\Bigr)^2 + \frac{1}{2N}\sum_{t}\Bigl(\Omega_r^{(gen)}(t) - \Omega_r^{(target)}(t)\Bigr)^2,
\]
where \(\Omega_i^{(gen)}(t)\) and \(\Omega_r^{(gen)}(t)\) are the dimensionless generated pulses derived from `(ctrl_a, ctrl_b, ctrl_c)`.

In [None]:
def cost_match_rabi(ctrl_a, ctrl_b, ctrl_c):
    # pulses: [M, 5] => [Delta_i(t), Rabi_i amplitude param, Rabi_i phase param, Rabi_r amplitude param, Rabi_r phase param]
    pulses = propagator.return_physical_amplitudes(ctrl_a, ctrl_b, ctrl_c)

    # Rabi_i magnitude (dimensionless)
    rabi_i_mag = Rabi_i * pulses[:,1]
    rabi_i_dimless = rabi_i_mag / Rabi_i  # => pulses[:,1]

    # Rabi_r magnitude (dimensionless)
    rabi_r_mag = Rabi_r * pulses[:,3]
    rabi_r_dimless = rabi_r_mag / Rabi_r  # => pulses[:,3]

    diff_i = rabi_i_dimless - target_ri
    diff_r = rabi_r_dimless - target_rr

    cost_i = jnp.mean(diff_i**2)
    cost_r = jnp.mean(diff_r**2)

    return 0.5*(cost_i + cost_r)

@jax.jit
def single_match_step(ctrl_a, ctrl_b, ctrl_c, lr=0.03):
    val, grads = jax.value_and_grad(cost_match_rabi, argnums=(0,1,2))(ctrl_a, ctrl_b, ctrl_c)
    gA, gB, gC = grads

    ctrl_a_upd = ctrl_a - lr*gA
    ctrl_b_upd = ctrl_b - lr*gB
    ctrl_c_upd = ctrl_c - lr*gC
    return val, ctrl_a_upd, ctrl_b_upd, ctrl_c_upd

### Run the Matching to Obtain an Initial Guess
We'll do a moderate number of gradient steps to find `(ctrl_a, ctrl_b, ctrl_c)` that best match our target Rabi shapes. Then we'll use these as the starting point for the actual gate optimization.

We do **not** attempt to match phases or detunings at this stage; we only match the magnitude portion for `\Omega_i(t)` and `\Omega_r(t)` to keep it simple.

In [None]:
# Initialize the control parameters with small random values
key, s1, s2, s3 = jax.random.split(key, 4)

ctrl_a_init = jax.random.normal(s1, shape=(input_dim, numb_ctrl_amps))*0.3
ctrl_b_init = jax.random.normal(s2, shape=(input_dim, numb_ctrl_amps))*0.3
ctrl_c_init = jax.random.normal(s3, shape=(input_dim, numb_ctrl_amps))*0.3

init_cost = cost_match_rabi(ctrl_a_init, ctrl_b_init, ctrl_c_init)
print(f"Initial matching cost = {init_cost:.6e}")

# Perform gradient-based matching
num_match_iters = 1000
lr_match = 0.03

ctrl_a, ctrl_b, ctrl_c = ctrl_a_init, ctrl_b_init, ctrl_c_init

# Warmup
_warm_val, ctrl_a, ctrl_b, ctrl_c = single_match_step(ctrl_a, ctrl_b, ctrl_c, lr_match)

cost_history_match = []
for step in range(num_match_iters):
    val, ctrl_a, ctrl_b, ctrl_c = single_match_step(ctrl_a, ctrl_b, ctrl_c, lr_match)
    cost_history_match.append(val)
    if (step+1) % 200 == 0:
        print(f"Match step {step+1}, cost = {val:.6e}")

final_match_cost = cost_history_match[-1]
print(f"Final matching cost after {num_match_iters} steps = {final_match_cost:.6e}")

### Check the Matched Pulses
We confirm that the dimensionless pulses for `\Omega_i` and `\Omega_r` at this stage are close to the target shapes.

In [None]:
pulses_matched = propagator.return_physical_amplitudes(ctrl_a, ctrl_b, ctrl_c)
pulses_matched_np = np.array(pulses_matched)

ri_matched = Rabi_i * pulses_matched_np[:,1]
rr_matched = Rabi_r * pulses_matched_np[:,3]

ri_dimless = ri_matched / Rabi_i
rr_dimless = rr_matched / Rabi_r

fig, ax = getStylishFigureAxes(1,1)
PlotPlotter(
    fig, ax,
    tlist*1e9,
    target_ri,
    style={'label': 'Target Rabi_i', 'marker': '', 'linestyle':'--'}
).draw()
PlotPlotter(
    fig, ax,
    tlist*1e9,
    target_rr,
    style={'label': 'Target Rabi_r', 'marker': '', 'linestyle':'--'}
).draw()

PlotPlotter(
    fig, ax,
    tlist*1e9,
    ri_dimless,
    style={'label': 'Matched Rabi_i', 'marker': '', 'linestyle':'-'}
).draw()
PlotPlotter(
    fig, ax,
    tlist*1e9,
    rr_dimless,
    style={'label': 'Matched Rabi_r', 'marker': '', 'linestyle':'-'}
).draw()

ax.set_xlabel('Time (ns)')
ax.set_ylabel(r'$\tilde{\Omega}/\Omega_{max}$')
ax.legend(fontsize=6, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.show()

## (2) Main Gate Optimization
Next, we proceed with the gate infidelity minimization just as in the `02_two_qubit_gate.ipynb`. However, **instead** of randomizing `(ctrl_a, ctrl_b, ctrl_c)` at the start, we use the matched values from above as our initial guess.

We call the function `propagator.target(ctrl_a, ctrl_b, ctrl_c)` which returns a cost combining gate infidelity and adiabatic penalty (as defined in `PropagatorVLJAX`). We'll do a standard gradient descent or Adam-like approach.

In [None]:
# We'll start from the matched solution
ctrl_a_opt = ctrl_a
ctrl_b_opt = ctrl_b
ctrl_c_opt = ctrl_c

# Evaluate initial gate cost
init_gate_cost = propagator.target(ctrl_a_opt, ctrl_b_opt, ctrl_c_opt)
print(f"Initial gate cost with matched pulses: {init_gate_cost:.6e}")

num_gate_iters = 1250
learning_rate = 0.02

cost_history_gate = []

# Warmup step for JIT
_warm_gate_cost, ctrl_a_opt, ctrl_b_opt, ctrl_c_opt = single_optimization_step(
    propagator, ctrl_a_opt, ctrl_b_opt, ctrl_c_opt, lr=learning_rate
)

for step in range(num_gate_iters):
    cost_val, ctrl_a_opt, ctrl_b_opt, ctrl_c_opt = single_optimization_step(
        propagator, ctrl_a_opt, ctrl_b_opt, ctrl_c_opt, lr=learning_rate
    )
    cost_history_gate.append(cost_val)
    if (step + 1) % 10 == 0:
        print(f"Iteration {step+1}, gate cost = {cost_val:.6e}")

best_cost = cost_history_gate[-1]
print(f"Final cost funct {num_gate_iters} steps: {best_cost:.6e}")

### Analyze Optimized Pulses
Now, we can look at the final pulses and measure their final gate infidelity. For completeness, let's also quickly plot them.

In [None]:
final_pulses = propagator.return_physical_amplitudes(ctrl_a_opt, ctrl_b_opt, ctrl_c_opt)
final_pulses_np = np.array(final_pulses)

labels = [r"$\Delta_i$", r"Rabi_i\,mag", r"Rabi_i\,phase", r"Rabi_r\,mag", r"Rabi_r\,phase"]

fig, ax = getStylishFigureAxes(1,1)
for i in range(final_pulses_np.shape[1]):
    PlotPlotter(
        fig,
        ax,
        tlist*1e9,
        final_pulses_np[:, i],
        style={'label': labels[i], 'marker': '', 'linestyle': '-'},
        xticks=[0, 250, 500, 750]
    ).draw()

ax.set_xlabel('Time (ns)')
ax.set_ylabel('Pulse Amplitude')
ax.legend(fontsize=6, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.show()

## Final Infidelity and Adiabatic Metric
We can use `propagator.metrics(ctrl_a_opt, ctrl_b_opt, ctrl_c_opt)` to get the final infidelity and an additional adiabatic measure. Let's do that.

In [None]:
infid, adiab = propagator.metrics(ctrl_a_opt, ctrl_b_opt, ctrl_c_opt)
print(f"Final infidelity: {infid:.6e}\nAdiabatic metric: {adiab:.6e}")

In [None]:
g_0, g_1, i, r, dark = gen_n_level_atom_basis(5)

psi0 = 1/2 * (tensor(g_0, g_0) + tensor(g_0, g_1) +
              tensor(g_1, g_0) + tensor(g_1, g_1))
psi_targ = 1/2 * (tensor(g_0, g_0) + tensor(g_0, g_1) +
                  tensor(g_1, g_0) - tensor(g_1, g_1))

# Set parameters for Hamiltonian

final_pulses_np = np.array(final_pulses)
Rabi_i_Pulse_Re = final_pulses_np[:, 1] * np.cos(np.pi * final_pulses_np[:, 2])
Rabi_i_Pulse_Im = final_pulses_np[:, 1] * np.sin(np.pi * final_pulses_np[:, 2])
Rabi_r_Pulse_Re = final_pulses_np[:, 3] * np.cos(np.pi * final_pulses_np[:, 4])
Rabi_r_Pulse_Im = final_pulses_np[:, 3] * np.sin(np.pi * final_pulses_np[:, 4])
Delta_i_Pulse = final_pulses_np[:, 0]


argsME5lvl = [psi0, psi_targ, V_int, Rabi_i, Rabi_r,
              Rabi_i_Pulse_Re, Rabi_r_Pulse_Re, Rabi_i_Pulse_Im,
              Rabi_r_Pulse_Im, Delta_i_Pulse, Delta_i, del_total, Gammas
              ]

# #Solve Propagator Equation: Superoperator
result = Mesolve_5lvl_t(tlist, argsME5lvl, output_states=True)
fidelity = result.expect[1][-1]
print(f"fidelity: {fidelity:.5f}")

In [None]:
# calcualte the fidelity for the analytical solution

Rabi_i_re_analytical = np.sin(np.pi / (2 * tau) * tlist)
Rabi_i_im_analytical = np.zeros_like(tlist)
Rabi_r_re_analytical = np.abs(np.cos(np.pi / (2 * tau) * tlist))
Rabi_r_im_analytical = np.zeros_like(tlist)
Delta_i_analytical = np.ones_like(tlist)


argsME5lvl_analytical = [psi0, psi_targ, V_int, Rabi_i, Rabi_r,
                        Rabi_i_re_analytical, Rabi_r_re_analytical, Rabi_i_im_analytical,
                        Rabi_r_im_analytical, Delta_i_analytical, Delta_i, del_total, Gammas]
result_analytical = Mesolve_5lvl_t(tlist, argsME5lvl_analytical, output_states=True)
fidelity_analytical = result_analytical.expect[1][-1]
print(fidelity_analytical)


## (3) **Optional**: Using the Optimized Propagator
Here, we demonstrate how to use the new `PropagatorVLOpt` class in `propagator_vl_optimized.py` to potentially accelerate the same two-qubit gate computations.

In [None]:
from quantum_optimal_control.two_qubit.propagator_vl_optimized import (
    PropagatorVLOpt, single_optimization_step_opt
)

# Create an instance of the optimized propagator
propagator_opt = PropagatorVLOpt(
    input_dim=input_dim,
    no_of_steps=nt,
    pad=pad,
    f_std=f_std,
    delta_t=delta_t,
    del_total=del_total,
    V_int=V_int,3
    Delta_i=Delta_i,
    Rabi_i=Rabi_i,
    Rabi_r=Rabi_r,
    Gammas_all=Gammas
)

### Timing Comparison
We show how one might measure runtime for a few calls to `.target(...)` or a quick optimization loop using the new optimized version. The speedup will depend on hardware and specific JAX environment.

In [None]:
import time

# Copy the matched pulses from above to use as a test
test_ctrl_a = ctrl_a_init
test_ctrl_b = ctrl_b_init
test_ctrl_c = ctrl_c_init

# Force a warm-up
_ = propagator_opt.target(test_ctrl_a, test_ctrl_b, test_ctrl_c).block_until_ready()

t0 = time.perf_counter()
# e.g. measure multiple calls
for _ in range(10):
    _ = propagator_opt.target(test_ctrl_a, test_ctrl_b, test_ctrl_c)
jax.block_until_ready(_)
t1 = time.perf_counter()

print(f"Optimized version: 10 calls to target(...) took {t1 - t0:.4f} s")

### Optional: Quick Optimization with the Optimized Propagator
We can run a few gradient steps to see if it converges similarly and measure time.

In [None]:
test_lr = 0.02
n_steps = 50
ctrl_a_opt2 = test_ctrl_a
ctrl_b_opt2 = test_ctrl_b
ctrl_c_opt2 = test_ctrl_c

# warm-up
_warm_val2, ctrl_a_opt2, ctrl_b_opt2, ctrl_c_opt2 = single_optimization_step_opt(
    propagator_opt, ctrl_a_opt2, ctrl_b_opt2, ctrl_c_opt2, lr=test_lr)

t2 = time.perf_counter()
for step in range(n_steps):
    cost_val2, ctrl_a_opt2, ctrl_b_opt2, ctrl_c_opt2 = single_optimization_step_opt(
        propagator_opt, ctrl_a_opt2, ctrl_b_opt2, ctrl_c_opt2, lr=test_lr
    )
jax.block_until_ready(cost_val2)
t3 = time.perf_counter()
print(f"{n_steps} steps of optimization with PropagatorVLOpt: {t3 - t2:.4f} s")
print(f"Final cost: {cost_val2:.6e}")