# Quantum Optimal Control: Two-Qubit Gate

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import time

from arc import Rubidium87
from qutip import *

from quantum_optimal_control.two_qubit.propagator_vl import PropagatorVL
from quantum_optimal_control.toolkits.plotting_helper import getStylishFigureAxes, PlotPlotter

# Suppress TF warnings
tf.get_logger().setLevel('ERROR')

# Set random seed for reproducibility
tf.random.set_seed(42)

## System Parameters Setup

Define atomic parameters for Rubidium, including lifetimes and decay rates for the intermediate excited state (6P3/2) and the Rydberg state (70 S1/2). For the two-qubit gate, additional parameters such as the interaction strength `V_int`, gate time `tau`, and detuning `Delta_i` are specified.

The Rydberg decay rate includes contributions from both radiative decay and blackbody-stimulated transitions.

In [None]:
atom = Rubidium87()

# Intermediate excited state: 6P3/2
n_i = 6
l_i = 1
j_i = 1.5
T_i = atom.getStateLifetime(n_i, l_i, j_i)  # Lifetime of the intermediate state
Gamma_ig = 1/T_i  # Decay rate of the intermediate state

# Additional intermediate decay channels for two-qubit gate
Gamma_i1 = 1 * Gamma_ig  # Allowed decay channel
Gamma_i0 = 0 * Gamma_ig  # Forbidden channel

# Hyperfine ground state decay rate (from Levine2018)
Gamma_10 = 0.1  # Decay rate between hyperfine ground states (1/s)

# Rydberg state: 70 S1/2
n_r = 70
l_r = 0
j_r = 0.5

# Total lifetime including blackbody stimulation and radiative decay
T_rTot = atom.getStateLifetime(n_r, l_r, j_r, temperature=300, includeLevelsUpTo=n_r + 50)
T_rRad = atom.getStateLifetime(n_r, l_r, j_r, temperature=0)  # Radiative lifetime at 0 K
T_ri = 1/atom.getTransitionRate(n_r, l_r, j_r, n_i, l_i, j_i, temperature=0)  # Transition lifetime from Rydberg to intermediate

# Effective lifetimes for radiative and blackbody-stimulated channels
T_rgp = 1/(1/T_rRad - 1/T_ri)
T_rBB = 1/(1/T_rTot - 1/T_rRad)

Gamma_ri = 1/T_ri
Gamma_rrp = 1/T_rBB  # Blackbody stimulated decay channel
Gamma_rgp = 1/T_rgp  # Radiative decay channel to dark ground states
Gamma_rTot = Gamma_ri + Gamma_rrp + Gamma_rgp  # Total Rydberg decay rate
Gamma_rd = Gamma_rrp + Gamma_rgp

# Consolidate decay rates
Gammas = [Gamma_10, Gamma_i1, Gamma_ri, Gamma_rd]

# Two-qubit gate specific parameters
V_int = 2 * np.pi * 10e6  # Interaction strength in rad/s
tau = 324e-9            # Gate time (s)
Delta_i = 2 * np.pi * -35.7e6  # Detuning for the gate (rad/s)

print(f'V_int: {V_int/(2*np.pi*1e6):.2f} MHz')
print(f'tau: {tau*1e9:.2f} ns')
print(f'Delta_i: {Delta_i/(2*np.pi*1e6):.2f} MHz')

###### Additional system parameters ######
Rabi_i = 2 * np.pi * 100e6  # Drive coupling for channel i (rad/s)
Rabi_r = 2 * np.pi * 100e6  # Drive coupling for channel r (rad/s)
del_total = 0

# Time grid settings
t_0 = 0
t_f = 2 * tau
nt = 1000
pad = int(0.03 * nt)  # Padding points for zero pulse regions
delta_t = (t_f - t_0) / (nt + 2 * pad)
tlist = np.linspace(t_0, t_f, nt + 2 * pad)

# Pulse bandwidth (largest frequency component)
f_std = 50e6

# Basis functions for control amplitudes (Gaussian modes)
input_dim = 10

# Optimization parameters
num_iters = 500
learn_rate = 0.02

## Initialize Propagator

Instantiate the `PropagatorVL` for the two-qubit gate. The control amplitudes are parameterized by Gaussian modes. Initial guesses for amplitudes, centers, and widths are randomized.

