In [97]:
import numpy as np
from sympy.utilities.iterables import multiset_permutations

import jax.random
import jax.numpy as jnp
import jax.scipy.optimize
import jaxopt
import optax
jax.config.update("jax_enable_x64", True)

import pennylane as qml
from functools import partial

from matplotlib import pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
plt.ioff()

import time
import pandas as pd
from datetime import datetime
import os
import pytz
import ast
from pypdf import PdfMerger

import contextlib
import warnings
warnings.filterwarnings('ignore')

In [98]:
# Running parameters
num_iters = 2000    # Number of training iterations
num_runs = 10
cl_types = ["NCL"]#, "CL", "ACL", "SPCL", "SPACL"]     # "NCL" - No curriculum, "CL" - Curriculum, "ACL" - Anti-curriculum, "SPCL" - Self paced curriculum, "SPACL" - Self paced anti-curriculum

# Circuit parameters
nqubits = 6         # Num qubits, min 4, always 2**num_layers qubits
layers = 15          # Num layers of the variational unitary
qcnn_mode = False   # True: All the gates in the same layer share weights. False: Each gates has its own weight.
device = "default.qubit"

# Data hyper-parameters
dataset_type = "haar"  # "basis" or "haar"
batch_size = 15    # batch training size
train_size = 15    # Total states that will be used for training
val_size = 50      # Total states that will be used for validation
cl_batch_ratios = [0.3, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]  #[0.4, 0.3, 0.2, 0.1]    # [0.1, 0.2, 0.3, 0.4], [0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.3, 0.2, 0.2, 0.2, 0.1]
cl_iter_ratios  = [0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2]

# Optimization parameters
optimizer = "Adam"  # "Adam", "GradientDescent", "BFGS"
loss_type = "fidelity" # "fidelity"
dist_type = "fro"  # "fro", np.inf, 2, 1
initialization = "gaussian" # "gaussian", "uniform"
max_weight_init = 2*np.pi  # weight_init goes from 0 to this number. Max = 2*np.pi. Other options = 0.01
stepsize = 0.01         # stepsize of the gradient descent.

# Constant definitions
if qcnn_mode:
    nweights = 15*layers
