In [2]:
! pip install optax



In [2]:
!pip install qiskit_dynamics

Collecting qiskit_dynamics
  Downloading qiskit_dynamics-0.3.2-py3-none-any.whl (154 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.4/154.4 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m00:01[0m
Collecting multiset>=3.0.1
  Downloading multiset-3.0.1-py2.py3-none-any.whl (9.6 kB)
Installing collected packages: multiset, qiskit_dynamics
Successfully installed multiset-3.0.1 qiskit_dynamics-0.3.2


In [1]:
import random
from typing import Tuple

import optax
import jax.numpy as jnp
import jax
import numpy as np

from utils.utils import simulate, histogram_to_category

from sklearn.model_selection import train_test_split

# Import the circuits
from circuits.encoderFRQI import encode as frqi
from circuits.encoderBASIC import encode as basic_encoder

from circuits.weightsCircuit import encode as weight_layer

def run_training(X, y, encoder_fn, classifier_fn, backend='qiskit'):
    
    n_layers = 1
    batch_size = 32
    epochs = 1_000
    
    def circuit_wrapper(x: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
        print(params['w'])
        print(params['w'].shape)
        #print(np.array(w[:][0]))
        categories = []
        for i_ in range(batch_size):
            circuit = encoder_fn(jnp.array(x[i_]))
            print(type(params['w']))
            #print(type(np.array(params['w'])))
            classifier = classifier_fn(jnp.array(params['w']))

            nq1 = circuit.width()
            nq2 = classifier.width()

            nq = max(nq1, nq2)
            qc = qiskit.QuantumCircuit(nq)
            qc.append(circuit.to_instruction(), list(range(nq1)))
            qc.append(classifier.to_instruction(), list(range(nq2)))

            histogram = simulate(qc)
            categories.append[histogram_to_category(histogram)]

        return categories

    
    def loss(params: optax.Params, batch: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
        predictions = circuit_wrapper(batch, params)
        
        y_pred = jax.nn.one_hot(y % 2, 2).astype(jnp.float32).reshape(len(labels), 2)
        loss_value = optax.sigmoid_binary_cross_entropy(y_hat, labels).sum(axis=-1)

        return loss_value.mean()

    def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
        opt_state = optimizer.init(params)

        #@jax.jit
        def step(params, opt_state, batch, labels):
            loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
            updates, opt_state = optimizer.update(grads, opt_state, params)
            params = optax.apply_updates(params, updates)
            return params, opt_state, loss_value

        for i in range(epochs):

            batch_index = np.random.randint(0, len(y), (batch_size,))
            x_train_batch = X[batch_index]
            y_train_batch = y[batch_index]


            params, opt_state, loss_value = step(params, opt_state, x_train_batch, y_train_batch)
            losses.append(loss_value)
            if i % 100 == 0:
                print(f'step {i}, loss: {loss_value}')

            return params, losses

    
    initial_params = {
        'w': jax.random.normal(shape=[16, n_layers], key=jax.random.PRNGKey(0)),
    }
    
    optimizer = optax.adam(learning_rate=1e-2)
    
    optimal_params, losses = fit(initial_params, optimizer)
    
    optimal_classifier = classifier_fn(optimal_params)
    
    return optimal_classifier, losses

In [2]:

X = np.load('data/images.npy')
y = np.load('data/labels.npy')

X_reshape = np.reshape(X, (len(X), 28*28))
y_reshape = jax.nn.one_hot(y % 2, 2).astype(jnp.float32).reshape(2000, 2)

X_train, X_test, y_train, y_test = train_test_split(X_reshape, y_reshape, test_size=0.2, random_state=42)

circuit_weights, losses = run_training(X_train, y_train, frqi, weight_layer)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Traced<ConcreteArray([[ 0.08482574]
 [ 1.9097648 ]
 [ 0.29561743]
 [ 1.120948  ]
 [ 0.33432344]
 [-0.82606775]
 [ 0.6481277 ]
 [ 1.0434873 ]
 [-0.7824839 ]
 [-0.4539462 ]
 [ 0.6297971 ]
 [ 0.81524646]
 [-0.32787678]
 [-1.1234448 ]
 [-1.6607416 ]
 [ 0.27290547]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[ 0.08482574],
       [ 1.9097648 ],
       [ 0.29561743],
       [ 1.120948  ],
       [ 0.33432344],
       [-0.82606775],
       [ 0.6481277 ],
       [ 1.0434873 ],
       [-0.7824839 ],
       [-0.4539462 ],
       [ 0.6297971 ],
       [ 0.81524646],
       [-0.32787678],
       [-1.1234448 ],
       [-1.6607416 ],
       [ 0.27290547]], dtype=float32)
  tangent = Traced<ShapedArray(float32[16,1])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16,1]), None)
    recipe = LambdaBinding()
(16, 1)
<class 'jax.interpreters.ad.JVPTracer'>
<class 'jax.interpreters.ad.JVPTrace'>


CircuitError: "Invalid param type <class 'jax.interpreters.ad.JVPTracer'> for gate ry."