# Quantum Hamilton-Jacobi-Bellman Solving

In [1]:
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
import diffrax as dx
import equinox as eqx
import dynamiqs as dq
from jaxpulse.controllers import *
from jaxpulse.optimizers import *
from jaxpulse.bilinear import *

In [2]:
rerun_baseline = False

### Baseline Gaussian

In [3]:
if rerun_baseline:
    t1 = 10.
    t0 = 0.
    flipper = ClosedQuantumSystem(
        H_0=dq.eye(2),
        H_M=[dq.sigmax()]
    )
    flip_control = ControlVector([
        GaussianControl.std(amp=1.0,mean=t1/2,sigma=1.)
    ])

In [4]:
if rerun_baseline:
    proj1 = dq.fock_dm(2,1)
    flip_opt = OptimalController(
        system=flipper,
        controls=flip_control,
        duration=t1,
        y_final=lambda y: 0,
        y_statewise=lambda y, u, t: jnp.dot(u,u).squeeze() + 1-dq.expect(proj1,y),
        y0=dq.fock(2,0)
    )

In [5]:
if rerun_baseline:
    flip_opt = flip_opt.optimize(
        N_steps=45,
        learning_rate=.1,
        verbosity=2)

In [6]:
if rerun_baseline:
    fig,ax = plt.subplots()
    flip_opt.plot(
        ax=ax,
        exp_ops=[dq.sigmaz()],
        exp_names=["<Z>"]
    )
    ax.legend()

### HJB Implementation

In [7]:
from jaxpulse.bilinear import *
import dynamiqs as dq
import jax.numpy as jnp

In [8]:
# example from paper
A = jnp.array([
    [13./6,5./12],
    [-50./3,-8./3],
],dtype=complex)
B = jnp.array([[-1./8],[0.]],dtype=complex)
F = 1000.*dq.eye(2)
Q = 10.*dq.eye(2)
R = dq.eye(1)
H = jnp.zeros((2,2,1),dtype=complex)
H = H.at[:,0,:].set(jnp.array([[-1.],[0.]],dtype=complex))
H = H.at[:,1,:].set(jnp.array([[0.0],[0.0]],dtype=complex))

t1 = 10.0
Nt=1_000_000


In [9]:
bc = QuantumBilinearController(
    A=A,
    B=B,
    F=F,
    Q=Q,
    R=R,
    H=H,
    Nt=Nt,
    t1=t1,
    dt0=.005
)

In [10]:
import warnings
warnings.filterwarnings('ignore')
x0=jnp.array([[.15],[0.]],dtype=complex)
xr, ur, ts = bc.solve(
    x0=x0)

Jf=Array(15.311608+0.j, dtype=complex64),Jt_integrate=Array(27.346079+0.j, dtype=complex64) Total J=(42.65768814086914+0j)
Jf=Array(7.560088e-17+0.j, dtype=complex64),Jt_integrate=Array(2.0942917+0.j, dtype=complex64) Total J=(2.0942916870117188+0j)
Jf=Array(0.6590983+0.j, dtype=complex64),Jt_integrate=Array(1.5786163+0.j, dtype=complex64) Total J=(2.2377145290374756+0j)
Jf=Array(inf+nanj, dtype=complex64),Jt_integrate=Array(nan+nanj, dtype=complex64) Total J=(nan+nanj)


KeyboardInterrupt: 

In [None]:
plt.plot(ts, xr[0,:,0],label="No Control $x_{-1}(t)$")
plt.plot(ts, xr[1,:,0],label="Linear Control $x_{0}(t)$")
# plt.plot(ts, ur[0,:,0],label="No Control $u_{-1}(t)$")
# plt.plot(ts, ur[1,:,0],label="Linear Control $u(t)$")
for i in range(2,6):
    plt.plot(ts, xr[i,:,0],label="$x_{" + str(i) + "}(t)$")
    # plt.plot(ts, ur[i,:,0],label="$u_{" + str(i) + "}(t)$")
plt.legend()
plt.xlabel("t")
