# Quantum Optimal Control: State Transfer


In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import time

from arc import Rubidium87
from qutip import *

from quantum_optimal_control.state_transfer.propagator_vl import PropagatorVL
from quantum_optimal_control.toolkits.plotting_helper import getStylishFigureAxes, PlotPlotter

tf.random.set_seed(1)

## System Parameters Setup


In [None]:
atom = Rubidium87()

# Intermediate excited state: 6P3/2
n_i = 6
l_i = 1
j_i = 1.5
T_i = atom.getStateLifetime(n_i, l_i, j_i)  # Lifetime of the intermediate state
Gamma_ig = 1/T_i  # Decay rate of the intermediate state

# Rydberg state: 70 S1/2
n_r = 70
l_r = 0
j_r = 0.5
# Compute total lifetime including blackbody stimulation (T_rTot), radiative lifetime at 0 K (T_rRad),
# and the radiative transition lifetime (T_ri) for decay to the intermediate state
T_rTot = atom.getStateLifetime(n_r, l_r, j_r, temperature=300, includeLevelsUpTo=n_r+50)  
T_rRad = atom.getStateLifetime(n_r, l_r, j_r, temperature=0)  
T_ri = 1/atom.getTransitionRate(n_r, l_r, j_r, n_i, l_i, j_i, temperature=0)  
# Calculate effective lifetimes for radiative decay (T_rgp) and blackbody stimulated transitions (T_rBB)
T_rgp = 1/(1/T_rRad - 1/T_ri)  
T_rBB = 1/(1/T_rTot - 1/T_rRad)  

Gamma_ri = 1/T_ri
Gamma_rrp = 1/T_rBB
Gamma_rgp = 1/T_rgp
Gamma_rTot = Gamma_ri + Gamma_rrp + Gamma_rgp  # Total decay rate from the Rydberg state

# Control parameters
Rabi_1 = 2 * np.pi * 127e6
Rabi_2 = 2 * np.pi * 127e6
Delta_1 = 0
del_total = 0

# Time grid
t_0 = 0
t_f = 100e-9
nt = 1000
delta_t = (t_f - t_0) / nt
tlist = np.linspace(t_0, t_f, nt)

# Basis functions for control amplitudes and optimization parameters
input_dim = 10
num_iters = 1250
learn_rate = 0.1

## Initialize Propagator

Instantiate the `PropagatorVL` class and set the initial guess for the control amplitudes.

In [None]:
propagatorVL = PropagatorVL(
    input_dim, nt, delta_t, del_total, Delta_1, Rabi_1, Rabi_2,
    Gamma_rTot, Gamma_ig
)

# There are 5 control amplitudes: Delta_1, Re/Im of Rabi_1, Re/Im of Rabi_2
numb_ctrl_amps = 5

propagatorVL.ctrl_amplitudes_a.assign(
    tf.random.uniform([input_dim, numb_ctrl_amps], -1, 1, dtype=tf.float64)
)
propagatorVL.ctrl_amplitudes_b[:, 0].assign(
    tf.random.uniform([input_dim], -1, 1, dtype=tf.float64)
)
propagatorVL.ctrl_amplitudes_b[:, 1:3].assign(
    tf.random.uniform([input_dim, 2], 0, 1, dtype=tf.float64)
)
propagatorVL.ctrl_amplitudes_b[:, 3:5].assign(
    tf.random.uniform([input_dim, 2], -1, 0, dtype=tf.float64)
)

propagatorVL.ctrl_amplitudes_c.assign(
    tf.random.uniform([input_dim, numb_ctrl_amps], 0, 0.1, dtype=tf.float64)
)

initial_infidelity = propagatorVL.target()
print('Initial Figure of Merit:', initial_infidelity.numpy())

## Visualize Initial Control Pulses

In [None]:
physical_amplitudes_initial = propagatorVL.return_physical_amplitudes().numpy()
labels = [r"$\Delta_1$", r"$\mathrm{Re}(\Omega_1)$", r"$\mathrm{Im}(\Omega_1)$", r"$\mathrm{Re}(\Omega_2)$", r"$\mathrm{Im}(\Omega_2)$"]
colors = ['k', 'm--', 'b:', 'c--', 'r:']

fig, ax = getStylishFigureAxes(1, 1)
ind = 0
for amplitude in tf.transpose(physical_amplitudes_initial):
    PlotPlotter(fig, ax, tlist * 1e9, amplitude, style={'label': labels[ind], 'marker': '', 'linestyle': '-', 'linewidth': 1}).draw()
    ind += 1
