# Quantum Optimal Control: Two-Qubit Gate (JAX Version)

In this notebook, we demonstrate how to construct and optimize a two-qubit CZ gate using a JAX-based Van Loan propagator (`PropagatorVLJAX`). We keep the same structure as the previous TensorFlow version but replace the underlying code with JAX. We'll:
1. Define the physical system parameters (Rubidium atomic parameters, lifetimes, decay rates)
2. Set up the gate parameters, time grids, and an instance of `PropagatorVLJAX`
3. Randomly initialize control pulses in terms of Gaussian modes
4. Run a gradient-based optimization loop using JAX's automatic differentiation
5. Analyze and plot the resulting optimal pulses.


In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt3
import time

from arc import Rubidium87
from qutip import *
from quantum_optimal_control.two_qubit.propagator_vl_jax import (
    PropagatorVLJAX,
    single_optimization_step
)
from quantum_optimal_control.toolkits.plotting_helper import PlotPlotter, getStylishFigureAxes

key = jax.random.PRNGKey(6)


## System Parameters Setup

We'll define atomic parameters for Rubidium87, including lifetimes and decay rates for the intermediate excited state (6P3/2) and the Rydberg state (70 S1/2). For the two-qubit gate, additional parameters such as the interaction strength `V_int`, gate time `tau`, and detuning `Delta_i` are also specified.

The Rydberg decay rate includes contributions from both radiative decay and blackbody-stimulated transitions.

In [None]:
atom = Rubidium87()

# Intermediate excited state: 6P3/2
n_i = 6
l_i = 1
j_i = 1.5
T_i = atom.getStateLifetime(n_i, l_i, j_i)  # Lifetime of the intermediate state
Gamma_ig = 1/T_i  # Decay rate of the intermediate state

# Additional intermediate decay channels for two-qubit gate
Gamma_i1 = 1 * Gamma_ig  # Allowed decay
Gamma_i0 = 0            # Not used

# Hyperfine ground state decay rate (from e.g. Levine2018)
Gamma_10 = 0.1  # Decay rate between hyperfine ground states (1/s)

# Rydberg state: 70 S1/2
n_r = 70
l_r = 0
j_r = 0.5

# Total lifetime including blackbody stimulation and radiative decay
T_rTot = atom.getStateLifetime(n_r, l_r, j_r, temperature=300, includeLevelsUpTo=n_r + 50)
T_rRad = atom.getStateLifetime(n_r, l_r, j_r, temperature=0)
T_ri = 1/atom.getTransitionRate(n_r, l_r, j_r, n_i, l_i, j_i, temperature=0)

# Effective lifetimes
T_rgp = 1/(1/T_rRad - 1/T_ri)
T_rBB = 1/(1/T_rTot - 1/T_rRad)

Gamma_ri = 1/T_ri
Gamma_rrp = 1/T_rBB
Gamma_rgp = 1/T_rgp
Gamma_rTot = Gamma_ri + Gamma_rrp + Gamma_rgp
Gamma_rd = Gamma_rrp + Gamma_rgp

# Consolidate decay rates
Gammas = [Gamma_10, Gamma_i1, Gamma_ri, Gamma_rd]

# Two-qubit gate specific parameters
V_int = 2 * np.pi * 10e6  # Interaction strength in rad/s
tau = 324e-9             # Gate time (s)
Delta_i = 2 * np.pi * -35.7e6  # Detuning for the gate (rad/s)

print(f'V_int: {V_int/(2*np.pi*1e6):.2f} MHz')
print(f'tau: {tau*1e9:.2f} ns')
print(f'Delta_i: {Delta_i/(2*np.pi*1e6):.2f} MHz')

###### Additional system parameters ######
Rabi_i = 2 * np.pi * 100e6
Rabi_r = 2 * np.pi * 100e6
del_total = 0

# Time grid settings
t_0 = 0
t_f = 2 * tau
nt = 500  # Increase for more accurate or keep moderate for speed
pad = int(0.03 * nt)  # Padding points
delta_t = (t_f - t_0) / (nt + 2 * pad)
tlist = np.linspace(t_0, t_f, nt + 2*pad)

