In [1]:
import tensorcircuit as tc
import optax
import jax.numpy as jnp

In [2]:
K = tc.set_backend('jax')

In [3]:
c = tc.Circuit(2)
# c.H(0)
# c.CNOT(0,1)
theta = jnp.array([0.1])
c.rx(0, theta=theta)
print(c.wavefunction())
print(c.expectation([tc.gates.z(), [0]]))
print(c.sample())



[0.99875027+0.j         0.        +0.j         0.        -0.04997917j
 0.        +0.j        ]
(0.9950042+0j)
(DeviceArray([0., 0.], dtype=float32), DeviceArray(0.997502, dtype=float32))


In [4]:
@K.jit
def loss(theta):
    c = tc.Circuit(2)
    c.rx(0, theta=theta)
    return jnp.real(c.expectation([tc.gates.z(), [0]]))**2

In [5]:
opt = optax.adam(learning_rate=0.1)
opt_state = opt.init(theta)

In [6]:
for i in range(100):
    loss_val, grad_val = K.value_and_grad(loss)(theta)
    updates, opt_state = opt.update(grad_val, opt_state, theta)
    theta = optax.apply_updates(theta, updates)
    print(loss_val, theta)

0.9900333 [0.19999933]
0.96053064 [0.29672948]
0.9145057 [0.39310214]
0.85326844 [0.49003172]
0.7784849 [0.5879161]
0.69238746 [0.6869318]
0.59783113 [0.7871142]
0.49828398 [0.88837236]
0.3977522 [0.99047184]
0.30062768 [1.0929972]
0.21143986 [1.1952951]
0.13449739 [1.2964073]
0.07341866 [1.3950065]
0.030585047 [1.4893711]
0.0066154273 [1.5774459]
4.4215773e-05 [1.65703]
0.0074178264 [1.7260747]
0.02391821 [1.7830018]
0.04435927 [1.8269169]
0.06417589 [1.8576481]
0.08005169 [1.8756391]
0.09008593 [1.8817768]
0.09363124 [1.8772259]
0.09099664 [1.8633027]
0.0831475 [1.8413959]
0.07145425 [1.8129225]
0.057488322 [1.779308]
0.042850673 [1.7419775]
0.029017888 [1.7023505]
0.01720689 [1.6618298]
0.008264225 [1.621781]
0.0025971862 [1.5835]
0.00016137381 [1.5481712]
0.00051181106 [1.5168201]
0.0029106103 [1.4902707]
0.006470366 [1.4691136]
0.010303806 [1.4536899]
0.013651336 [1.4440937]
0.015967837 [1.4401888]
0.016961563 [1.4416358]
0.016589856 [1.4479252]
0.015021493 [1.4584107]
0.012577425