In [1]:
import pennylane as qml
import pennylane.numpy as np
import jax.numpy as jnp

import jax
import optax

from time import time
import pickle

import matplotlib.pyplot as plt

import os, sys

parent = os.path.abspath('../src')
sys.path.insert(1, parent)

from perceptron import NativePerceptron
from perceptron import Perceptron

# Set to float64 precision and remove jax CPU/GPU warning
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

In [2]:
# setting up the problem t=100
perceptron_qubits = 4
pulse_basis = 3*perceptron_qubits
ts = jnp.array([1.0])
t = 1000

dev = qml.device("default.qubit.jax", wires = perceptron_qubits)


perceptron = NativePerceptron(perceptron_qubits, pulse_basis, basis='fourier', pulse_width=0.005, native_coupling=1)

H =  perceptron.H

H_obj, H_obj_spectrum = perceptron.get_1d_ising_hamiltonian(0.1)

# e_ground_state_exact = H_obj_spectrum[0]

print(f'Ising Model Hamiltonian:\nH = {H_obj}')
# print(f'Exact ground state energy: {e_ground_state_exact}')

I0000 00:00:1718646602.332112       1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


Ising Model Hamiltonian:
H =   (0.1) [X0]
+ (0.1) [X1]
+ (0.1) [X2]
+ (0.1) [X3]
+ (1.0) [Z0 Z1]
+ (1.0) [Z1 Z2]
+ (1.0) [Z2 Z3]




In [3]:
V = qml.matrix(qml.evolve(H_obj, 1))

@jax.jit
def loss(param_vector):

    param_list = perceptron.vector_to_hamiltonian_parameters(param_vector)

    U = qml.matrix(qml.evolve(perceptron.H)(param_list, t))

    return qml.math.frobenius_inner_product(jnp.conjugate(U-V),U-V).real

In [4]:
value_and_grad = jax.jit(jax.value_and_grad(loss))

In [5]:
from datetime import datetime

n_epochs = 1000
param_vector = perceptron.get_random_parameter_vector(118293)


# The following block creates a constant schedule of the learning rate
# # that increases from 0.1 to 0.5 after 10 epochs
# schedule0 = optax.constant_schedule(0.1)
# schedule1 = optax.constant_schedule(0.05)
# # schedule2 = optax.constant_schedule(0.001)
# # schedule = optax.join_schedules([schedule0, schedule1, schedule2], [200, 3000])
# schedule = optax.join_schedules([schedule0, schedule1], [200])

# optimizer = optax.adam(learning_rate=schedule)

# optimizer = optax.adam(learning_rate=0.1)

optimizer = optax.adam(learning_rate=0.001)

# optimizer = optax.sgd(learning_rate=0.005)
# optimizer = optax.adabelief(0.1)
opt_state = optimizer.init(param_vector)

energies = np.zeros(n_epochs )
# energy[0] = loss(param_vector)
mean_gradients = np.zeros(n_epochs)

gradients_trajectory = []
param_trajectory = []

# ## Compile the evaluation and gradient function and report compilation time
# time0 = time()
# _ = value_and_grad(param_vector)
# time1 = time()

# print(f"grad and val compilation time: {time1 - time0}")


## Optimization loop
for n in range(n_epochs):
    val, grads = value_and_grad(param_vector)
    updates, opt_state = optimizer.update(grads, opt_state)

    mean_gradients[n] = np.mean(np.abs(grads))
    energies[n] = val
    param_trajectory.append(param_vector)
    gradients_trajectory.append(grads)

    param_vector = optax.apply_updates(param_vector, updates)

    print(val)

    # if not n % 10:
    #     print(f"{n+1} / {n_epochs}; Frobenius norm: {val}")
    #     print(f"    mean grad: {mean_gradients[n]}")
    #     print(f'    gradient norm: {jnp.linalg.norm(grads)}')
    #     if n>=2:
    #         print(f'    difference of gradients: {jnp.linalg.norm(grads-gradients_trajectory[-2])}')



print(f"Optimal Frobenius Norm Found: {energies[-1]}")


29.511019978067697
25.899743620993444
26.674361725905936
25.290815180948513
24.58103910390865
24.91141643786822
24.754351136892243
24.073441643909597
23.65716514055216
23.711270737606448
23.691532381784747
23.357608138645276
23.043018265487074
23.018333955843058
23.082474424113325
22.937017776527703
22.624534809035236
22.410523345551226
22.372930370864257
22.32157961797252
22.139384069327377
21.938587276458957
21.83395838303755
21.759912653257903
21.61455701421901
21.443840927947935
21.34549332027177
21.286239693690096
21.170019671951607
21.015764386201823
20.9181269711858
20.871253066674008
20.790584064087156
20.66977165643018
20.582206029608678
20.53794935811573
20.477049483544754
20.39228078273947
20.327823011017195
20.28446422510715
20.2283874578827
20.166086310757247
20.118620595681445
20.07079390768991
20.012183246056523
19.962258231677115
19.920582435274614
19.86696215451567
19.811392146324387
19.77043365546786
19.728117575886582
19.675511702359707
19.63082838105385
19.595036816