In [1]:
import qutip as qt
import numpy as np

import jax
import jax.numpy as jnp
from diffrax import Dopri5, PIDController

In [2]:
import goat
from objective import Objective

In [3]:
H_drft = qt.sigmaz()
H_ctrl = [qt.sigmax(), qt.sigmay()]

initial = qt.basis(2, 0)
target = qt.basis(2, 1)

evo_time = np.linspace(0, np.pi, 100)

# Python

In [4]:
def sin(t, p):
    return p[0] * np.sin(p[1] * t + p[2])

def grad_sin(t, p, idx):
    if idx==0: return np.sin(p[1] * t + p[2])
    if idx==1: return p[0] * np.cos(p[1] * t + p[2]) * t
    if idx==2: return p[0] * np.cos(p[1] * t + p[2])
    if idx==3: return p[0] * np.cos(p[1] * t + p[2]) * p[1] # w.r.t. time

evo = [H_drft,
       [H_ctrl[0], sin, {"grad": grad_sin}],
       [H_ctrl[1], sin, {"grad": grad_sin}]]

In [5]:
%%timeit
goat.optimize_pulses(
    objectives=[Objective(initial, evo, target)],
    pulse_options={
        sin: {
            "guess": np.zeros(3),
            "bounds": np.array([(-10, 10)]),
        }
    },
    tlist=evo_time,
    kwargs={
        "optimizer": {
            "disp": False,
            "niter": 100,  # = 0: no global search
            "seed": 123,  # deterministic results
            # num of iters without improvement before stopping
            "niter_success": 10
        },
        "integrator": {
            "progress_bar": False
        }
    }
)

1.36 s ± 1.34 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Cython

In [6]:
%load_ext Cython

In [7]:
%%cython
import numpy as np 

cpdef double sin_cy(double t, double[:] p):
    return p[0] * np.sin(p[1] * t + p[2])

cpdef double grad_sin_cy(double t, double[:] p, int idx):
    if idx==0: return np.sin(p[1] * t + p[2])
    if idx==1: return p[0] * np.cos(p[1] * t + p[2]) * t
    if idx==2: return p[0] * np.cos(p[1] * t + p[2])
    if idx==3: return p[0] * np.cos(p[1] * t + p[2]) * p[1] # w.r.t. time

In [8]:
evo_cy = [H_drft,
          [H_ctrl[0], sin_cy, {"grad": grad_sin_cy}],
          [H_ctrl[1], sin_cy, {"grad": grad_sin_cy}]]

In [9]:
%%timeit
goat.optimize_pulses(
    objectives=[Objective(initial, evo_cy, target)],
    pulse_options={
        sin_cy: {
            "guess": np.zeros(3),
            "bounds": np.array([(-10, 10)]),
        }
    },
    tlist=evo_time,
    kwargs={
        "optimizer": {
            "disp": False,
            "niter": 100,  # = 0: no global search
            "seed": 123,  # deterministic results
            # num of iters without improvement before stopping
            "niter_success": 10
        },
        "integrator": {
            "progress_bar": False
        }
    }
)

1.44 s ± 2.21 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# JAX

In [10]:
def sin_jax(t, p):
    return p[0] * jnp.sin(p[1] * t + p[2])

def grad_sin_jax(t, p, idx):
    dt, dp = jax.grad(sin_jax, argnums=(0, 1))(t, p)
    return jnp.concatenate((dp, dt), axis=None)[idx]

evo_jax = [H_drft,
          [H_ctrl[0], sin_jax, {"grad": grad_sin_jax}],
          [H_ctrl[1], sin_jax, {"grad": grad_sin_jax}]]

In [11]:
%%timeit
goat.optimize_pulses(
    objectives=[Objective(initial, evo_jax, target)],
    pulse_options={
        sin_jax: {
            "guess": np.zeros(3),
            "bounds": np.array([(-10, 10)]),
        }
    },
    tlist=evo_time,
    kwargs={
        "optimizer": {
            "disp": False,
            "niter": 100,  # = 0: no global search
            "seed": 123,  # deterministic results
            # num of iters without improvement before stopping
            "niter_success": 10
        },
        "integrator": {
            "progress_bar": False,
            "method": "diffrax",
            "stepsize_controller": PIDController(rtol=1e-5, atol=1e-5),
            "solver": Dopri5()
        }
    }
)

9.28 s ± 193 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
