In [2]:
import jax
import jax.numpy as jnp
import pennylane as qml

# Added to silence some warnings.
jax.config.update("jax_enable_x64", True)

dev = qml.device("default.mixed", wires=2)


In [4]:
@qml.qnode(dev, interface="jax")
def circuit(param):
    # These two gates represent our QML model.
    qml.RX(param, wires=0)
    qml.CNOT(wires=[0, 1])

    # The expval here will be the "cost function" we try to minimize.
    # Usually, this would be defined by the problem we want to solve,
    # but for this example we'll just use a single PauliZ.
    return qml.expval(qml.PauliZ(0))

print("\nGradient Descent")
print("---------------")

# We use jax.grad here to transform our circuit method into one
# that calcuates the gradient of the output relative to the input.

grad_circuit = jax.grad(circuit)
print(f"grad_circuit(jnp.pi / 2): {grad_circuit(jnp.pi / 2):0.3f}")

# We can then use this grad_circuit function to optimize the parameter value
# via gradient descent.
param = 0.123  # Some initial value.

print(f"Initial param: {param:0.3f}")
print(f"Initial cost: {circuit(param):0.3f}")

for _ in range(100):  # Run for 100 steps.
    param -= grad_circuit(param)  # Gradient-descent update.

print(f"Tuned param: {param:0.3f}")
print(f"Tuned cost: {circuit(param):0.3f}")



Gradient Descent
---------------
grad_circuit(jnp.pi / 2): -1.000
Initial param: 0.123
Initial cost: 0.992
Tuned param: 3.142
Tuned cost: -1.000


In [21]:
import pennylane as qml
from jax import numpy as jnp
import jax
import optax

learning_rate = 0.15

dev = qml.device("default.qubit", wires=1, shots=None)
def f():
    qml.X(wires=0)
@jax.jit
@qml.qnode(dev, interface="jax")
def energy(a):
    f()
    qml.RX(a, wires=0)
    return qml.expval(qml.PauliZ(0))

optimizer = optax.adam(learning_rate)

params = jnp.array(0.5)
opt_state = optimizer.init(params)

for _ in range(200):
    grads = jax.grad(energy)(params)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)


In [None]:
import random 
import numpy as np
random.uniform(, np.pi)

2.561254757251045