In [1]:
import pickle 
import pennylane as qml
import pennylane.numpy as np
import jax
import jax.numpy as jnp
import optax
import jaxopt


import os, sys, argparse

parent = os.path.abspath('../src')
sys.path.insert(1, parent)

from perceptron import Perceptron

from perceptron import NativePerceptron
import time 

import matplotlib.pyplot as plt


In [2]:
Ntrials=100
N=4
P=5*N


jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

# Setting up the quantum perceptron problem
perceptron_qubits = N
n_axis=2
pulse_basis = P
sigma=0.1
save_path = ''
n_epochs = 200

ts = jnp.array([1.0])
t = 1
times = jnp.linspace(0,t, pulse_basis+2)[1:-1]
dev = qml.device("default.qubit.jax", wires=perceptron_qubits)

#Setting up perceptron
perceptron = NativePerceptron(perceptron_qubits, pulse_basis, basis='fourier', pulse_width=sigma, native_coupling=1)
H = perceptron.H


H_obj, H_obj_spectrum = perceptron.get_1d_ising_hamiltonian(0.1)
W = qml.matrix(qml.evolve(H_obj, coeff=1))

hcs = [qml.PauliX(n) for n in range(perceptron_qubits)]
hcs+= [qml.PauliY(n) for n in range(perceptron_qubits)]


@jax.jit
def loss(param_vector):
    param_list = perceptron.vector_to_hamiltonian_parameters(param_vector)
    U = qml.matrix(qml.evolve(perceptron.H)(param_list, t))
    
    return qml.math.frobenius_inner_product(jnp.conjugate(U-W), U-W).real


I0000 00:00:1708315518.244198       1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


In [3]:
random_seed = int(time.time() * 1000)
param_vector = perceptron.get_random_parameter_vector(random_seed)

solver = jaxopt.LBFGS(loss,maxiter=1000)

res=solver.run(param_vector)

print('Final loss: ', res.state.value)
print('Max grad: ', np.max(res.state.grad))


Final loss:  14.463098351332391
Max grad:  0.00034394366469295474


In [53]:
random_seed = int(time.time() * 1000)
print(loss(perceptron.get_random_parameter_vector(random_seed)))


28.828260482544017


In [4]:
res.params

Array([ 1.50030337e-01,  2.98020627e-01,  2.73625758e-01,  2.16755777e-01,
        3.75886420e-01,  2.54269227e-01,  2.94418476e-01,  2.64722767e-01,
        3.38836758e-01,  2.66611888e-01,  2.90807854e-01,  2.96020825e-01,
        3.17655938e-01,  2.86353960e-01,  3.04522650e-01,  2.86969655e-01,
        2.84762637e-01,  2.58207542e-01,  2.63904294e-01,  3.13299893e-01,
        2.15430887e-01,  3.23372165e-01,  3.67676056e-01,  2.23407878e-01,
        5.03169293e-01,  3.06888072e-01,  3.62825570e-01,  3.15884533e-01,
        4.12356795e-01,  3.61661408e-01,  3.35834166e-01,  3.70987430e-01,
        3.49376230e-01,  3.65405312e-01,  3.54721349e-01,  3.59567604e-01,
        3.15476881e-01,  3.89139899e-01,  3.73112284e-01,  3.66674193e-01,
        6.68169478e-01,  6.35309339e+00,  3.99863200e-02,  6.08469133e+00,
        8.37435869e-01,  6.65217304e+00,  1.70705448e-01,  6.20249836e+00,
        3.81395184e-01,  6.42829485e+00,  7.80425617e-01,  6.73841564e+00,
        5.63928471e-01,  