else:
    nweights = 15*(layers//2*(int(nqubits//2) + int((nqubits-1)//2)) + layers%2*(int(nqubits//2)))
dev = qml.device(device, wires=nqubits)

cl_batches = []
i_batch_size = 0
for i in range(len(cl_iter_ratios)):
    if i < len(cl_iter_ratios)-1:
        i_batch_size += int(cl_batch_ratios[i]*train_size)
        i_num_iters = int(cl_iter_ratios[i]*num_iters)
    else:
        i_batch_size = train_size
        i_num_iters = num_iters - len(cl_batches)
        
    cl_batches += [i_batch_size]*i_num_iters

# Generate dataset

In [99]:
def X(i):
    return qml.PauliX(i)

def Y(i):
    return qml.PauliY(i)

def Z(i):
    return qml.PauliZ(i)

In [100]:
def hamiltonian_unitary():
    hamiltonian = sum(X(i) @ X((i+1)) + Y(i) @ Y((i+1)) + 1.5 * Z(i) @ Z((i+1)) for i in range(nqubits-1))
    qml.ApproxTimeEvolution(hamiltonian, time=0.5, n=10)


@qml.qnode(dev)
def hamiltonian_evolution(ini_state):
    qml.QubitStateVector(ini_state, wires=range(nqubits))
    hamiltonian_unitary()
    return qml.state()


def get_random_basis_state():
    choice_arr = np.array([1]+[0]*(2**nqubits-1))
    state = np.random.choice(choice_arr, size=2**nqubits, replace=False)
    return state


def generate_basis_dataset(num_points):
    basis_states = []
    target_states = []
    if num_points > 2**nqubits:
        raise ValueError(f"Error trying to generate {num_points} basis states. You can not generate more than {2**nqubits}.")
    while len(basis_states) < num_points:
        ini_state = get_random_basis_state()
        if not np.array([(ini_state==state).all() for state in basis_states]).any():  # Checks if ini_state has already been sampled
            target_state = hamiltonian_evolution(ini_state)
            basis_states.append(ini_state)
            target_states.append(target_state)

    return np.array(basis_states), np.array(target_states)


def generate_all_basis_dataset():
    
    choice_arr = np.array([1]+[0]*(2**nqubits-1))
    all_basis_states = np.array(list(multiset_permutations(choice_arr)))
    np.random.shuffle(all_basis_states)

    target_states = []
    for ini_state in all_basis_states:
        target_states.append(hamiltonian_evolution(ini_state))

    return all_basis_states, np.array(target_states)

In [101]:
def qr_haar(N):
    """Generate a Haar-random matrix using the QR decomposition."""
    """https://pennylane.ai/qml/demos/tutorial_haar_measure/"""
    
    A, B = np.random.normal(size=(N, N)), np.random.normal(size=(N, N))
    Z = A + 1j * B

    Q, R = np.linalg.qr(Z)
    Lambda = np.diag([R[i, i] / np.abs(R[i, i]) for i in range(N)])

    return np.dot(Q, Lambda)


def get_random_haar_state(num_qubits):
    qml.QubitUnitary(qr_haar(2**num_qubits), wires=range(num_qubits))
    return qml.state()


def generate_haar_dataset(num_points):
    global get_random_haar_state
    dev_haar = qml.device(device, wires=nqubits)
    get_random_haar_state = qml.QNode(get_random_haar_state, dev_haar)

    haar_states = []
    target_states = []
    while len(haar_states) < num_points:
        ini_state = get_random_haar_state(nqubits)
        target_state = hamiltonian_evolution(ini_state)
        haar_states.append(ini_state)
        target_states.append(target_state)

    return np.array(haar_states), np.array(target_states)


def generate_local_haar_dataset(num_points):
    global get_random_haar_state
    dev_haar = qml.device(device, wires=1)
    get_random_haar_state = qml.QNode(get_random_haar_state, dev_haar)
    
    haar_states = []
    target_states = []
    while len(haar_states) < num_points:
        ini_state = 1
        for _ in range(nqubits):
            ini_state = np.tensordot(ini_state, get_random_haar_state(1), axes=0).flatten()
        target_state = hamiltonian_evolution(ini_state)
        haar_states.append(ini_state)
        target_states.append(target_state)

    return np.array(haar_states), np.array(target_states)

# Quantum Circuit

In [102]:
def general_unitary_2q(q1, q2, weights):
    qml.U3(wires=q1, theta=weights[0], phi=weights[1], delta=weights[2])
    qml.U3(wires=q1, theta=weights[3], phi=weights[4], delta=weights[5])
    qml.CNOT(wires=[q2, q1])
    qml.RZ(wires=q1, phi=weights[6])
    qml.RY(wires=q2, phi=weights[7])
    qml.CNOT(wires=[q1, q2])
    qml.RY(wires=q2, phi=weights[8])
    qml.CNOT(wires=[q2, q1])
    qml.U3(wires=q1, theta=weights[9], phi=weights[10], delta=weights[11])
    qml.U3(wires=q1, theta=weights[12], phi=weights[13], delta=weights[14])



def variational_unitary(weights):
    k = 0
    for _ in range(layers//2):
        
        i = 0
        while 2*i+1 < nqubits:
            general_unitary_2q(2*i, 2*i+1, weights[k:k+15])
            k += 15 if not qcnn_mode else 0
            i += 1
        k += 0 if not qcnn_mode else 15

        i = 0
        while 2*i+2 < nqubits:
            general_unitary_2q(2*i+1, 2*i+2, weights[k:k+15])
            k += 15 if not qcnn_mode else 0
            i += 1
        k += 0 if not qcnn_mode else 15
    
    if layers % 2 != 0:
        i = 0
        while 2*i+1 < nqubits:
            general_unitary_2q(2*i, 2*i+1, weights[k:k+15])
            k += 15 if not qcnn_mode else 0
            i += 1


@jax.jit
@qml.qnode(dev, interface="jax", diff_method="best")
def variational_circuit(weights, state_ini):
    qml.QubitStateVector(state_ini, wires=range(nqubits))
    variational_unitary(weights)
    return qml.state()

# Loss and accuracy

In [103]:
@jax.jit
def single_loss(weights, ini_state, target_state):
    
    out_state = variational_circuit(weights, ini_state)

    if loss_type == "fidelity":
        cost = 1 - jnp.abs(jnp.vdot(out_state, target_state))**2
    
    return cost


@jax.jit
def loss(weights, ini_states, target_states):
    costs = jax.vmap(single_loss, in_axes=[None, 0, 0])(weights, ini_states, target_states)
    return costs.sum()/len(ini_states)

In [104]:
@jax.jit
def unitary_distance(weights):
    unitary_matrix = qml.matrix(hamiltonian_unitary)()
    variational_matrix = qml.matrix(variational_unitary)(weights)
    return jnp.linalg.norm(variational_matrix - unitary_matrix, ord=dist_type)

# Processing Data

In [105]:
def save_multi_image(filename):
    pp = PdfPages(filename)
    fig_nums = plt.get_fignums()
    figs = [plt.figure(n) for n in fig_nums]
    for fig in figs:
        fig.savefig(pp, bbox_inches='tight', format='pdf')
    pp.close()


def close_all_figures():
    fig_nums = plt.get_fignums()
    for n in fig_nums:
        plt.figure(n)
        plt.close()


def save_plots(time_now,
               folder_name,
               file_name,
               plot_run,
               uni_dists,
               losses_train,
               losses_val,
               losses_val_all_states
              ):


    fig, axis = plt.subplots(1,2)
    fig.set_figheight(6.5)
    fig.set_figwidth(15)
    fig.tight_layout(w_pad=6)  # otherwise the right y-label is slightly clipped # rect=(1,1,5,1)

    # ---------------------------------------------------------------------- #
    # -------------------- Loss and accuracy figure ------------------------ #
    # ---------------------------------------------------------------------- #

    iterations = range(1, num_iters+1)

    color1 = 'darkred'
    axis[0].set_xlabel('Iterations')
    axis[0].set_ylabel('Unitary dist.', color=color1)
    axis[0].plot(iterations, uni_dists, label="Unitary dist.", color=color1)
    axis[0].tick_params(axis='y', labelcolor=color1)
    axis[0].set_ylim(bottom=0)

    ax2 = axis[0].twinx()  # instantiate a second axes that shares the same x-axis

    color2 = 'darkblue'
    ax2.set_ylabel('Loss', color=color2)  # we already handled the x-label with axis[0]
    ax2.plot(iterations, losses_train, label="Train Loss", color=color2)
    ax2.plot(iterations, losses_val, '-.', label="Val. Loss", color=color2)
    ax2.tick_params(axis='y', labelcolor=color2)
    ax2.set_yscale("log")
    ax2.set_ylim(bottom=0.0, top=1.0)

    # plt.legend()
    axis[0].set_title(f"Unitary dist. and Loss - Run {plot_run}")

    # -------------------------------------------------------------------- #
    # -------------------- Loss validation figure ------------------------ #
    # -------------------------------------------------------------------- #

    iterations = range(1, num_iters+1)

    axis[1].yaxis.set_label_position("right")
    axis[1].yaxis.tick_right()

    color2 = 'darkblue'
    axis[1].set_ylabel('Loss', color=color2)  # we already handled the x-label with axis[0]

    for loss in losses_val_all_states:
        axis[1].plot(iterations, loss, color=color2)
    
    axis[1].tick_params(axis='y', labelcolor=color2)
    axis[1].set_yscale("log")
    axis[1].set_ylim(bottom=0.0, top=1.0)

    # plt.legend()
    axis[1].set_title(f"Loss all validation states - Run {plot_run}")


    
    # ---------------------------------------------------------------------- #
    # --------------------------- Save plots ------------------------------- #
    # ---------------------------------------------------------------------- #

    plots_pdf_name = f"{folder_name}/{time_now} - Plots - {file_name}.pdf"
    
    
    # If the file doesn't exist we save it. If it does, we merge it.
    if not os.path.isfile(plots_pdf_name):
        save_multi_image(plots_pdf_name)
    
    else:
        save_multi_image(plots_pdf_name + "2")
        # Merge the new plot with the rest and delete the last file
        merger = PdfMerger()
        merger.append(plots_pdf_name)
        merger.append(plots_pdf_name + "2")
        merger.write(plots_pdf_name)
        merger.close()
        os.remove(plots_pdf_name + "2")
    
    close_all_figures()

In [106]:
def save_hyperparameters(time_now, folder_name, file_name):
    
    # --------------- Hyperparameters -----------------#
    hyperparameters = {}
    hyperparameters["num_iters"] = [num_iters]
    hyperparameters["num_runs"] = [num_runs]
    hyperparameters["cl_types"] = [cl_types]
    hyperparameters["nqubits"] = [nqubits]
    hyperparameters["layers"] = [layers]
    hyperparameters["qcnn_mode"] = qcnn_mode
    hyperparameters["device"] = device
    hyperparameters["dataset_type"] = dataset_type
    hyperparameters["batch_size"] = [batch_size]
    hyperparameters["train_size"] = [train_size]
    hyperparameters["val_size"] = val_size
    hyperparameters["cl_batch_ratios"] = [cl_batch_ratios]
    hyperparameters["cl_iter_ratios"] = [cl_iter_ratios]
    hyperparameters["optimizer"] = [optimizer]
    hyperparameters["loss_type"] = [loss_type]
    hyperparameters["dist_type"] = [dist_type]
    hyperparameters["initialization"] = [initialization]
    hyperparameters["max_weight_init"] = [max_weight_init]
    hyperparameters["stepsize"] = [stepsize]
    hyperparameters["key"] = [time_now]

    hyperparameters = pd.DataFrame(hyperparameters)

    hyperparameters_file_name = f"{folder_name}/{time_now} - Hyperparameters{file_name}.csv"
    hyperparameters.to_csv(hyperparameters_file_name, index=False)

In [107]:
def save_data(time_now,
              folder_name,
              run,
              weights,
              losses_train,
              losses_val,
              uni_dists,
              run_time,
              cl,
              losses_val_all_states
              ):
    
    # -------------------- Total Data -------------------- #
    data = {}
    data["run"] = run
    
    it_min = np.argmin(np.array(uni_dists))
    uni_dist_min = uni_dists[it_min]
    uni_dist_last = uni_dists[num_iters-1]
    
    data["it_min"] = it_min
    data["uni_dist_min"] = uni_dist_min
    data["uni_dist_last"] = uni_dist_last
    data["run_time"] = run_time
    
    data["weights"] = [weights]
    data["losses_train"] = [losses_train]
    data["losses_val"] = [losses_val]
    data["uni_dists"] = [uni_dists]
    

    data = pd.DataFrame(data)
    
    data_file_name = f"{folder_name}/{time_now} - Data - {cl}.csv"
    data.to_csv(data_file_name, index=False, mode='a', header = not os.path.exists(data_file_name))
    
    
    # ------------------- Results ------------------- #
    
    read_data = pd.read_csv(data_file_name,
                     usecols=["it_min",
                              "uni_dist_min",
                              "uni_dists"],
                     converters={"uni_dists":ast.literal_eval})
    
    total_it_min= read_data["it_min"]
    total_uni_dist_min = read_data["uni_dist_min"]
    total_uni_dists = read_data["uni_dists"].tolist()
    
    best_run_min = total_uni_dist_min.argmax()
    best_it_min = total_it_min[best_run_min]
    avg_uni_dist_min = total_uni_dist_min.mean()
    
    best_run_last = np.argmax(np.array(total_uni_dists)[:,num_iters-1])
    avg_uni_dist_last = np.mean(np.array(total_uni_dists)[:,num_iters-1])

    results = {}
    results["type_cl"] = [cl]
    results["num_runs"] = [run+1]
    results["best_run_min"] = [best_run_min]
    results["best_run_last"] = [best_run_last]
    results["best_it_min"] = [best_it_min]
    results["best_it_last"] = [num_iters-1]
    results["best_uni_dist_min"] = [total_uni_dists[best_run_min][best_it_min]]
    results["best_uni_dist_last"] = [total_uni_dists[best_run_last][num_iters-1]]
    results["avg_uni_dist_min"] = [avg_uni_dist_min]
    results["avg_uni_dist_last"] = [avg_uni_dist_last]
    results = pd.DataFrame(results)

    results_file_name = f"{folder_name}/{time_now} - Results.csv"
    
    # If file exists, we update the info
    if os.path.exists(results_file_name):
        read_results = pd.read_csv(results_file_name)
        row_index = read_results.loc[read_results["type_cl"] == cl].index
        
        if row_index.shape != (0,):
            read_results.drop(labels=row_index[0], axis=0, inplace=True) # we delete the line if it already exists
            
        results = pd.concat([read_results, results], ignore_index=True)
    
    results.to_csv(results_file_name, index=False)
    
    
    
    # ------------------- Plots ------------------- #
    save_plots(time_now,
               folder_name,
               cl,
               run,
               uni_dists,
               losses_train,
               losses_val,
               losses_val_all_states
              )
    
    if cl == "NCL":
        cl_str = "NCL  "
    elif cl=="CL":
        cl_str = "CL   "
    elif cl=="ACL":
        cl_str = "ACL  "
    elif cl=="SPCL":
        cl_str = "SPCL "
    elif cl=="SPACL":
        cl_str = "SPACL"
        
    print(
        f" {cl_str} |"
        f" {run:3d} |"
        f" {it_min:4d}/{num_iters-1:4d} |"
        f"  {uni_dists[it_min]:0.0f}/{uni_dists[num_iters-1]:0.0f}  |"
        f" {run_time:0.0f}"
    )
    
    

In [108]:
def save_average_plots(time_now, folder_name):
        
        transparency = 0.1
        
        if cl_types == ["NCL", "CL", "ACL", "SPCL", "SPACL"]:
                arr_file_names = [["NCL","CL","ACL"], ["SPCL","SPACL"]]
        else:
                arr_file_names = [cl_types]

        for file_names in arr_file_names:

                fig, axis = plt.subplots(1,len(file_names))
                if len(file_names) == 1:
                        axis = [axis]
                        fig.tight_layout(rect=(0,0,0,0))
                else:
                        fig.set_figheight(6)
                        fig.set_figwidth(7*len(file_names))
                        fig.tight_layout(pad=4, w_pad=7)

                i = 0
                for file_name in file_names:

                        data_file_name = f"{folder_name}/{time_now} - Data - {file_name}.csv"

                        # Read the saved data #####################
                        read_data = pd.read_csv(data_file_name,
                                                usecols=["losses_train",
                                                         "losses_val",
                                                         "uni_dists"],
                                                converters={"losses_train":ast.literal_eval,
                                                            "losses_val":ast.literal_eval,
                                                            "uni_dists":ast.literal_eval})

                        all_runs_losses_train = list(map(np.array, read_data["losses_train"]))
                        all_runs_losses_val = list(map(np.array, read_data["losses_val"]))
                        all_runs_uni_dists = list(map(np.array, read_data["uni_dists"]))

                        # We take the averages
                        losses_train = sum(all_runs_losses_train)/num_runs
                        losses_val = sum(all_runs_losses_val)/num_runs
                        uni_dists = sum(all_runs_uni_dists)/num_runs

                        iterations = range(1, num_iters+1)

                        color1 = 'darkred'
                        axis[i].set_xlabel('Iterations')
                        axis[i].set_ylabel('Unitary dist.', color=color1)

                        axis[i].plot(iterations, uni_dists, label="Unitary dist.", color=color1)
                        for uni_dist in all_runs_uni_dists:
                                axis[i].plot(iterations, uni_dist, alpha=transparency, color=color1)

                        axis[i].tick_params(axis='y', labelcolor=color1)
                        axis[i].set_ylim(bottom=0)

                        ax2 = axis[i].twinx()  # instantiate a second axes that shares the same x-axis

                        color2 = 'darkblue'
                        ax2.set_ylabel('Loss', color=color2)  # we already handled the x-label with axis[0]

                        ax2.plot(iterations, losses_train, label="Train Loss", color=color2)
                        for loss_train in all_runs_losses_train:
                                ax2.plot(iterations, loss_train, alpha=transparency, color=color2)

                        ax2.plot(iterations, losses_val, '-.', label="Val. Loss", color=color2)
                        for loss_val in all_runs_losses_val:
                                ax2.plot(iterations, loss_val, '-.', alpha=transparency, color=color2)

                        ax2.tick_params(axis='y', labelcolor=color2)
                        ax2.set_yscale("log")
                        ax2.set_ylim(bottom=0.0, top=1.0)

                        # fig.tight_layout()  # otherwise the right y-label is slightly clipped
                        # plt.legend()
                        axis[i].set_title(f"Unitary dist. and Loss - Average {file_name} - ({round(uni_dists[num_iters-1],3)})")
                        
                        i += 1
                        
        plots_pdf_name = f"{folder_name}/{time_now} - Average plots.pdf"
        save_multi_image(plots_pdf_name)
        close_all_figures()

# Training

In [109]:
@jax.jit
def sort_states(w, ini_states, target_states, ascending):    
    scores = jax.vmap(single_loss, in_axes=[None, 0, 0])(jnp.array(w), jnp.array(ini_states), jnp.array(target_states))
    
    p = jnp.where(ascending, scores.argsort(), scores.argsort()[::-1])
    
    return ini_states[p], target_states[p]

In [110]:
def train_qcnn(train_states, train_target_states, val_states, val_target_states, opt, cl):
    
    if initialization == "uniform":
        weights_init = np.random.uniform(0, max_weight_init, nweights)
    elif initialization == "gaussian":
        weights_init = np.random.normal(0, 1/np.sqrt(nqubits), nweights)
        
    #Initiaize variables
    weights = []
    losses_train = []
    losses_val = []
    uni_dists = []
    losses_val_all_states = []

    w = weights_init
    state = opt.init_state(weights_init, train_states[:2], train_target_states[:2])
    
    for it in range(num_iters):
        
        # For self paced learning, we sort the datapoints at every iteration
        if cl in ["SPCL", "SPACL"]:
            ascending = True if cl == "SPCL" else False
            train_states, train_target_states = sort_states(w, train_states, train_target_states, ascending)
            
        # Once they are sorted, we select the first datapoints into the batch lists
        if cl in ["CL", "ACL", "SPCL", "SPACL"]:
            train_states_batch = train_states[:cl_batches[it]]
            train_target_states_batch = train_target_states[:cl_batches[it]]
        
        elif cl == "NCL":
            batch_index = np.random.default_rng().choice(len(train_states), size=batch_size, replace=False)
            
            train_states_batch = train_states[batch_index]
            train_target_states_batch = train_target_states[batch_index]

            
        # Update the weights by one optimizer step
        with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
            w, state = opt.update(w, state, train_states_batch, train_target_states_batch)
        
        if optimizer == "GradientDescent":
            l_train = loss(w, train_states_batch, train_target_states_batch)
        else:
            l_train = state.value
        
        l_val = loss(w, val_states, val_target_states)

        # Compute difference between variational unitary and target unitary from hamiltonian evolution
        uni_dist = unitary_distance(w)
        
        weights.append(w.tolist())
        losses_train.append(float(l_train))
        losses_val.append(float(l_val))
        uni_dists.append(float(uni_dist))

        loss_val_all_states = jax.vmap(single_loss, in_axes=[None, 0, 0])(jnp.array(w), jnp.array(val_states), jnp.array(val_target_states))
        losses_val_all_states.append(np.array(loss_val_all_states))

    losses_val_all_states = np.array(losses_val_all_states).transpose()

    return weights, losses_train, losses_val, uni_dists, losses_val_all_states

In [111]:
def main():

    # with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):

    time_now = datetime.now(pytz.timezone('Europe/Andorra')).strftime("%Y-%m-%d %H-%M-%S")

    folder_name = f"Results/{nqubits}q - {num_iters:} iters/"
    if not os.path.isdir(f'{folder_name}'):
        os.makedirs(f'{folder_name}')

    save_hyperparameters(time_now, folder_name, file_name="")

    # t = time.time()
    # key = jax.random.PRNGKey(int((t-int(t))*10**10))  # We start with a random (depending on time) seed for the jax keys

    # choose variational classifier
    if optimizer == "GradientDescent":
        opt = jaxopt.GradientDescent(loss, stepsize=stepsize, verbose=False, jit=True)
    elif optimizer == "Adam":
        opt = jaxopt.OptaxSolver(loss, optax.adam(stepsize), verbose=False, jit=True)
    elif optimizer == "BFGS":
        opt = jaxopt.BFGS(loss, verbose=False, jit=True)
    
    
    for run in range (num_runs):
        
        # -------------------------------------------------------------- #
        # ------------------- Generate ground states ------------------- #
        # -------------------------------------------------------------- #
        
        print("Generating dataset...")
        start_time = time.time()

        if dataset_type == "basis":
            basis_states, target_states = generate_all_basis_dataset()
        elif dataset_type == "haar":
            basis_states, target_states = generate_local_haar_dataset(train_size + val_size)

        train_states, train_target_states = basis_states[:train_size], target_states[:train_size]
        val_states, val_target_states = basis_states[train_size:], target_states[train_size:]
        
        run_time = time.time() - start_time

        print(f"Dataset generated - {run_time:.0f}s")
        print()
        print("Max train / Last run")
        print("-------------------------------------------")
        print("  CL   | Run |   Iter    |Acc train| Time  ")
        print("-------------------------------------------")
        
        
        for cl in cl_types:
            # ----------------------------------------------------------------------------------------------- #
            # ------------------------ Sort training gs by their score if curriculum ------------------------ #
            # ----------------------------------------------------------------------------------------------- #

            if cl in ["CL", "ACL"]:
                score_it = num_iters-1
                ascending = True if cl == "CL" else False
                train_states, train_target_states = sort_states(weights_ncl[score_it], train_states, train_target_states, ascending)

            # ------------------------------------------------------------ #
            # ------------------------ Train QCNN ------------------------ #
            # ------------------------------------------------------------ #

            start_time = time.time()

            weights, \
            losses_train, \
            losses_val, \
            uni_dists, \
            losses_val_all_states = train_qcnn(train_states,
                                            train_target_states,
                                            val_states,
                                            val_target_states,
                                            opt=opt,
                                            cl=cl
                                            )

            run_time = time.time() - start_time
            
            if cl == "NCL":
                weights_ncl = weights

            # --------------------------------------------------------- #
            # ------------------- Save calculations ------------------- #
            # --------------------------------------------------------- #
            save_data(time_now,
                      folder_name,
                      run,
                      weights,
                      losses_train,
                      losses_val,
                      uni_dists,
                      run_time,
                      cl=cl,
                      losses_val_all_states = losses_val_all_states
                      )

        print("-------------------------------------------")
        print()
    
    save_average_plots(time_now, folder_name)

# Execution

In [112]:
# for dist_type in ["fro", np.inf]:
main()

Generating dataset...
Dataset generated - 3s

Max train / Last run
-------------------------------------------
  CL   | Run |   Iter    |Acc train| Time  
-------------------------------------------
 NCL   |   0 |  237/1999 |  11/11  | 197
-------------------------------------------

Generating dataset...
Dataset generated - 3s

Max train / Last run
-------------------------------------------
  CL   | Run |   Iter    |Acc train| Time  
-------------------------------------------
 NCL   |   1 |  849/1999 |  11/11  | 97
-------------------------------------------

Generating dataset...
Dataset generated - 3s

Max train / Last run
-------------------------------------------
  CL   | Run |   Iter    |Acc train| Time  
-------------------------------------------
 NCL   |   2 |  946/1999 |  11/11  | 96
-------------------------------------------

Generating dataset...
Dataset generated - 3s

Max train / Last run
-------------------------------------------
  CL   | Run |   Iter    |Acc train|

# Miscellaneous

In [30]:
def print_matrices():
    def print_matrix(matrix):
        for row in matrix:
            for value in row:
                print(f"{round(value.real, 3) + round(value.imag, 3) * 1j:14}", end='\t')
            print('\n')

    time_now = "2024-02-06 18-39-30"
    run = 0

    folder_name = f"Results/{nqubits}q - {num_iters:} iters/"
    data_file_name = f"{folder_name}/{time_now} - Data - NCL.csv"
    read_data = pd.read_csv(data_file_name,
                            usecols=["weights"],
                            converters={"weights":ast.literal_eval})

    all_runs_weights = list(map(np.array, read_data["weights"]))
    weights = all_runs_weights[run][num_iters-1]

    unitary_matrix = qml.matrix(hamiltonian_unitary)()
    variational_matrix = qml.matrix(variational_unitary)(weights)
    fro_norm = jnp.linalg.norm(variational_matrix - unitary_matrix, ord="fro")
    inf_norm = jnp.linalg.norm(variational_matrix - unitary_matrix, ord=np.inf)

    clean_unitary_matrix = np.array([[v if np.abs(v)>10**(-1) else 0 for v in r] for r in unitary_matrix])
    clean_variational_matrix = np.array([[v if np.abs(v)>10**(-1) else 0 for v in r] for r in variational_matrix])


    data = unitary_matrix - variational_matrix
    data = pd.DataFrame(data)
    data_file_name = f"{folder_name}/{time_now} - Matrix difference - run {run}.csv"
    data.to_csv(data_file_name, index=False, header=False)
    
    data2 = clean_unitary_matrix - clean_variational_matrix
    data2 = pd.DataFrame(data2)
    data_file_name2 = f"{folder_name}/{time_now} - Matrix difference clean - run {run}.csv"
    data2.to_csv(data_file_name2, index=False, header=False)

    # print_matrix(clean_unitary_matrix)
    # print()
    # print_matrix(clean_variational_matrix)
    # print()
    # print(fro_norm)
    # print(inf_norm)

In [31]:
print_matrices()

In [4]:
time_now = "2024-02-06 18-39-30"
run = 0

folder_name = f"Results/{nqubits}q - {num_iters:} iters/"
data_file_name = f"{folder_name}/{time_now} - Data - NCL.csv"
read_data = pd.read_csv(data_file_name,
                        usecols=["weights"],
                        converters={"weights":ast.literal_eval})

all_runs_weights = list(map(np.array, read_data["weights"]))
weights = all_runs_weights[run][num_iters-1]

In [126]:
weights = all_runs_weights[run][0]

In [127]:
unitary_matrix = qml.matrix(hamiltonian_unitary)()
variational_matrix = qml.matrix(variational_unitary)(weights)
diff_matrix = unitary_matrix - variational_matrix

fro_norm = jnp.linalg.norm(diff_matrix, ord="fro")
inf_norm = jnp.linalg.norm(diff_matrix, ord=np.inf)

In [128]:
clean_unitary_matrix = np.array([[v if np.abs(v)>10**(-1) else 0 for v in r] for r in unitary_matrix])
clean_variational_matrix = np.array([[v if np.abs(v)>10**(-1) else 0 for v in r] for r in variational_matrix])
clean_diff_matrix = clean_unitary_matrix - clean_variational_matrix

clean_fro_norm = jnp.linalg.norm(clean_diff_matrix, ord="fro")
clean_inf_norm = jnp.linalg.norm(clean_diff_matrix, ord=np.inf)

In [129]:
print(fro_norm, clean_fro_norm)
print(inf_norm, clean_inf_norm)

11.327905790605817 10.898848660369541
9.925183900614426 8.598933662401748


In [130]:
ini_state = np.array([0, 1]+[0]*(2**nqubits-2))
target_state = unitary_matrix @ ini_state
out_state = variational_matrix @ ini_state

loss = jnp.abs(jnp.vdot(out_state, target_state))**2

clean_target_state = np.array([v if np.abs(v)>10**(-1) else 0 for v in target_state])
clean_out_state = np.array([v if np.abs(v)>10**(-1) else 0 for v in out_state])

clean_loss = jnp.abs(jnp.vdot(clean_out_state, clean_target_state))**2


print(clean_loss)

0.002566376434289157


In [48]:
np.savetxt('clean_target_state.csv', clean_target_state, delimiter=',')
np.savetxt('clean_out_state.csv', clean_out_state, delimiter=',')

In [131]:
abs_unitary_matrix = np.array([[np.abs(v) for v in r] for r in unitary_matrix])
abs_variational_matrix = np.array([[np.abs(v) for v in r] for r in variational_matrix])
abs_diff_matrix = abs_unitary_matrix - abs_variational_matrix

abs_fro_norm = jnp.linalg.norm(abs_diff_matrix, ord="fro")
abs_inf_norm = jnp.linalg.norm(abs_diff_matrix, ord=np.inf)


In [132]:
print(fro_norm, clean_fro_norm, abs_fro_norm)
print(inf_norm, clean_inf_norm, abs_inf_norm)

11.327905790605817 10.898848660369541 9.181683363272231
9.925183900614426 8.598933662401748 8.359812921184403


In [118]:
def global_phase_unitary_dist(U, V):
    N = U.shape[0]
    return np.sqrt(1 - np.abs(np.trace(U.conj().T @ V))/N)

In [133]:
print(global_phase_unitary_dist(unitary_matrix, variational_matrix))

0.9897180782176259
