In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
import jax
import numpy as np
import jax.numpy as jnp
import tensorcircuit as tc
K = tc.set_backend("jax")
from magic_game import bit_to_num
from quantum_model import Quantum_Strategy
from optax import adam

Please first ``pip install -U qiskit`` to enable related functionality in translation module


In [2]:
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
weight_A = jax.random.normal(key, (4, 16))
weight_B = jax.random.normal(subkey, (4, 16))

2024-04-12 23:01:51.100186: W external/xla/xla/service/gpu/nvptx_compiler.cc:742] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.4.99). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [3]:
# two qubit Pauli basis
I = jnp.array([[1, 0], [0, 1]], dtype=jnp.complex64)
X = jnp.array([[0, 1], [1, 0]], dtype=jnp.complex64)
Y = jnp.array([[0, -1j], [1j, 0]], dtype=jnp.complex64)
Z = jnp.array([[1, 0], [0, -1]], dtype=jnp.complex64)
basis = jnp.array([jnp.kron(a, b) for a in [I, X, Y, Z] for b in [I, X, Y, Z]])
bit_to_op = jnp.array([(I + Z) / 2, (I - Z) / 2])

In [4]:
def weight_to_hamiltonian(weight):
    return jnp.einsum("i,ijk->jk", weight, basis)

In [5]:
def get_nll(n, weights_A, weights_B, inputs, targets):
    # n: number of games
    # weights_A: (4, 16)
    # weights_B: (4, 16)
    # inputs: (4 * n,)
    # targets: (4 * n,)
    
    inputs = bit_to_num(inputs.reshape(-1, 2))
    
    c = tc.Circuit(4 * n)
    # Bell state
    for i in range(n):
        c.H(2 * i)
        c.CNOT(2 * i, 2 * n + 2 * i)
        c.H(2 * i + 1)
        c.CNOT(2 * i + 1, 2 * n + 2 * i + 1)
    
    # U(weight)
    for i in range(n):
        c.EXP(2 * i, 2 * i + 1, theta=1, hamiltonian=weight_to_hamiltonian(weights_A[inputs[i]]))
        c.EXP(2 * n + 2 * i, 2 * n + 2 * i + 1, theta=1, hamiltonian=weight_to_hamiltonian(weights_B[inputs[n + i]]))
        
    # measure likelihood
    ops = [(bit_to_op[targets[i]], [i,]) for i in range(len(targets))]
    exp = c.expectation(*ops)
    
    return -jnp.log(jnp.real(exp) + 1e-10)

In [6]:
def sample(n, weights_A, weights_B, inputs):
    inputs = bit_to_num(inputs.reshape(-1, 2))
    
    c = tc.Circuit(4 * n)
    # Bell state
    for i in range(n):
        c.H(2 * i)
        c.CNOT(2 * i, 2 * n + 2 * i)
        c.H(2 * i + 1)
        c.CNOT(2 * i + 1, 2 * n + 2 * i + 1)
    
    # U(weight)
    for i in range(n):
        c.EXP(2 * i, 2 * i + 1, theta=1, hamiltonian=weight_to_hamiltonian(weights_A[inputs[i]]))
        c.EXP(2 * n + 2 * i, 2 * n + 2 * i + 1, theta=1, hamiltonian=weight_to_hamiltonian(weights_B[inputs[n + i]]))
        
    # sample a bit string
    sample = c.measure(*range(4 * n), with_prob=False)[0]
    
    return sample

sample_vmap = jax.vmap(sample, in_axes=(None, None, None, 0))

In [7]:
get_nll_vmap = jax.vmap(get_nll, in_axes=(None, None, None, 0, 0), out_axes=0)
get_nll_batch = lambda n, weights_A, weights_B, inputs, targets: jnp.mean(get_nll_vmap(n, weights_A, weights_B, inputs, targets))
grad_nll = jax.grad(get_nll_batch, argnums=(1, 2))

In [8]:
inputs = jnp.zeros((4,), dtype=jnp.int16)
targets = jnp.array([1, 0, 1, 0], dtype=jnp.int16)
get_nll(1, weight_A, weight_B, inputs, targets)

Array(2.9956663, dtype=float32)

In [9]:
n = 1
test_size = 10000
data = np.load(f"./data/data_{n}.npz")
data_size = len(data['X'])
x_train = data['X']; y_train = data['Y']
x_test = np.random.randint(0, 2, (test_size, x_train.shape[1]))

In [12]:
batch_size = 32
n_epochs = 100
n_batches = data_size // batch_size
lr = 1e-2
opt = adam(lr)
opt_state = opt.init((weight_A, weight_B))

for epoch in range(n_epochs):
    for i in range(n_batches):
        x_batch = x_train[i * batch_size: (i + 1) * batch_size]
        y_batch = y_train[i * batch_size: (i + 1) * batch_size]
        grad_A, grad_B = grad_nll(n, weight_A, weight_B, x_batch, y_batch)
        # updates, opt_state = opt.update((grad_A, grad_B), opt_state)
        weight_A = weight_A - lr * grad_A
        weight_B = weight_B - lr * grad_B
        print(f"Epoch {epoch}, batch {i}, loss {get_nll_batch(n, weight_A, weight_B, x_batch, y_batch)}")

Epoch 0, batch 0, loss 3.062504291534424
Epoch 0, batch 1, loss 2.95979905128479
Epoch 0, batch 2, loss 3.03666353225708
Epoch 0, batch 3, loss 3.407806873321533
Epoch 0, batch 4, loss 3.1409687995910645
Epoch 0, batch 5, loss 3.285081148147583
Epoch 0, batch 6, loss 3.2066898345947266
Epoch 0, batch 7, loss 3.138278007507324
Epoch 0, batch 8, loss 2.9444949626922607
Epoch 0, batch 9, loss 3.1234118938446045
Epoch 0, batch 10, loss 3.202165126800537
Epoch 0, batch 11, loss 3.304335594177246


Epoch 0, batch 12, loss 2.9373021125793457
Epoch 0, batch 13, loss 3.1398775577545166
Epoch 0, batch 14, loss 3.104675769805908
Epoch 0, batch 15, loss 3.284898281097412
Epoch 0, batch 16, loss 2.9489941596984863
Epoch 0, batch 17, loss 3.228365421295166
Epoch 0, batch 18, loss 3.0252938270568848
Epoch 0, batch 19, loss 3.1636557579040527
Epoch 0, batch 20, loss 3.3313472270965576
Epoch 0, batch 21, loss 3.441865921020508
Epoch 0, batch 22, loss 2.8106260299682617
Epoch 0, batch 23, loss 2.753493309020996
Epoch 0, batch 24, loss 2.8613219261169434
Epoch 0, batch 25, loss 3.100783348083496
Epoch 0, batch 26, loss 3.0145678520202637
Epoch 0, batch 27, loss 2.7095537185668945
Epoch 0, batch 28, loss 2.7729287147521973
Epoch 0, batch 29, loss 3.1637308597564697
Epoch 0, batch 30, loss 2.75455379486084
Epoch 0, batch 31, loss 3.0581605434417725
Epoch 0, batch 32, loss 3.195115566253662
Epoch 0, batch 33, loss 2.904262065887451
Epoch 0, batch 34, loss 2.7759616374969482
Epoch 0, batch 35, lo

KeyboardInterrupt: 

In [13]:
pred = sample_vmap(n, weight_A, weight_B, x_test)
qs = Quantum_Strategy(n)
check = qs.check_input_output(x_test, pred, flatten=True)
score = np.mean(check)
print(f"Score: {score}")

Score: 0.9393
