In [2]:
import qml_essentials.yaqsi as ys
import qml_essentials.operations as op
import numpy as np
import optax as otx
import jax

In [3]:
def circuit():
    op.H(wires=0)
    op.CX(wires=[0, 1])

obs = [op.PauliZ(wires=0), op.PauliZ(wires=1)]

In [4]:
yss = ys.Script(circuit)
yss.execute(type="probs", obs=obs)

Array([0.5, 0. , 0. , 0.5], dtype=float64)

In [5]:
n_qubits = 20

def circuit():
    op.H(wires=0)
    for i in range(n_qubits):
        op.CX(wires=[i, i + 1])

obs = [op.PauliZ(wires=i) for i in range(n_qubits)]
yss = ys.Script(circuit)
yss.execute(type="probs", obs=obs)

Array([0.5, 0. , 0. , ..., 0. , 0. , 0.5], dtype=float64)

In [22]:
n_qubits = 1

def circuit(phi, theta, omega):
    op.Rot(phi, theta, omega, wires=0)
    op.Rot(np.pi, 1/2*np.pi, 1/4*np.pi, wires=0)

obs = [op.PauliZ(wires=i) for i in range(n_qubits)]
yss = ys.Script(circuit)
yss.execute(type="expval", obs=obs, args=(np.pi, 1/2*np.pi, 1/4*np.pi))

Array([0.70710678], dtype=float64)

In [24]:
def cost_fct(params):
    phi, theta, omega = params
    return yss.execute(type="expval", obs=[op.PauliZ(0)], args=(phi, theta, omega))[0]

params = jax.numpy.array([0.1, 0.2, 0.3])
opt = otx.adam(0.01)
opt_state = opt.init(params)

print(params)
for epoch in range(1, 101):
    grads = jax.grad(cost_fct)(params)
    updates, opt_state = opt.update(grads, opt_state, params)
    params = otx.apply_updates(params, updates)

print(params)

[0.1 0.2 0.3]
[ 0.1        -0.77972245 -0.01600711]