In [None]:
propagatorVL = PropagatorVL(
    input_dim, nt, pad, f_std, delta_t, del_total, V_int, Delta_i, Rabi_i, Rabi_r,
    Gammas
)

# There are 5 control amplitudes: Delta_i, Re/Im of Rabi_i, Re/Im of Rabi_r
numb_ctrl_amps = 5

# Set initial Gaussian amplitudes
propagatorVL.ctrl_amplitudes_a.assign(
    tf.random.uniform([input_dim, numb_ctrl_amps], -1, 1, dtype=tf.float64)
)

# Initialize Gaussian centers for Delta_i
propagatorVL.ctrl_amplitudes_b[:, 0].assign(
    tf.random.uniform([input_dim], -1, 1, dtype=tf.float64)
)
# For Rabi_i: set centers for amplitude and phase
propagatorVL.ctrl_amplitudes_b[:, 1:2].assign(
    tf.random.uniform([input_dim, 1], -0.5, 0.5, dtype=tf.float64)
)
propagatorVL.ctrl_amplitudes_b[:, 2:3].assign(
    tf.random.uniform([input_dim, 1], -1, 1, dtype=tf.float64)
)
# For Rabi_r: set centers for amplitude and phase
propagatorVL.ctrl_amplitudes_b[0:int(input_dim/2), 3:4].assign(
    tf.random.uniform([int(input_dim/2), 1], -1, -0.5, dtype=tf.float64)
)
propagatorVL.ctrl_amplitudes_b[int(input_dim/2):input_dim, 3:4].assign(
    tf.random.uniform([int(input_dim/2), 1], 0.5, 1, dtype=tf.float64)
)
propagatorVL.ctrl_amplitudes_b[:, 4:5].assign(
    tf.random.uniform([input_dim, 1], -1, 1, dtype=tf.float64)
)

# Set Gaussian widths
propagatorVL.ctrl_amplitudes_c.assign(
    tf.random.uniform([input_dim, numb_ctrl_amps], 0, 0.5, dtype=tf.float64)
)

initial_infidelity = propagatorVL.target()
print('Initial Figure of Merit:', initial_infidelity.numpy())

## Visualize Initial Control Pulses

Plot the initial control pulses using production-quality plotting functions.

In [None]:
physical_amplitudes_initial = propagatorVL.return_physical_amplitudes().numpy()
labels = [r"$\Delta_i$", r"$\mathrm{Re}(\Omega_i)$", r"$\mathrm{Im}(\Omega_i)$", r"$\mathrm{Re}(\Omega_r)$", r"$\mathrm{Im}(\Omega_r)$"]
colors = ['k', 'm--', 'b:', 'c--', 'r:']

fig, ax = getStylishFigureAxes(1, 1)
for ind, amplitude in enumerate(tf.transpose(physical_amplitudes_initial)):
    PlotPlotter(fig, ax, tlist * 1e9, amplitude, style={'label': labels[ind], 'marker': '', 'linestyle': '-', 'linewidth': 1}).draw()