ax.set_xlabel('Time (ns)')
ax.set_ylabel('Pulse Amplitude')
ax.set_yticks([-1, 0, 1])
ax.legend(fontsize=4, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.show()

## Optimize Control Pulses

Apply gradient-based optimization to minimize the target cost function.

In [None]:
optimizer = tf.keras.optimizers.Adam(learn_rate)

@tf.function
def optimization_step():
    with tf.GradientTape() as tape:
        cost = propagatorVL.target()
    grads = tape.gradient(cost, [
        propagatorVL.ctrl_amplitudes_a,
        propagatorVL.ctrl_amplitudes_b,
        propagatorVL.ctrl_amplitudes_c
    ])
    optimizer.apply_gradients(zip(grads, [
        propagatorVL.ctrl_amplitudes_a,
        propagatorVL.ctrl_amplitudes_b,
        propagatorVL.ctrl_amplitudes_c
    ]))
    return propagatorVL.target()

best_infidelity = 1.0
for step in range(num_iters):
    current_cost = optimization_step()
    if (step + 1) % 50 == 0 or step == 0:
        print(f'Step {step+1}: Figure of Merit = {current_cost.numpy()[0][0]:.5f}')
    if current_cost < best_infidelity:
        best_infidelity = current_cost

print('Optimized Figure of Merit:', best_infidelity.numpy())

## Post-Optimization Analysis

Simulate the dynamics using the optimized pulses, compute the Rabi amplitudes, and analyze the dark state overlap.

Below we plot both the computed Rabi amplitudes (magnitude) and the raw real/imaginary components

In [None]:
# Retrieve optimized physical amplitudes
physical_amplitudes_final = propagatorVL.return_physical_amplitudes().numpy()

# Compute Rabi amplitudes from real and imaginary parts
Rabi_1_amplitude = np.sqrt(physical_amplitudes_final[:, 1]**2 + physical_amplitudes_final[:, 2]**2)
Rabi_2_amplitude = np.sqrt(physical_amplitudes_final[:, 3]**2 + physical_amplitudes_final[:, 4]**2)
Rabi_1_phase = np.angle(physical_amplitudes_final[:, 1] + 1j * physical_amplitudes_final[:, 2])
Rabi_2_phase = np.angle(physical_amplitudes_final[:, 3] + 1j * physical_amplitudes_final[:, 4])

# Plot optimized control pulses (raw real and imaginary parts)
fig, ax = getStylishFigureAxes(1, 1)
for ind, amplitude in enumerate(tf.transpose(physical_amplitudes_final)):
    PlotPlotter(fig, ax, tlist * 1e9, amplitude, style={'label': labels[ind], 'marker': '', 'linestyle': '-'}).draw()
ax.set_xlabel('Time (ns)')
ax.set_ylabel('Pulse Amplitude')
ax.legend(fontsize=6, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.show()

# Plot computed Rabi amplitudes (magnitudes)
fig, ax = getStylishFigureAxes(1, 1)
PlotPlotter(fig, ax, tlist * 1e9, Rabi_1_amplitude, style={'label': r'$|\Omega_1|$', 'marker': '', 'linestyle': '-'}).draw()
PlotPlotter(fig, ax, tlist * 1e9, Rabi_2_amplitude, style={'label': r'$|\Omega_2|$', 'marker': '', 'linestyle': '-'}).draw()
ax.set_xlabel('Time (ns)')
ax.set_ylabel(r'$\widetilde{\Omega}(t)/\Omega_{max}$')
ax.legend(fontsize=6)
plt.show()

# Plot computed phases
fig, ax = getStylishFigureAxes(1, 1)
PlotPlotter(fig, ax, tlist * 1e9, Rabi_1_phase / np.pi, style={'label': r'$\varphi_1$', 'marker': '', 'linestyle': '-'}).draw()
PlotPlotter(fig, ax, tlist * 1e9, Rabi_2_phase / np.pi, style={'label': r'$\varphi_2$', 'marker': '', 'linestyle': '-'}).draw()
ax.set_xlabel('Time (ns)')
ax.set_ylabel(r'$\varphi/\pi$')
ax.legend(fontsize=6)
plt.show()

In [None]:
# Compute dark state overlap over time
exps = propagatorVL.exponentials()
g_prime, r_prime, g, i, r = propagatorVL.nLevelAtomBasis(propagatorVL.dim)
current_state = propagatorVL.psi_0
DS_overlap_squared = []
for idx in range(propagatorVL.no_of_steps):
    # Calculate current Rabi frequencies
    current_Rabi_1 = Rabi_1 * physical_amplitudes_final[idx, 1:3]
    current_Rabi_2 = Rabi_2 * physical_amplitudes_final[idx, 3:5]
    norm_dark = 1.0/np.sqrt(np.sum(np.square(current_Rabi_1.real) + np.square(current_Rabi_1.imag) +
                                  np.square(current_Rabi_2.real) + np.square(current_Rabi_2.imag)))
    dark_state = norm_dark * ((current_Rabi_2[0] + 1j*current_Rabi_2[1]) * g - 
                              (current_Rabi_1[0] - 1j*current_Rabi_1[1]) * r)
    overlap = tf.linalg.matmul(current_state, tf.reshape(tf.cast(dark_state, tf.complex128), (-1,1)), adjoint_a=True)[0,0]
    DS_overlap_squared.append(tf.math.real(overlap * tf.math.conj(overlap)).numpy())
    # Propagate state
    current_state = tf.linalg.matmul(exps[idx, :propagatorVL.dim, :propagatorVL.dim], current_state)


In [None]:
# Plot dark state overlap
fig, ax = getStylishFigureAxes(1, 1)
PlotPlotter(fig, ax, tlist * 1e9, DS_overlap_squared, style={'label': 'Dark State Overlap', 'marker': '', 'linestyle': '-', 'linewidth': 1}).draw()
ax.set_xlabel('Time (ns)')
ax.set_ylabel(r'$|\langle D|\psi \rangle|^2$')
plt.show()