### Optimal Control on Simple Systems: Driven JC Model

In [None]:
%load_ext autoreload
import jax.numpy as jnp
from jaxtyping import Array
import matplotlib.pyplot as plt
import dynamiqs as dq
import strawberryfields as sf
import os
from controllers import ControlVector, SinusoidalControl, ConstantControl
from optimizers import ClosedQuantumSystem, OptimalController
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']='.90'

#### System Setup

In [None]:
N_res = 4
N_qub = 2
omega_res = 3
omega_qub = 2.4
g = .1

a = dq.tensor(dq.destroy(N_res), dq.eye(N_qub))
b = dq.tensor(dq.eye(N_res),dq.destroy(N_qub))
Z = dq.tensor(dq.eye(N_res), dq.number(N_qub))
H_0 = omega_res * dq.dag(a)@a + omega_qub * dq.dag(b)@b 
H_1 = (dq.dag(a)@b + a @ dq.dag(b)) # coupling
H_2 = 1./jnp.sqrt(2)*(a + dq.dag(a)) # X
H_3 = 1j/jnp.sqrt(2)*(dq.dag(a) - a) # P
psi_0 = dq.tensor(dq.fock(N_res,0),dq.fock(N_qub,0))


#### Control Setups

In [None]:
coupling_control = ConstantControl(
    k=jnp.array(.1)
)
sine_drive_X = SinusoidalControl(
    a=jnp.array([.5]),
    omega=jnp.array([3.0]),
    phi=jnp.array([0.0])
)
sine_drive_P = SinusoidalControl(
    a=jnp.array([.5]),
    omega=jnp.array([3.5]),
    phi=[0.0]
)

#### Optimizer Setup

In [None]:
jc_system = ClosedQuantumSystem(
    dim=N_res*N_qub,
    H_0=H_0,
    H_M=[H_1,H_2],
)

jc_controls = ControlVector(
    [
        coupling_control,
        sine_drive_X,
        # sine_drive_P
        ]
    )

def final_penalty(psi_tF: Array) -> float:
    return 1 - dq.expect(Z, psi_tF)

def statewise_penalty(psi_ti: Array) -> float:
    return 0

def control_penalty(u_m_ti: Array) -> float:
    return 0

controller = OptimalController(
    system=jc_system,
    controls=jc_controls,
    y0=psi_0,
    duration=15.0,
    dt_start=.01,
    dt_save=.1,
    y_final=final_penalty,
    y_statewise=statewise_penalty,
    u_statewise=control_penalty,
)

#### Pre-Training

In [None]:
fig, ax = plt.subplots()
controller.plot(
    ax=ax,
    exp_ops=[Z],
    exp_names=[r"$\langle Z\rangle$"],
)
fig.legend()
fig.show()

#### Training

In [None]:
new_controller = controller.optimize(
    N_steps=10,
    learning_rate=.1,
    verbosity=2,
)

#### Post-Training

In [None]:

fig_opt, ax_opt = plt.subplots()
new_controller.plot(
    ax_opt,
    exp_ops=[Z],
    exp_names=[r"$\langle Z\rangle$"]
)
fig_opt.legend()
fig_opt.show()