ax.set_xlabel('Time (ns)')
ax.set_ylabel('Pulse Amplitude')
ax.legend(fontsize=6, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.show()

## Optimize Control Pulses

Apply gradient-based optimization to minimize the target cost function and update the control pulses.

In [None]:
optimizer = tf.keras.optimizers.Adam(learn_rate)

@tf.function
def optimization_step():
    with tf.GradientTape() as tape:
        cost = propagatorVL.target()
    grads = tape.gradient(cost, [
        propagatorVL.ctrl_amplitudes_a,
        propagatorVL.ctrl_amplitudes_b,
        propagatorVL.ctrl_amplitudes_c
    ])
    optimizer.apply_gradients(zip(grads, [
        propagatorVL.ctrl_amplitudes_a,
        propagatorVL.ctrl_amplitudes_b,
        propagatorVL.ctrl_amplitudes_c
    ]))
    return propagatorVL.target()

best_infidelity = 1.0
for step in range(num_iters):
    current_cost = optimization_step()
    if (step + 1) % 3 == 0:
        print(f'Step {step+1}: Figure of Merit = {current_cost.numpy()[0][0]:.5f}')
    if current_cost < best_infidelity:
        best_infidelity = current_cost

print('Optimized Figure of Merit:', best_infidelity.numpy())


## Post-Optimization Analysis

Simulate the dynamics using the optimized pulses, compute the Rabi amplitudes and phases, and analyze state overlaps.

In [None]:
physical_amplitudes_final = propagatorVL.return_physical_amplitudes().numpy()

fig, ax = getStylishFigureAxes(1, 1)
for ind, amplitude in enumerate(tf.transpose(physical_amplitudes_final)):
    PlotPlotter(fig, ax, tlist * 1e9, amplitude, style={'label': labels[ind], 'marker': '', 'linestyle': '-'}, xticks=[0, 250, 500, 750]).draw()
ax.set_xlabel('Time (ns)')
ax.set_ylabel('Pulse Amplitude')
ax.legend(fontsize=6, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.show()


In [None]:
# Calculate Rabi amplitudes and phases
Rabi_i_amplitude = np.sqrt((physical_amplitudes_final[:, 1] * np.cos(np.pi * physical_amplitudes_final[:, 2]))**2 + 
                          (physical_amplitudes_final[:, 1] * np.sin(np.pi * physical_amplitudes_final[:, 2]))**2)
Rabi_r_amplitude = np.sqrt((physical_amplitudes_final[:, 3] * np.cos(np.pi * physical_amplitudes_final[:, 4]))**2 + 
                          (physical_amplitudes_final[:, 3] * np.sin(np.pi * physical_amplitudes_final[:, 4]))**2)
Rabi_i_phase = np.pi * physical_amplitudes_final[:, 2]
Rabi_r_phase = np.pi * physical_amplitudes_final[:, 4]

# Plot raw control pulses (real and imaginary parts)
fig, ax = getStylishFigureAxes(1, 1)
for ind, amplitude in enumerate(tf.transpose(physical_amplitudes_final)):
    PlotPlotter(fig, ax, tlist * 1e9, amplitude, style={'label': labels[ind], 'marker': '', 'linestyle': '-'}).draw()
ax.set_xlabel('Time (ns)')
ax.set_ylabel('Pulse Amplitude')
ax.legend(fontsize=6, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.show()

# Plot computed Rabi amplitudes (magnitudes)
fig, ax = getStylishFigureAxes(1, 1)
PlotPlotter(fig, ax, tlist * 1e9, Rabi_i_amplitude, style={'label': r'$|\Omega_i|$', 'marker': '', 'linestyle': '-'}).draw()
PlotPlotter(fig, ax, tlist * 1e9, Rabi_r_amplitude, style={'label': r'$|\Omega_r|$', 'marker': '', 'linestyle': '-'}).draw()
ax.set_xlabel('Time (ns)')
ax.set_ylabel(r'$\widetilde{\Omega}(t)/\Omega_{max}$')
ax.legend(fontsize=6)
plt.show()

# Plot phases
fig, ax = getStylishFigureAxes(1, 1)
PlotPlotter(fig, ax, tlist * 1e9, Rabi_i_phase/np.pi, style={'label': r'$\varphi_i$', 'marker': '', 'linestyle': '-'}).draw()
PlotPlotter(fig, ax, tlist * 1e9, Rabi_r_phase/np.pi, style={'label': r'$\varphi_r$', 'marker': '', 'linestyle': '-'}).draw()
ax.set_xlabel('Time (ns)')
ax.set_ylabel(r'$\varphi/\pi$')
ax.legend(fontsize=6)
plt.show()

In [None]:
# Compare with analytical solution

def Rabi_1_analytical_adiabatic_cz(t, args):  # Rabi connecting 1 to i
    return np.sin(np.pi / (2 * args['tau']) * t)

def Rabi_2_analytical_adiabatic_cz(t, args):  # Rabi connecting i to r
    return abs(np.cos(np.pi / (2 * args['tau']) * t))

fig, ax = getStylishFigureAxes(1, 1)
PlotPlotter(fig, ax, tlist * 1e9, Rabi_i_amplitude, xlabel='Time (ns)', ylabel=r'$\widetilde{\Omega}(t)/\Omega_{max}$', style={'label': '$\Omega_1$', 'linewidth': 0.75, 'linestyle': '-', 'marker': ''}, xticks=[0, 200, 400, 600], yticks=[0, 0.5, 1]).draw()
PlotPlotter(fig, ax, tlist * 1e9, Rabi_r_amplitude, style={'label': '$\Omega_2$', 'linewidth': 0.75, 'linestyle': '-', 'marker': ''}).draw()
PlotPlotter(fig, ax, tlist * 1e9, Rabi_2_analytical_adiabatic_cz(tlist, {'tau': t_f / 2}), style={'label': '$\Omega_2$ (analytical)', 'linewidth': 0.75, 'linestyle': '--', 'color': '#4c7ca4', 'marker': ''}).draw()
PlotPlotter(fig, ax, tlist * 1e9, Rabi_1_analytical_adiabatic_cz(tlist, {'tau': t_f / 2}), style={'label': '$\Omega_1$ (analytical)', 'linewidth': 0.75, 'linestyle': '--', 'color': '#c53a3c', 'marker': ''}).draw()
plt.show()

### State Overlap Analysis

The following section computes the overlap of the dark state with the actual state over the evolution, as well as overlaps for various two-qubit state population

In [None]:
# Note: The following analysis assumes that the propagatorVL object contains initial states
# psi_pp, psi_10, and psi_11 for the different two-qubit configurations.

DS_overlap_squared = []
plus_overlap_squared = []
targ_overlap_squared = []
rr_overlap_squared = []
i_overlap_squared = []
state11_11_overlap_squared = []
state11_rr_overlap_squared = []
state11_i_overlap_squared = []
state11_DS_overlap_squared = []

dim = propagatorVL.dim
exps = propagatorVL.exponentials()

# Assume initial two-qubit states are stored in propagatorVL
current_state_pp = propagatorVL.psi_pp  # For |++> initial state
current_state_10 = propagatorVL.psi_10  # For |10> initial state
current_state_11 = propagatorVL.psi_11  # For |11> initial state

g_0, g_1, i, r = propagatorVL.nLevelAtomBasis(dim)

for ind in range(propagatorVL.no_of_steps + 2 * propagatorVL.padding):
    # Compute instantaneous drives for |10> state
    current_Rabi_i_real = propagatorVL.Rabi_i * physical_amplitudes_final[ind, 1] * np.cos(np.pi * physical_amplitudes_final[ind, 2])
    current_Rabi_i_imag = propagatorVL.Rabi_i * physical_amplitudes_final[ind, 1] * np.sin(np.pi * physical_amplitudes_final[ind, 2])
    current_Rabi_r_real = propagatorVL.Rabi_r * physical_amplitudes_final[ind, 3] * np.cos(np.pi * physical_amplitudes_final[ind, 4])
    current_Rabi_r_imag = propagatorVL.Rabi_r * physical_amplitudes_final[ind, 3] * np.sin(np.pi * physical_amplitudes_final[ind, 4])

    norm_dark = 1 / np.sqrt(current_Rabi_i_real**2 + current_Rabi_i_imag**2 + current_Rabi_r_real**2 + current_Rabi_r_imag**2)
    dark_state = norm_dark * (
        (current_Rabi_r_real + 1j*current_Rabi_r_imag) * tensor(g_1, g_0) - 
        (current_Rabi_i_real - 1j*current_Rabi_i_imag) * tensor(r, g_0)
    )
    DS_overlap = tf.linalg.matmul(current_state_10, tf.reshape(tf.cast(dark_state, tf.complex128), (-1,1)), adjoint_a=True)[0,0]
    DS_overlap_squared.append(tf.math.real(DS_overlap * tf.math.conj(DS_overlap)).numpy())

    plus = 1/np.sqrt(2) * (g_0 + g_1)
    plus_overlap = tf.linalg.matmul(current_state_pp, tf.cast(tensor(plus, plus), tf.complex128), adjoint_a=True)[0,0]
    plus_overlap_squared.append(tf.math.real(plus_overlap * tf.math.conj(plus_overlap)).numpy())

    targ_two_qubit = 1/2 * (tensor(g_0, g_0) + tensor(g_1, g_0) + tensor(g_0, g_1) - tensor(g_1, g_1))
    targ_overlap = tf.linalg.matmul(current_state_pp, tf.cast(targ_two_qubit, tf.complex128), adjoint_a=True)[0,0]
    targ_overlap_squared.append(tf.math.real(targ_overlap * tf.math.conj(targ_overlap)).numpy())

    rr_overlap = tf.linalg.matmul(current_state_pp, tf.cast(tensor(r, r), tf.complex128), adjoint_a=True)[0, 0]
    rr_overlap_squared.append(tf.math.real(rr_overlap * tf.math.conj(rr_overlap)).numpy())
    i_overlap = tf.linalg.matmul(current_state_pp, tf.cast(tensor(i, identity(dim)), tf.complex128), adjoint_a=True)[0, 0]
    i_overlap_squared.append(tf.math.real(i_overlap * tf.math.conj(i_overlap)).numpy())

    state11_11 = tf.linalg.matmul(current_state_11, tf.cast(tensor(g_1, g_1), tf.complex128), adjoint_a=True)[0,0]
    state11_11_overlap_squared.append(tf.math.real(state11_11 * tf.math.conj(state11_11)).numpy())

    state11_rr = tf.linalg.matmul(current_state_11, tf.cast(tensor(r, r), tf.complex128), adjoint_a=True)[0,0]
    state11_rr_overlap_squared.append(tf.math.real(state11_rr * tf.math.conj(state11_rr)).numpy())

    state11_i = tf.linalg.matmul(current_state_pp, tf.cast(tensor(g_1, g_0), tf.complex128), adjoint_a=True)[0,0]
    state11_i_overlap_squared.append(tf.math.real(state11_i * tf.math.conj(state11_i)).numpy())

    state11_DS = tf.linalg.matmul(current_state_11, dark_state, adjoint_a=True)[0,0]
    state11_DS_overlap_squared.append(tf.math.real(state11_DS * tf.math.conj(state11_DS)).numpy())

    # Propagate the states
    current_state_10 = tf.linalg.matmul(exps[ind, :propagatorVL.dim**2, :propagatorVL.dim**2], current_state_10)
    current_state_pp = tf.linalg.matmul(exps[ind, :propagatorVL.dim**2, :propagatorVL.dim**2], current_state_pp)
    current_state_11 = tf.linalg.matmul(exps[ind, :propagatorVL.dim**2, :propagatorVL.dim**2], current_state_11)


In [None]:
# Propagator |01>
fig, ax = getStylishFigureAxes(1, 1)
xticks = [0, 200, 400, 600]
PlotPlotter(fig, ax, tlist/1e-9, plus_overlap_squared, xlabel='Time (ns)', ylabel=r'$\langle P \rangle$', title='Bare State Overlap: Propagator |10>', style={'label': r'$\rho_{++}$', 'marker': '', 'linestyle': '-', 'linewidth': 0.75}, xticks=xticks).draw()
PlotPlotter(fig, ax, tlist/1e-9, targ_overlap_squared, style={'label': r'$\rho_{targ}$', 'marker': '', 'linestyle': '-', 'linewidth': 0.75}).draw()
PlotPlotter(fig, ax, tlist/1e-9, i_overlap_squared, style={'label': r'$\rho_{i}$', 'marker': '', 'linestyle': '-', 'linewidth': 0.75}).draw()
PlotPlotter(fig, ax, tlist/1e-9, rr_overlap_squared, style={'label': r'$\rho_{rr}$', 'marker': '', 'linestyle': '-', 'linewidth': 0.75}).draw()
ax.legend(fontsize=6)
plt.show()

# Propagator |11>
fig, ax = getStylishFigureAxes(1, 1)
PlotPlotter(fig, ax, tlist/1e-9, state11_11_overlap_squared, xlabel='Time (ns)', ylabel=r'$\langle P \rangle$', title='Bare State Overlap: Propagator |11>', style={'label': r'$\rho_{11}$', 'marker': '', 'linestyle': '-', 'linewidth': 0.75}, xticks=xticks).draw()
PlotPlotter(fig, ax, tlist/1e-9, state11_i_overlap_squared, style={'label': r'$\rho_{i}$', 'marker': '', 'linestyle': '-', 'linewidth': 0.75}).draw()
PlotPlotter(fig, ax, tlist/1e-9, state11_rr_overlap_squared, style={'label': r'$\rho_{rr}$', 'marker': '', 'linestyle': '-', 'linewidth': 0.75}).draw()
ax.legend(fontsize=6)
plt.show()

# Plot dark state overlap throughout the pulse duration:
fig, ax = getStylishFigureAxes(1, 1)
PlotPlotter(fig, ax, tlist/1e-9, DS_overlap_squared, xlabel='Time (ns)', ylabel=r'$|\langle D|\psi(t)\rangle|^2$', title='Dark State Overlap', style={'marker': '', 'linestyle': '-', 'linewidth': 0.75}, xticks=xticks).draw()
plt.show()