In [2]:
# This cell is added by sphinx-gallery
# It can be customized to whatever you like
%matplotlib inline

How to optimize a QML model using JAX and Optax
===============================================

Once you have set up a quantum machine learning model, data to train
with and cost function to minimize as an objective, the next step is to
**perform the optimization**. That is, setting up a classical
optimization loop to find a minimal value of your cost function.

In this example, we'll show you how to use
[JAX](https://jax.readthedocs.io), an autodifferentiable machine
learning framework, and [Optax](https://optax.readthedocs.io/), a suite
of JAX-compatible gradient-based optimizers, to optimize a PennyLane
quantum machine learning model.

![](../_static/demonstration_assets/How_to_optimize_QML_model_using_JAX_and_Optax/socialsthumbnail_large_How_to_optimize_QML_model_using_JAX_and_Optax_2024-01-16.png){.align-center
width="50.0%"}


Set up your model, data, and cost
=================================


Here, we will create a simple QML model for our optimization. In
particular:

-   We will embed our data through a series of rotation gates.
-   We will then have an ansatz of trainable rotation gates with
    parameters `weights`; it is these values we will train to minimize
    our cost function.
-   We will train the QML model on `data`, a `(5, 4)` array, and
    optimize the model to match target predictions given by `target`.


In [12]:
import pennylane as qml
import jax
from jax import numpy as jnp
import optax

n_wires = 5
data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3
targets = jnp.array([-0.2, 0.4, 0.35, 0.2])

dev = qml.device("default.qubit", wires=n_wires)

@qml.qnode(dev)
def circuit(data, weights):
    """Quantum circuit ansatz"""

    # data embedding
    for i in range(n_wires):
        # data[i] will be of shape (4,); we are
        # taking advantage of operation vectorization here
        qml.RY(data[i], wires=i)

    # trainable ansatz
    for i in range(n_wires):
        qml.RX(weights[i, 0], wires=i)
        qml.RY(weights[i, 1], wires=i)
        qml.RX(weights[i, 2], wires=i)
        qml.CNOT(wires=[i, (i + 1) % n_wires])

    # we use a sum of local Z's as an observable since a
    # local Z would only be affected by params on that qubit.
    return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

def my_model(data, weights, bias):
    return circuit(data, weights) + bias

We will define a simple cost function that computes the overlap between
model output and target data, and [just-in-time (JIT)
compile](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html)
it:


In [13]:
@jax.jit
def loss_fn(params, data, targets):
    predictions = my_model(data, params["weights"], params["bias"])
    loss = jnp.sum((targets - predictions) ** 2 / len(data))
    return loss

Note that the model above is just an example for demonstration -- there
are important considerations that must be taken into account when
performing QML research, including methods for data embedding, circuit
architecture, and cost function, in order to build models that may have
use. This is still an active area of research; see our
[demonstrations](https://pennylane.ai/qml/demonstrations) for details.


Initialize your parameters
==========================


Now, we can generate our trainable parameters `weights` and `bias` that
will be used to train our QML model.


In [14]:
weights = jnp.ones([n_wires, 3])
bias = jnp.array(0.)
params = {"weights": weights, "bias": bias}

Plugging the trainable parameters, data, and target labels into our cost
function, we can see the current loss as well as the parameter
gradients:


In [15]:
print(loss_fn(params, data, targets))

print(jax.grad(loss_fn)(params, data, targets))

0.29232618
{'bias': Array(-0.754321, dtype=float32, weak_type=True), 'weights': Array([[-1.9507733e-01,  5.2854650e-02, -4.8925212e-01],
       [-1.9968867e-02, -5.3287148e-02,  9.2290469e-02],
       [-2.7175695e-03, -9.6455216e-05, -4.7958046e-03],
       [-6.3544422e-02,  3.6111072e-02, -2.0519713e-01],
       [-9.0263695e-02,  1.6375928e-01, -5.6426275e-01]], dtype=float32)}


Create the optimizer
====================


We can now use Optax to create an optimizer, and train our circuit.
Here, we choose the Adam optimizer, however [other available
optimizers](https://optax.readthedocs.io/en/latest/api.html) may be used
here.


In [16]:
opt = optax.adam(learning_rate=0.3)
opt_state = opt.init(params)

We first define our `update_step` function, which needs to do a couple
of things:

-   Compute the loss function (so we can track training) and the
    gradients (so we can apply an optimization step). We can do this in
    one execution via the `jax.value_and_grad` function.
-   Apply the update step of our optimizer via `opt.update`
-   Update the parameters via `optax.apply_updates`


In [17]:
def update_step(params, opt_state, data, targets):
    loss_val, grads = jax.value_and_grad(loss_fn)(params, data, targets)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_val

loss_history = []

for i in range(100):
    params, opt_state, loss_val = update_step(params, opt_state, data, targets)

    if i % 5 == 0:
        print(f"Step: {i} Loss: {loss_val}")

    loss_history.append(loss_val)

Step: 0 Loss: 0.2923261821269989
Step: 5 Loss: 0.04476672783493996
Step: 10 Loss: 0.03190239891409874
Step: 15 Loss: 0.03623729571700096
Step: 20 Loss: 0.03370063751935959
Step: 25 Loss: 0.028723960742354393
Step: 30 Loss: 0.02301179990172386
Step: 35 Loss: 0.018715690821409225
Step: 40 Loss: 0.014776434749364853
Step: 45 Loss: 0.010427684523165226
Step: 50 Loss: 0.009645545855164528
Step: 55 Loss: 0.024109993129968643
Step: 60 Loss: 0.00808287039399147
Step: 65 Loss: 0.00760952103883028
Step: 70 Loss: 0.007097803056240082
Step: 75 Loss: 0.006783411838114262
Step: 80 Loss: 0.006902276072651148
Step: 85 Loss: 0.006584083661437035
Step: 90 Loss: 0.006034402176737785
Step: 95 Loss: 0.0049751754850149155


Jitting the optimization loop
=============================


In the above example, we JIT compiled our cost function `loss_fn`.
However, we can also JIT compile the entire optimization loop; this
means that the for-loop around optimization is not happening in Python,
but is compiled and executed natively. This avoids (potentially costly)
data transfer between Python and our JIT compiled cost function with
each update step.


In [9]:
@jax.jit
def update_step_jit(i, args):
    params, opt_state, data, targets, print_training = args

    loss_val, grads = jax.value_and_grad(loss_fn)(params, data, targets)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    def print_fn():
        jax.debug.print("Step: {i}  Loss: {loss_val}", i=i, loss_val=loss_val)

    # if print_training=True, print the loss every 5 steps
    jax.lax.cond((jnp.mod(i, 5) == 0) & print_training, print_fn, lambda: None)

    return (params, opt_state, data, targets, print_training)

@jax.jit
def optimization_jit(params, data, targets, print_training=False):
    opt = optax.adam(learning_rate=0.3)
    opt_state = opt.init(params)

    args = (params, opt_state, data, targets, print_training)
    (params, opt_state, _, _, _) = jax.lax.fori_loop(0, 100, update_step_jit, args)

    return params

Note that we use `jax.lax.fori_loop` and `jax.lax.cond`, rather than a
standard Python for loop and if statement, to allow the control flow to
be JIT compatible. We also use `jax.debug.print` to allow printing to
take place at function run-time, rather than compile-time.


In [10]:
params = {"weights": weights, "bias": bias}
optimization_jit(params, data, targets, print_training=True)

Step: 0  Loss: 0.2923261821269989
Step: 5  Loss: 0.044766709208488464
Step: 10  Loss: 0.03190242126584053
Step: 15  Loss: 0.03623732179403305
Step: 20  Loss: 0.03370067477226257
Step: 25  Loss: 0.028724024072289467
Step: 30  Loss: 0.023011792451143265
Step: 35  Loss: 0.018715735524892807
Step: 40  Loss: 0.014776429161429405
Step: 45  Loss: 0.010427681729197502
Step: 50  Loss: 0.009645579382777214
Step: 55  Loss: 0.024109533056616783
Step: 60  Loss: 0.008082757703959942
Step: 65  Loss: 0.0076092444360256195
Step: 70  Loss: 0.007097671739757061
Step: 75  Loss: 0.006783545948565006
Step: 80  Loss: 0.006901954300701618
Step: 85  Loss: 0.006584125570952892
Step: 90  Loss: 0.006033747456967831
Step: 95  Loss: 0.0049752178601920605


{'bias': Array(-0.75290495, dtype=float32),
 'weights': Array([[ 1.630918  ,  1.5501642 ,  0.6721541 ],
        [ 0.7266173 ,  0.3642349 , -0.7562605 ],
        [ 2.7837987 ,  0.62709916,  3.450068  ],
        [-1.101276  , -0.12706573,  0.89288384],
        [ 1.2723563 ,  1.1062955 ,  2.2205076 ]], dtype=float32)}

Appendix: Timing the two approaches
===================================

We can time the two approaches (JIT compiling just the cost function, vs
JIT compiling the entire optimization loop) to explore the differences
in performance:


In [11]:
from timeit import repeat

def optimization(params, data, targets):
    opt = optax.adam(learning_rate=0.3)
    opt_state = opt.init(params)

    for i in range(100):
        params, opt_state, loss_val = update_step(params, opt_state, data, targets)

    return params

reps = 5
num = 2

times = repeat("optimization(params, data, targets)", globals=globals(), number=num, repeat=reps)
result = min(times) / num

print(f"Jitting just the cost (best of {reps}): {result} sec per loop")

times = repeat("optimization_jit(params, data, targets)", globals=globals(), number=num, repeat=reps)
result = min(times) / num

print(f"Jitting the entire optimization (best of {reps}): {result} sec per loop")

Jitting just the cost (best of 5): 0.7518415104714222 sec per loop
Jitting the entire optimization (best of 5): 0.0035007075057365 sec per loop


In this example, JIT compiling the entire optimization loop is
significantly more performant.

About the authors
=================