# Pulse bandwidth (largest frequency component)
f_std = 50e6

# Basis dimension for the Gaussian modes
input_dim = 10

# For the cost function, number of control amplitudes is 5 (Delta_i, Rabi_i Re/Im, Rabi_r Re/Im in the old TF approach)
numb_ctrl_amps = 5


## Initializing the JAX Propagator

Next, we create an instance of `PropagatorVLJAX` by providing the relevant parameters (dimensions, time step, detunings, Rabi frequencies, etc.). We also initialize random control parameters (`ctrl_a`, `ctrl_b`, `ctrl_c`) using JAX's random utilities. These parameters define how our pulses are constructed from sums of Gaussian modes.

In [None]:
# Create the JAX-based propagator
propagator_jax = PropagatorVLJAX(
    input_dim=input_dim,
    no_of_steps=nt,
    pad=pad,
    f_std=f_std,
    delta_t=delta_t,
    del_total=del_total,
    V_int=V_int,
    Delta_i=Delta_i,
    Rabi_i=Rabi_i,
    Rabi_r=Rabi_r,
    Gammas_all=Gammas
)

### Randomly Initialize Controls

We define three arrays of shape `[input_dim, 5]`: `ctrl_a`, `ctrl_b`, and `ctrl_c`. They are used internally by `PropagatorVLJAX` to generate the time-dependent pulses via sums of Gaussian basis functions. Let's randomize them.

In [None]:
Rabi_i_analytical = lambda t, tau: np.sin(np.pi / (2 * tau) * t)  # Rabi conncecting 1 to i
Rabi_r_analytical = lambda t, tau: abs(np.cos(np.pi / (2 * tau) * t))  # Rabi conncecting i to r

In [None]:
key, subkey1, subkey2, subkey3 = jax.random.split(key, 4)

# ctrl_a: random amplitudes
ctrl_a = jax.random.uniform(
    subkey1,
    shape=(input_dim, numb_ctrl_amps),
    minval=-1.0,
    maxval=1.0
).astype(jnp.float64)

# ctrl_b: random centers in [-1,1]
ctrl_b = jax.random.uniform(
    subkey2,
    shape=(input_dim, numb_ctrl_amps),
    minval=-1.0,
    maxval=1.0
).astype(jnp.float64)

# ctrl_c: random widths in [0,0.5]
ctrl_c = jax.random.uniform(
    subkey3,
    shape=(input_dim, numb_ctrl_amps),
    minval=0.0,
    maxval=0.5
).astype(jnp.float64)

# Evaluate initial cost
init_cost = propagator_jax.target(ctrl_a, ctrl_b, ctrl_c)
print(f"Initial cost: {init_cost:.6e}")

## Visualize Initial Control Pulses

Just as we did in the TensorFlow version, let's examine the raw pulses produced by our random parameters before any optimization.

In [None]:
# Retrieve physical pulses [M, 5]
initial_pulses = propagator_jax.return_physical_amplitudes(ctrl_a, ctrl_b, ctrl_c)
initial_pulses_np = np.array(initial_pulses)

labels = [r"$\Delta_i$", r"Rabi_i\,mag", r"Rabi_i\,phase", r"Rabi_r\,mag", r"Rabi_r\,phase"]

fig, ax = getStylishFigureAxes(1,1)
for i in range(initial_pulses_np.shape[1]):
    PlotPlotter(
        fig,
        ax,
        tlist*1e9,
        initial_pulses_np[:, i],
        style={'label': labels[i], 'marker': '', 'linestyle': '-'},
        xticks=[0, 250, 500, 750]
    ).draw()

