In [24]:
import jax
import jax.numpy as np
from jax import jit, vmap
from qiskit import QuantumCircuit, transpile, Aer
from qiskit.opflow import StateFn, PauliExpectation, Gradient
from qiskit.utils import QuantumInstance

# configure jax to use 64 bit mode
import jax
jax.config.update("jax_enable_x64", True)

# tell JAX we are using CPU
jax.config.update('jax_platform_name', 'cpu')

# import Array and set default backend
from qiskit_dynamics.array import Array
Array.set_default_backend('jax')

from qiskit_dynamics.array import wrap

jit = wrap(jit, decorator=True)

def prepare_data_circuit(x, num_qubits):
    qc = QuantumCircuit(num_qubits)
    #for i in range(num_qubits):
        #qc.ry(x[i], i)
    return qc

def hardware_efficient_ansatz(num_qubits, parameters):
    qc = QuantumCircuit(num_qubits)
    num_params_per_qubit = len(parameters) // num_qubits
    for i in range(num_qubits):
        for j in range(0, num_params_per_qubit, 3):
            qc.ry(parameters[i * num_params_per_qubit + j], i)
            qc.rz(parameters[i * num_params_per_qubit + j + 1], i)
            qc.rz(parameters[i * num_params_per_qubit + j + 2], i)
        if i < num_qubits - 1:
            qc.cz(i, i + 1)
    return qc

def quantum_neural_network(x, parameters):
    num_qubits=5 
    measure_qubit=0
    data_circuit = prepare_data_circuit(x, num_qubits)
    ansatz_circuit = hardware_efficient_ansatz(num_qubits, parameters)
    qnn_circuit = data_circuit.compose(ansatz_circuit)

    backend = Aer.get_backend('statevector_simulator')
    quantum_instance = QuantumInstance(backend)

    qubit_op = StateFn(qnn_circuit).adjoint().compose(PauliExpectation().convert(StateFn(qnn_circuit)))

    return np.real(qubit_op.eval(quantum_instance))

In [25]:
jit_qnn = jit(quantum_neural_network, static_argnums=(2, 3))

def batch_qnn(params, batch_x, num_qubits=5, measure_qubit=0):
    return vmap(jit_qnn, in_axes=(0, None))(batch_x, params)

In [26]:
import numpy as jnp
# assume that we have defined the `num_params` and `batch_size` variables

# create the `params` variable
params = jnp.random.normal(size=(15,))

# create the `batch_x` variable
batch_x = jnp.random.normal(size=(200, 5))

batch_outputs = batch_qnn(params, batch_x)

CircuitError: "Invalid param type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> for gate ry."