In [21]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [22]:
from typing import Sequence
import itertools
import time
import tqdm
import matplotlib.pyplot as plt

import pennylane as qml
import jax
import jax.numpy as jnp
import optax
import flax.linen as nn

from queso.sensor import Sensor, RegressionEstimator


n = 4
k = 4
key = jax.random.PRNGKey(0)
phi = jnp.array(0.0)
theta = jax.random.uniform(key, shape=[n, 3*k])
mu = jax.random.uniform(key, shape=[n, 3])
sensor = Sensor(n=n, k=k, shots=10)

print(sensor.state(theta, phi, mu))
print(sensor.probs(theta, phi, mu))
print(sensor.sample(theta, phi, mu, shots=10))
# print(sensor.counts(theta, phi, mu, shots=10))
print(sensor.qfi(theta, phi, mu))
print(sensor.cfi(theta, phi, mu))

[ 0.31477246-0.3843014j   0.19758867+0.01313797j -0.12203943-0.03752252j
 -0.07016478+0.19542567j  0.25725755+0.37216403j -0.05611895+0.1215237j
 -0.12127692-0.23931877j -0.37968784-0.03291814j -0.22870433+0.02563431j
  0.07229804-0.03830751j -0.22201793+0.03540559j -0.09960696+0.10683048j
 -0.09485266-0.23841816j  0.01973272-0.07862855j -0.01788557+0.00127606j
  0.09885021+0.02694576j]
[0.24676927 0.03921389 0.01630156 0.04311429 0.20468751 0.01791735
 0.07198157 0.14524646 0.05296279 0.00669447 0.05054552 0.0213343
 0.06584024 0.00657183 0.00032152 0.01049744]
[[0 1 0 0]
 [0 0 0 0]
 [1 0 1 0]
 [1 0 0 0]
 [0 0 1 1]
 [1 1 0 0]
 [0 0 0 0]
 [1 0 1 1]
 [0 0 0 0]
 [0 1 1 1]]
2.983996744457663
1.2025993538923774


In [11]:
%%timeit
sensor.sample(theta, phi, mu, shots=100)

6.84 ms ± 493 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [43]:
# %%timeit
def test(theta, phi, mu):
    return sensor.sample_nonjit(theta, phi, mu, shots=100)

test_jit = jax.jit(test)
test_jit(theta, phi, mu)

InterfaceUnsupportedError: The new JAX JIT interface of PennyLane requires JAX and and JAX lib version below 0.4.4. Please downgrade these packages.If you are using pip to manage your packages, you can run the following command:

	pip install 'jax==0.4.3' 'jaxlib==0.4.3'

If you are using conda to manage your packages, you can run the following command:

	conda install 'jax==0.4.3' 'jaxlib==0.4.3'



In [23]:
%%timeit
sensor.counts(theta, phi, mu, shots=100)

62.4 ms ± 848 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [24]:
%%timeit
sensor.counts_nonjit(theta, phi, mu, shots=100)

65.9 ms ± 5.31 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [41]:
key = jax.random.PRNGKey(time.time_ns())
shots = 100
def counts(theta, phi, mu, shots, key):
    probs = sensor.probs(theta, phi, mu)
    # print(probs)
    # inds = jax.random.choice(key, len(probs), shape=(shots,), replace=True, p=probs)

    inds = jnp.array([jax.random.choice(key, len(probs), p=probs) for i in range(shots)])
    unique, counts = jnp.unique(inds, return_counts=True)
    return unique, counts

counts_jit = jax.jit(counts)
unique, counts = counts(theta, phi, mu, shots, key)
# unique, counts = count_jit(theta, phi, mu, shots, key)
    # bases = list(itertools.product(*n * [[0, 1]]))
print(unique, counts)
# samples = jnp.array([bases[ind] for ind in inds])

[11] [100]