ax.set_xlabel('Time (ns)')
ax.set_ylabel('Pulse Amplitude')
ax.legend(fontsize=6, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.show()

## Gradient-Based Optimization

We'll perform a simple gradient-descent (or Adam-based) loop. Here we use the `single_optimization_step` function that returns the cost and updated control parameters. The cost is defined within `PropagatorVLJAX.target()`.

For demonstration, let's do a certain number of steps and monitor the cost. If you want to use more sophisticated optimizers (like Adam), you can write your own update step or adapt the code with JAX's `optax` library.

In [None]:
num_iters = 1000
learning_rate = 0.02

cost_history = []

# Pre-compile (JIT) by doing one warmup step
warmup_cost, ctrl_a, ctrl_b, ctrl_c = single_optimization_step(
    propagator_jax, ctrl_a, ctrl_b, ctrl_c, lr=learning_rate
)

for step in range(num_iters):
    cost_val, ctrl_a, ctrl_b, ctrl_c = single_optimization_step(
        propagator_jax, ctrl_a, ctrl_b, ctrl_c, lr=learning_rate
    )
    cost_val_np = np.array(cost_val)
    cost_history.append(cost_val_np)
    if (step + 1) % 10 == 0:
        print(f"Iteration {step+1}, cost = {cost_val_np:.6e}")

best_cost = cost_history[-1]
print(f"Final cost after {num_iters} steps: {best_cost:.6e}")

## Analyze Optimized Pulses

We'll now plot the final pulses and see how they differ from the initial random guess. The first channel corresponds to $\Delta_i(t)$, and the next four columns represent Rabi amplitude/phase for the intermediate and Rydberg drives (in a parametric form).

In [None]:
# Get final pulses
final_pulses = propagator_jax.return_physical_amplitudes(ctrl_a, ctrl_b, ctrl_c)
final_pulses_np = np.array(final_pulses)

fig, ax = getStylishFigureAxes(1,1)
for i in range(final_pulses_np.shape[1]):
    PlotPlotter(
        fig, ax,
        tlist*1e9,
        final_pulses_np[:, i],
        style={'label': labels[i], 'marker': '', 'linestyle': '-'},
        xticks=[0, 250, 500, 750]
    ).draw()

ax.set_xlabel('Time (ns)')
ax.set_ylabel('Pulse Amplitude')
ax.legend(fontsize=6, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.show()

### Extract Rabi Magnitudes and Phases

Just like the previous approach, we can interpret the second and third columns as amplitude and phase for $\Omega_i$, and the fourth and fifth columns as amplitude and phase for $\Omega_r$. Let's convert them to a more direct $(\text{magnitude}, \text{phase})$ representation.

In [None]:
rabi_i_mag = Rabi_i * final_pulses_np[:,1]
rabi_i_ph = np.pi * final_pulses_np[:,2]
rabi_r_mag = Rabi_r * final_pulses_np[:,3]
rabi_r_ph = np.pi * final_pulses_np[:,4]

# rabi_i_amplitude = rabi_i_mag * np.exp(1j * rabi_i_ph)
# rabi_i_amplitude = np.abs(rabi_i_amplitude)
# rabi_r_amplitude = rabi_r_mag * np.exp(1j * rabi_r_ph)
# rabi_r_amplitude = np.abs(rabi_r_amplitude)


fig, ax = getStylishFigureAxes(1, 1)
PlotPlotter(
    fig, ax,
    tlist*1e9,
    rabi_i_amplitude    ,
    style={'label': r'$|\Omega_i|$', 'marker': '', 'linestyle': '-'}
).draw()

PlotPlotter(
    fig, ax,
    tlist*1e9,
    rabi_r_amplitude,
    style={'label': r'$|\Omega_r|$', 'marker': '', 'linestyle': '-'}
).draw()


plt.show()

In [None]:
fig, ax = getStylishFigureAxes(1, 1)
PlotPlotter(
    fig, ax,
    tlist*1e9,
    rabi_i_ph,
    style={'label': r'$\phi_i$', 'marker': '', 'linestyle': '-'}
).draw()

PlotPlotter(
    fig, ax,
    tlist*1e9,
    rabi_r_ph,
    style={'label': r'$\phi_r$', 'marker': '', 'linestyle': '-'}
).draw()

plt.show()

In [None]:
final_np = final_pulses_np
# final_np[:, 0]: Delta_i(t)
# final_np[:, 1]: Rabi_i amplitude param, final_np[:, 2]: Rabi_i phase param
# final_np[:, 3]: Rabi_r amplitude param, final_np[:, 4]: Rabi_r phase param

rabi_i_mag = Rabi_i * final_np[:,1]
rabi_i_ph = np.pi * final_np[:,2]
rabi_r_mag = Rabi_r * final_np[:,3]
rabi_r_ph = np.pi * final_np[:,4]

# Rabi_i_abs = np.sqrt((rabi_i_mag * np.cos(rabi_i_ph / Rabi_i))**2 + (rabi_i_mag * np.sin(rabi_i_ph / Rabi_i))**2)
# Actually simpler might be: Rabi_i_abs ~ rabi_i_mag, but we must be mindful the code lumps amplitude & phase differently.
# We'll do a direct approach:
# ri_complex = rabi_i_mag * np.exp(1j * rabi_i_ph / Rabi_i)  # but the original code lumps pi factor

# Instead let's keep it simpler:
Rabi_i_magnitude = np.sqrt((final_np[:,1] * np.cos(np.pi * final_np[:,2]))**2 + (final_np[:,1] * np.sin(np.pi * final_np[:,2]))**2)
Rabi_r_magnitude = np.sqrt((final_np[:,3] * np.cos(np.pi * final_np[:,4]))**2 + (final_np[:,3] * np.sin(np.pi * final_np[:,4]))**2)

Rabi_i_phase = np.pi * final_np[:,2]
Rabi_r_phase = np.pi * final_np[:,4]

fig, ax = getStylishFigureAxes(1, 1)
PlotPlotter(
    fig, ax,
    tlist*1e9,
    Rabi_i_magnitude,
    style={'label': r'$|\Omega_i|$', 'marker': '', 'linestyle': '-'}
).draw()

PlotPlotter(
    fig, ax,
    tlist*1e9,
    Rabi_r_magnitude,
    style={'label': r'$|\Omega_r|$', 'marker': '', 'linestyle': '-'}
).draw()

ax.legend(fontsize=6)
ax.set_xlabel('Time (ns)')
ax.set_ylabel('Normalized Rabi Magnitude (dimensionless)')
plt.show()

In [None]:
fig, ax = getStylishFigureAxes(1, 1)
PlotPlotter(
    fig, ax,
    tlist*1e9,
    Rabi_i_phase,
    style={'label': r'$\phi_i$', 'marker': '', 'linestyle': '-'}
).draw()

PlotPlotter(
    fig, ax,
    tlist*1e9,
    Rabi_r_phase,
    style={'label': r'$\phi_r$', 'marker': '', 'linestyle': '-'}
).draw()

ax.legend(fontsize=6)
ax.set_xlabel('Time (ns)')
ax.set_ylabel('Normalized Rabi Magnitude (dimensionless)')
plt.show()

## Brief State Evolution Check

We can run the final solution through the propagator (via `propagate`) and examine the final operator, or even do a step-by-step evolution to check relevant population or amplitude metrics.
Below is a snippet that obtains the final 32x32 Van Loan operator and extracts the top-left 16x16 as the effective two-qubit evolution operator.


In [None]:
final_vl = propagator_jax.propagate(ctrl_a, ctrl_b, ctrl_c)
dim_sq = propagator_jax.dim * propagator_jax.dim
U_final = final_vl[0:dim_sq, 0:dim_sq]

print("Final 2-qubit operator shape:", U_final.shape)
print("Trace of U_final:", np.trace(np.array(U_final)))


### Infidelity Check

We already have a cost measure built in, but let's verify the final infidelity and other metrics we can retrieve from `propagator_jax.metrics(...)`.

In [None]:
infid, adiab = propagator_jax.metrics(ctrl_a, ctrl_b, ctrl_c)
print(f"Final infidelity: {infid:.6e}, Adiabatic metric: {adiab:.6e}")


## Conclusion

We've successfully replaced the TensorFlow-based code with a JAX-based Van Loan propagator for two-qubit gates. We:
- Defined the same Rubidium87 system parameters.
- Constructed pulses via random Gaussian modes.
- Performed gradient updates on the cost function using JAX's autodiff.
- Analyzed the final pulses and checked the resulting gate operator.

For production-scale usage, we can increase `nt`, refine the basis dimension, or combine `PropagatorVLJAX` with advanced optimizers in JAX or other libraries.