In [3]:
! pip install optax

Collecting optax
  Using cached optax-0.1.4-py3-none-any.whl (154 kB)
Collecting chex>=0.1.5
  Using cached chex-0.1.5-py3-none-any.whl (85 kB)
Collecting jaxlib>=0.1.37
  Using cached jaxlib-0.4.2-cp39-cp39-manylinux2014_x86_64.whl (71.9 MB)
Collecting absl-py>=0.7.1
  Using cached absl_py-1.4.0-py3-none-any.whl (126 kB)
Collecting jax>=0.1.55
  Using cached jax-0.4.2-py3-none-any.whl
Collecting dm-tree>=0.1.5
  Using cached dm_tree-0.1.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (153 kB)
Collecting toolz>=0.9.0
  Using cached toolz-0.12.0-py3-none-any.whl (55 kB)
Collecting opt-einsum
  Using cached opt_einsum-3.3.0-py3-none-any.whl (65 kB)
Installing collected packages: dm-tree, toolz, opt-einsum, absl-py, jaxlib, jax, chex, optax
Successfully installed absl-py-1.4.0 chex-0.1.5 dm-tree-0.1.8 jax-0.4.2 jaxlib-0.4.2 opt-einsum-3.3.0 optax-0.1.4 toolz-0.12.0


In [12]:
import random
from typing import Tuple

import optax
import jax.numpy as jnp
import jax
import numpy as np

from utils.utils import simulate, histogram_to_category

from sklearn.model_selection import train_test_split

# Import the circuits
from circuits.encoderFRQI import encode as frqi
from circuits.encoderBASIC import encode as basic_encoder

from circuits.weightsCircuit import encode as weight_layer

def run_training(X, y, encoder_fn, classifier_fn, backend='qiskit'):
    
    n_layers = 2
    batch_size = 32
    epochs = 1_000
    
    def circuit_wrapper(x: jnp.ndarray, w: jnp.ndarray) -> jnp.ndarray:
        circuit = encoder_fn(x)
        classifier = classifier_fn(w)
        
        nq1 = circuit.width()
        nq2 = classifier.width()
        
        nq = max(nq1, nq2)
        qc = qiskit.QuantumCircuit(nq)
        qc.append(circuit.to_instruction(), list(range(nq1)))
        qc.append(classifier.to_instruction(), list(range(nq2)))

        histogram = simulate(qc)
        
        return histogram_to_category(histogram)

    
    def loss(params: optax.Params, batch: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
        predictions = circuit_wrapper(batch, params)
        
        y_pred = jax.nn.one_hot(y % 2, 2).astype(jnp.float32).reshape(len(labels), 2)
        loss_value = optax.sigmoid_binary_cross_entropy(y_hat, labels).sum(axis=-1)

        return loss_value.mean()

    def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
        opt_state = optimizer.init(params)

        @jax.jit
        def step(params, opt_state, batch, labels):
            loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
            updates, opt_state = optimizer.update(grads, opt_state, params)
            params = optax.apply_updates(params, updates)
            return params, opt_state, loss_value

        for i in range(epochs):

            batch_index = np.random.randint(0, len(y), (batch_size,))
            x_train_batch = X[batch_index]
            y_train_batch = y[batch_index]


            params, opt_state, loss_value = step(params, opt_state, x_train_batch, y_train_batch)
            losses.append(loss_value)
            if i % 100 == 0:
                print(f'step {i}, loss: {loss_value}')

            return params, losses

    
    initial_params = {
        'w': jax.random.normal(shape=[16, n_layers], key=jax.random.PRNGKey(0)),
    }
    
    optimizer = optax.adam(learning_rate=1e-2)
    
    optimal_params, losses = fit(initial_params, optimizer)
    
    optimal_classifier = classifier_fn(optimal_params)
    
    return optimal_classifier, losses



In [13]:

X = np.load('data/images.npy')
y = np.load('data/labels.npy')
y_reshape = jax.nn.one_hot(y % 2, 2).astype(jnp.float32).reshape(2000, 2)

X_train, X_test, y_train, y_test = train_test_split(X, y_reshape, test_size=0.33, random_state=42)

circuit_weights, losses = run_training(X_train, y_train, frqi, weight_layer)

CircuitError: "Invalid param type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> for gate ry."