In [132]:
import numpy as np
from sympy.utilities.iterables import multiset_permutations

import jax.random
import jax.numpy as jnp
import jax.scipy.optimize
import jaxopt
import optax
jax.config.update("jax_enable_x64", True)

import pennylane as qml
from functools import partial

from matplotlib import pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
plt.ioff()

import time
import pandas as pd
from datetime import datetime
import os
import pytz
import ast
from pypdf import PdfMerger

import contextlib
import warnings
warnings.filterwarnings('ignore')

In [133]:
# Running parameters
num_iters = 500    # Number of training iterations
num_runs = 10
cl_types = ["NCL"]#, "CL", "ACL", "SPCL", "SPACL"]     # "NCL" - No curriculum, "CL" - Curriculum, "ACL" - Anti-curriculum, "SPCL" - Self paced curriculum, "SPACL" - Self paced anti-curriculum

# Circuit parameters
nqubits = 6
error_type = "BitFlip" # 0-"BitFlip", 1-"DepolarizingChannel", 2-"xxChannel"
error_prob = 0.1

# Data hyper-parameters
batch_size = 36    # batch training size
train_size = 36    # Total states that will be used for training
# val_size = 0      # Total states that will be used for validation
cl_batch_ratios = [1/3, 1/3, 1/3]  # [0.3, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
cl_iter_ratios  = [1/4, 1/4, 1/2]  # [0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2]

# Optimization parameters
optimizer = "Adam"  # "Adam", "GradientDescent", "BFGS"
loss_type = "fidelity" # "fidelity"
initialization = "gaussian" # "gaussian", "uniform"
max_weight_init = 2*np.pi  # weight_init goes from 0 to this number. Max = 2*np.pi. Other options = 0.01
stepsize = 0.01         # stepsize of the gradient descent.

# Constant definitions
cl_batches = []
i_batch_size = 0
for i in range(len(cl_iter_ratios)):
    if i < len(cl_iter_ratios)-1:
        i_batch_size += int(cl_batch_ratios[i]*train_size)
        i_num_iters = int(cl_iter_ratios[i]*num_iters)
    else:
        i_batch_size = train_size
        i_num_iters = num_iters - len(cl_batches)
        
    cl_batches += [i_batch_size]*i_num_iters

# Generate dataset

In [134]:
def X(i):
    return qml.PauliX(i)

def Y(i):
    return qml.PauliY(i)

def Z(i):
    return qml.PauliZ(i)

In [135]:
def generate_dataset():
    x_matrix = qml.matrix(X(0))
    _, x_eigvecs = np.linalg.eigh(x_matrix)
    
    y_matrix = qml.matrix(Y(0))
    _, y_eigvecs = np.linalg.eigh(y_matrix)

    z_matrix = qml.matrix(Z(0))
    _, z_eigvecs = np.linalg.eigh(z_matrix)
    
    # We put the eigenvectors together
    eigvecs = np.concatenate((x_eigvecs, y_eigvecs, z_eigvecs), axis=1)
    
    # We make the list of all the possible tensor product between themselves
    eigvecs = np.array([np.tensordot(e1, e2, axes=0).flatten() for e1 in eigvecs.T for e2 in eigvecs.T])
    
    # We make the state matrices out of the state vectors
    eigvecs = np.array([np.tensordot(e, np.conjugate(e), axes=0) for e in eigvecs])

    return eigvecs

# Quantum Circuit

In [136]:
def general_unitary_2q(q1, q2, weights):
    qml.U3(wires=q1, theta=weights[0], phi=weights[1], delta=weights[2])
    qml.U3(wires=q1, theta=weights[3], phi=weights[4], delta=weights[5])
    qml.CNOT(wires=[q2, q1])
    qml.RZ(wires=q1, phi=weights[6])
    qml.RY(wires=q2, phi=weights[7])
    qml.CNOT(wires=[q1, q2])
    qml.RY(wires=q2, phi=weights[8])
    qml.CNOT(wires=[q2, q1])
    qml.U3(wires=q1, theta=weights[9], phi=weights[10], delta=weights[11])
    qml.U3(wires=q1, theta=weights[12], phi=weights[13], delta=weights[14])

def bitFlip(error_prob, n):
    for i in range(n):
        qml.BitFlip(error_prob, i)
    return 0

def depolarizingChannel(error_prob, n):
    for i in range(n):
        qml.DepolarizingChannel(error_prob, i)
    return 0

def xxChannel_2q(error_prob, i):
    k0 = jnp.sqrt(1-error_prob)*np.eye(4)
    k1 = jnp.sqrt(error_prob)*qml.matrix(X(i) @ X(i+1))
    qml.QubitChannel([k0,k1], wires=[i,i+1])

def xxChannel_8q(error_prob):
    for i in [0,1,2,4,5,6]:
        xxChannel_2q(error_prob, i)
    return 0

def xxChannel_6q(error_prob):
    for i in [0,1,3,4]:
        xxChannel_2q(error_prob, i)
    return 0

def xxChannel(error_prob, n):
    if n == 8:
        xxChannel_8q(error_prob)
    elif n == 6:
        xxChannel_6q(error_prob)

def noise_channel(error_prob, error_type, n):
    if error_type == "BitFlip":
        bitFlip(error_prob, n)
    elif error_type == "DepolarizingChannel":
        depolarizingChannel(error_prob, n)
    elif error_type == "xxChannel":
        xxChannel(error_prob, n)

In [137]:
def variational_channel(weights, error_prob, error_type):
    # https://pennylane.ai/blog/2021/05/how-to-simulate-noise-with-pennylane/
    w_u1 = weights[0:15]
    w_u2 = weights[15:30]
    
    qml.adjoint(general_unitary_2q)(1, 2, w_u1)
    qml.adjoint(general_unitary_2q)(5, 6, w_u1)
    
    qml.adjoint(general_unitary_2q)(0, 1, w_u2)
    qml.adjoint(general_unitary_2q)(2, 3, w_u2)
    qml.adjoint(general_unitary_2q)(4, 5, w_u2)
    qml.adjoint(general_unitary_2q)(6, 7, w_u2)
    
    noise_channel(error_prob, error_type, n=8)
    
    general_unitary_2q(6, 7, w_u2)
    general_unitary_2q(4, 5, w_u2)
    general_unitary_2q(2, 3, w_u2)
    general_unitary_2q(0, 1, w_u2)
    
    general_unitary_2q(5, 6, w_u1)
    general_unitary_2q(1, 2, w_u1)

dev = qml.device("default.mixed", wires=nqubits)
@partial(jax.jit, static_argnames=["error_type"])
@qml.qnode(dev, interface="jax")
def variational_circuit(weights, state_ini, error_prob, error_type):
    qml.QubitDensityMatrix(state_ini, wires=[1,5])
    variational_channel(weights, error_prob, error_type)
    return qml.density_matrix([1,5])

In [138]:
def variational_channel3(weights, error_prob, error_type):
    j = 0
    for i in range(6):
        qml.Rot(weights[j+3*i], weights[j+1+3*i], weights[j+2+3*i], wires=i)
    j += 18
    qml.CNOT(wires=[0,1])
    for i in range(6):
        qml.Rot(weights[j+3*i], weights[j+1+3*i], weights[j+2+3*i], wires=i)
    j += 18
    qml.CNOT(wires=[3,4])
    for i in range(6):
        qml.Rot(weights[j+3*i], weights[j+1+3*i], weights[j+2+3*i], wires=i)
    j += 18
    qml.CNOT(wires=[0,2])
    for i in range(6):
        qml.Rot(weights[j+3*i], weights[j+1+3*i], weights[j+2+3*i], wires=i)
    j += 18
    qml.CNOT(wires=[3,5])
    for i in range(6):
        qml.Rot(weights[j+3*i], weights[j+1+3*i], weights[j+2+3*i], wires=i)
    j += 18
    noise_channel(error_prob, error_type, n=6)
    for i in range(6):
        qml.Rot(weights[j+3*i], weights[j+1+3*i], weights[j+2+3*i], wires=i)
    j += 18
    qml.CNOT(wires=[0,1])
    for i in range(6):
        qml.Rot(weights[j+3*i], weights[j+1+3*i], weights[j+2+3*i], wires=i)
    j += 18
    qml.CNOT(wires=[3,4])
    for i in range(6):
        qml.Rot(weights[j+3*i], weights[j+1+3*i], weights[j+2+3*i], wires=i)
    j += 18
    qml.CNOT(wires=[0,2])
    for i in range(6):
        qml.Rot(weights[j+3*i], weights[j+1+3*i], weights[j+2+3*i], wires=i)
    j += 18
    qml.CNOT(wires=[3,5])
    for i in range(6):
        qml.Rot(weights[j+3*i], weights[j+1+3*i], weights[j+2+3*i], wires=i)
    j += 18
    qml.Toffoli(wires=[2,1,0])
    for i in range(6):
        qml.Rot(weights[j+3*i], weights[j+1+3*i], weights[j+2+3*i], wires=i)
    j += 18
    qml.Toffoli(wires=[5,4,3])
    for i in range(6):
        qml.Rot(weights[j+3*i], weights[j+1+3*i], weights[j+2+3*i], wires=i)
    j += 18

def variational_channel2(weights, error_prob, error_type):
    j = 0
    qml.Rot(weights[j], weights[j+1], weights[j+2], wires=0)
    j += 3
    qml.CNOT(wires=[0,1])
    qml.Rot(weights[j], weights[j+1], weights[j+2], wires=1)
    j += 3
    qml.CNOT(wires=[3,4])
    qml.Rot(weights[j], weights[j+1], weights[j+2], wires=2)
    j += 3
    qml.CNOT(wires=[0,2])
    qml.Rot(weights[j], weights[j+1], weights[j+2], wires=3)
    j += 3
    qml.CNOT(wires=[3,5])
    qml.Rot(weights[j], weights[j+1], weights[j+2], wires=4)
    j += 3
    noise_channel(error_prob, error_type, n=6)
    qml.Rot(weights[j], weights[j+1], weights[j+2], wires=5)
    j += 3
    qml.CNOT(wires=[0,1])
    qml.Rot(weights[j], weights[j+1], weights[j+2], wires=0)
    j += 3
    qml.CNOT(wires=[3,4])
    qml.Rot(weights[j], weights[j+1], weights[j+2], wires=1)
    j += 3
    qml.CNOT(wires=[0,2])
    qml.Rot(weights[j], weights[j+1], weights[j+2], wires=2)
    j += 3
    qml.CNOT(wires=[3,5])
    qml.Rot(weights[j], weights[j+1], weights[j+2], wires=3)
    j += 3
    qml.Toffoli(wires=[2,1,0])
    qml.Rot(weights[j], weights[j+1], weights[j+2], wires=4)
    j += 3
    qml.Toffoli(wires=[5,4,3])
    qml.Rot(weights[j], weights[j+1], weights[j+2], wires=5)
    j += 3

dev = qml.device("default.mixed", wires=6)
@partial(jax.jit, static_argnames=["error_type"])
@qml.qnode(dev, interface="jax")
def variational_circuit2(weights, state_ini, error_prob, error_type):
    qml.QubitDensityMatrix(state_ini, wires=[0,3])
    variational_channel2(weights, error_prob, error_type)
    return qml.density_matrix([0,3])

In [139]:
def shor_channel(error_prob, error_type):
    qml.CNOT(wires=[0,1])
    qml.CNOT(wires=[3,4])
    qml.CNOT(wires=[0,2])
    qml.CNOT(wires=[3,5])
    noise_channel(error_prob, error_type, n=6)
    qml.CNOT(wires=[0,1])
    qml.CNOT(wires=[3,4])
    qml.CNOT(wires=[0,2])
    qml.CNOT(wires=[3,5])
    qml.Toffoli(wires=[2,1,0])
    qml.Toffoli(wires=[5,4,3])
    
dev = qml.device("default.mixed", wires=6)
@qml.qnode(dev)
def shor_circuit(state_ini, error_prob, error_type):
    qml.QubitDensityMatrix(state_ini, wires=[0,3])
    shor_channel(error_prob, error_type)
    return qml.density_matrix([0,3])

In [140]:
dev = qml.device("default.mixed", wires=nqubits)
@qml.qnode(dev)
def only_noise_circuit(state_ini, error_prob, error_type):
    qml.QubitDensityMatrix(state_ini, wires=[1,5])
    noise_channel(error_prob, error_type, n=nqubits)
    return qml.density_matrix([1,5])

In [141]:
def circuit(weights, state_ini, error_prob, error_type, circuit_type):
    
    if circuit_type == "Variational":
        state_out = variational_circuit2(weights, state_ini, error_prob, error_type)
    elif circuit_type == "Shor":
        state_out = shor_circuit(state_ini, error_prob, error_type)
    elif circuit_type == "Only noise":
        state_out = only_noise_circuit(state_ini, error_prob, error_type)

    return state_out

# Loss and accuracy

In [142]:
def single_loss(weights, ini_state, error_prob, error_type, circuit_type):
    
    out_state = circuit(weights, ini_state, error_prob, error_type, circuit_type)

    if loss_type == "fidelity":
        cost = 1 - qml.math.fidelity(jnp.array(out_state), ini_state)
    
    return cost


def loss(weights, ini_states, error_prob, error_type, circuit_type):
    costs = jax.vmap(single_loss, in_axes=[None, 0, None, None, None])(weights, ini_states, error_prob, error_type, circuit_type)
    return costs.sum()/len(ini_states)

grad = jax.grad(loss, argnums=0)

# Processing Data

In [143]:
def save_multi_image(filename):
    pp = PdfPages(filename)
    fig_nums = plt.get_fignums()
    figs = [plt.figure(n) for n in fig_nums]
    for fig in figs:
        fig.savefig(pp, bbox_inches='tight', format='pdf')
    pp.close()


def close_all_figures():
    fig_nums = plt.get_fignums()
    for n in fig_nums:
        plt.figure(n)
        plt.close()


def save_plots(time_now,
               folder_name,
               file_name,
               plot_run,
               losses_train
            #    losses_val
              ):


    fig, axis = plt.subplots(1,1)
    # fig.set_figheight(6.5)
    # fig.set_figwidth(15)
    fig.tight_layout()  # otherwise the right y-label is slightly clipped # rect=(1,1,5,1)

    # ---------------------------------------------------------------------- #
    # -------------------- Loss and accuracy figure ------------------------ #
    # ---------------------------------------------------------------------- #

    iterations = range(1, num_iters+1)

    color2 = 'darkblue'
    axis.set_ylabel('Loss', color=color2)  # we already handled the x-label with axis[0]
    axis.plot(iterations, losses_train[:-2], label="Train Loss", color=color2)
    # axis.plot(iterations, losses_val, '-.', label="Val. Loss", color=color2)
    axis.tick_params(axis='y', labelcolor=color2)
    axis.set_yscale("log")
    # axis.set_ylim(bottom=0.0, top=1.0)

    # plt.legend()
    axis.set_title(f"Training and Validation Loss - Run {plot_run} ({losses_train[-3]:.2E})") #/{losses_val[num_iters-1]:.2E})")
    
    # ---------------------------------------------------------------------- #
    # --------------------------- Save plots ------------------------------- #
    # ---------------------------------------------------------------------- #

    plots_pdf_name = f"{folder_name}/{time_now} - Plots - {file_name}.pdf"
    
    
    # If the file doesn't exist we save it. If it does, we merge it.
    if not os.path.isfile(plots_pdf_name):
        save_multi_image(plots_pdf_name)
    
    else:
        save_multi_image(plots_pdf_name + "2")
        # Merge the new plot with the rest and delete the last file
        merger = PdfMerger()
        merger.append(plots_pdf_name)
        merger.append(plots_pdf_name + "2")
        merger.write(plots_pdf_name)
        merger.close()
        os.remove(plots_pdf_name + "2")
    
    close_all_figures()

In [144]:
def save_hyperparameters(time_now, folder_name, file_name):
    
    # --------------- Hyperparameters -----------------#
    hyperparameters = {}
    hyperparameters["num_iters"] = [num_iters]
    hyperparameters["num_runs"] = [num_runs]
    hyperparameters["cl_types"] = [cl_types]
    hyperparameters["nqubits"] = [nqubits]
    hyperparameters["error_type"] = [error_type]
    hyperparameters["error_prob"] = [error_prob]
    hyperparameters["batch_size"] = [batch_size]
    hyperparameters["train_size"] = [train_size]
    hyperparameters["cl_batch_ratios"] = [cl_batch_ratios]
    hyperparameters["cl_iter_ratios"] = [cl_iter_ratios]
    hyperparameters["optimizer"] = [optimizer]
    hyperparameters["loss_type"] = [loss_type]
    hyperparameters["initialization"] = [initialization]
    hyperparameters["max_weight_init"] = [max_weight_init]
    hyperparameters["stepsize"] = [stepsize]
    hyperparameters["key"] = [time_now]

    hyperparameters = pd.DataFrame(hyperparameters)

    hyperparameters_file_name = f"{folder_name}/{time_now} - Hyperparameters{file_name}.csv"
    hyperparameters.to_csv(hyperparameters_file_name, index=False, mode='a', header = not os.path.exists(hyperparameters_file_name))

In [145]:
def save_data(time_now,
              folder_name,
              run,
              weights,
              losses_train,
              gradients,
              run_time,
              cl
              ):
    
    # -------------------- Total Data -------------------- #
    data = {}
    data["run"] = run
    data["run_time"] = run_time
    data["weights"] = [weights]
    data["losses_train"] = [losses_train]
    data["gradients"] = [gradients]
    

    data = pd.DataFrame(data)
    
    data_file_name = f"{folder_name}/{time_now} - Data - {cl}.csv"
    data.to_csv(data_file_name, index=False, mode='a', header = not os.path.exists(data_file_name))
    
    
    # ------------------- Plots ------------------- #
    save_plots(time_now,
               folder_name,
               cl,
               run,
               losses_train
            #    losses_val
              )
    
    if cl == "NCL":
        cl_str = "NCL  "
    elif cl=="CL":
        cl_str = "CL   "
    elif cl=="ACL":
        cl_str = "ACL  "
    elif cl=="SPCL":
        cl_str = "SPCL "
    elif cl=="SPACL":
        cl_str = "SPACL"
        
    print(
        f" {cl_str} |"
        f" {run:3d} |"
        f" {run_time:0.0f}"
    )
    
    

In [146]:
def save_plots_error_rate(time_now, folder_name):

    hyperparams_file_name = f"{folder_name}/{time_now} - Hyperparameters.csv"
    read_hyperparams = pd.read_csv(hyperparams_file_name, usecols=["error_prob"])
    error_probs = list(read_hyperparams["error_prob"])


    if cl_types == ["NCL", "CL", "ACL", "SPCL", "SPACL"]:
        arr_file_names = [["NCL","CL","ACL"], ["SPCL","SPACL"]]
    else:
        arr_file_names = [cl_types]
        
    for file_names in arr_file_names:

        fig, axis = plt.subplots(1,len(file_names))
        if len(file_names) == 1:
            axis = [axis]
            fig.tight_layout(rect=(0,0,0,0))
        else:
            fig.set_figheight(6)
            fig.set_figwidth(7*len(file_names))
            fig.tight_layout(pad=4, w_pad=7)
        
        i = 0
        for file_name in file_names:
            
            data_file_name = f"{folder_name}/{time_now} - Data - {file_name}.csv"
            read_data = pd.read_csv(data_file_name,
                                    usecols=["losses_train"],
                                    converters={"losses_train":ast.literal_eval})
            last_losses_train = np.array(list(map(np.array, read_data["losses_train"])))[:,-3]
            losses_shor = np.array(list(map(np.array, read_data["losses_train"])))[:,-2]
            losses_noise = np.array(list(map(np.array, read_data["losses_train"])))[:,-1]

            # ---------------------------------------------------------------------- #
            # -------------------- Loss and accuracy figure ------------------------ #
            # ---------------------------------------------------------------------- #

            color1 = "darkred"
            color2 = "darkblue"
            color3 = "black"
            axis[i].set_xlabel('Error rate')
            axis[i].set_ylabel('Loss')
            axis[i].plot(error_probs, losses_shor, label="Error code loss", color=color1)
            axis[i].plot(error_probs, last_losses_train, label="Variational loss", color=color2)
            axis[i].plot(error_probs, losses_noise, '-.', label="Only noise loss", color=color3)
            axis[i].set_yscale("log")
            axis[i].set_xscale("log")
            axis[i].set_title(f"Last training loss vs error rate with {error_type} error - {file_name}")
            axis[i].legend()
            i += 1
    
    plots_pdf_name = f"{folder_name}/{time_now} - Plots error rate.pdf"
    save_multi_image(plots_pdf_name)
    close_all_figures()

# Training

In [147]:
def sort_states(w, ini_states, ascending, error_prob, error_type, circuit_type):
    scores = jax.vmap(single_loss, in_axes=[None, 0, None, None, None])(jnp.array(w), jnp.array(ini_states), error_prob, error_type, circuit_type)
    
    p = jnp.where(ascending, scores.argsort(), scores.argsort()[::-1])
    
    return ini_states[p]

In [148]:
def train_qcnn(train_states, opt, cl, error_prob, error_type, opt_update):
    
    if initialization == "uniform":
        weights_init = np.random.uniform(0, max_weight_init, 36)  # 30 or 216 or 36
    elif initialization == "gaussian":
        weights_init = np.random.normal(0, 1/np.sqrt(nqubits), 36)  # 30 or 216 or 36
        
    #Initiaize variables
    weights = []
    losses_train = []
    # losses_val = []

    w = weights_init
    state = opt.init_state(weights_init, train_states[:2], error_prob, error_type, circuit_type="Variational")
    
    gradients = []
    for it in range(num_iters):
        
        # For self paced learning, we sort the datapoints at every iteration
        if cl in ["SPCL", "SPACL"]:
            ascending = True if cl == "SPCL" else False
            train_states = sort_states(w, train_states, ascending, error_prob, error_type, circuit_type="Variational")
            
        # Once they are sorted, we select the first datapoints into the batch lists
        if cl in ["CL", "ACL", "SPCL", "SPACL"]:
            train_states_batch = train_states[:cl_batches[it]]
        
        elif cl == "NCL":
            batch_index = np.random.default_rng().choice(len(train_states), size=batch_size, replace=False)
            train_states_batch = train_states[batch_index]

            
        # Update the weights by one optimizer step
        # with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
        w, state = opt_update(w, state, train_states_batch, error_prob, error_type=error_type, circuit_type="Variational")
        # gradients.append(grad(w, train_states_batch, error_prob, error_type, circuit_type="Variational").tolist())
        
        if optimizer == "GradientDescent":
            l_train = loss(w, train_states_batch, error_prob, error_type, circuit_type="Variational")
        else:
            l_train = state.value
            
        weights.append(w.tolist())
        losses_train.append(float(l_train))
    
    l_shor = loss(w, train_states, error_prob, error_type, circuit_type="Shor")
    l_noise = loss(w, train_states, error_prob, error_type, circuit_type="Only noise")
    losses_train.append(float(l_shor))
    losses_train.append(float(l_noise))

    return weights, losses_train, gradients

In [149]:
def main():

    # with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):

    time_now = datetime.now(pytz.timezone('Europe/Andorra')).strftime("%Y-%m-%d %H-%M-%S")

    folder_name = f"Results/{nqubits}q - {num_iters:} iters/"
    if not os.path.isdir(f'{folder_name}'):
        os.makedirs(f'{folder_name}')

    # choose variational classifier
    if optimizer == "GradientDescent":
        opt = jaxopt.GradientDescent(loss, stepsize=stepsize, verbose=False, jit=True)
    elif optimizer == "Adam":
        opt = jaxopt.OptaxSolver(loss, optax.adam(stepsize), verbose=False, jit=False)
    elif optimizer == "BFGS":
        opt = jaxopt.BFGS(loss, verbose=False, jit=True)
    
    opt_update = jax.jit(opt.update, static_argnames=["error_type", "circuit_type"])
    
    step = (np.log(10**(-1)) - np.log(10**(-4)))/9

    error_probs = []
    xpoint = np.log(10**(-4))
    for i in range(10):
        val = np.exp(xpoint)
        error_probs.append(val)
        xpoint += step
        
    num_runs = len(error_probs)
    run = 0
    global error_prob
    for error_prob in error_probs:
        
        save_hyperparameters(time_now, folder_name, file_name="")
        
        # -------------------------------------------------------------- #
        # ------------------- Generate ground states ------------------- #
        # -------------------------------------------------------------- #
        
        print("Generating dataset...")
        start_time = time.time()
        states = generate_dataset()
        train_states = states[:train_size]
        run_time = time.time() - start_time

        print(f"Dataset generated - {run_time:.0f}s")
        print()
        print("Max train / Last run")
        print("---------------------")
        print("  CL   | Run | Time  ")
        print("---------------------")
        
        
        for cl in cl_types:
            # ----------------------------------------------------------------------------------------------- #
            # ------------------------ Sort training gs by their score if curriculum ------------------------ #
            # ----------------------------------------------------------------------------------------------- #

            if cl in ["CL", "ACL"]:
                score_it = num_iters-1
                ascending = True if cl == "CL" else False
                train_states = sort_states(weights_ncl[score_it], train_states, ascending, error_prob, error_type, circuit_type="Variational")

            # ------------------------------------------------------------ #
            # ------------------------ Train QCNN ------------------------ #
            # ------------------------------------------------------------ #

            start_time = time.time()

            weights, losses_train, gradients = train_qcnn(train_states, opt, cl, error_prob, error_type, opt_update)

            run_time = time.time() - start_time
            
            if cl == "NCL":
                weights_ncl = weights

            # --------------------------------------------------------- #
            # ------------------- Save calculations ------------------- #
            # --------------------------------------------------------- #
            save_data(time_now,
                      folder_name,
                      run,
                      weights,
                      losses_train,
                      gradients,
                      run_time,
                      cl
                      )

        print("---------------------")
        print()
        run += 1
    
    save_plots_error_rate(time_now, folder_name)

# Execution

In [150]:
main()

Generating dataset...
Dataset generated - 0s

Max train / Last run
---------------------
  CL   | Run | Time  
---------------------
 NCL   |   0 | 23
---------------------

Generating dataset...
Dataset generated - 0s

Max train / Last run
---------------------
  CL   | Run | Time  
---------------------
 NCL   |   1 | 20
---------------------

Generating dataset...
Dataset generated - 0s

Max train / Last run
---------------------
  CL   | Run | Time  
---------------------
 NCL   |   2 | 19
---------------------

Generating dataset...
Dataset generated - 0s

Max train / Last run
---------------------
  CL   | Run | Time  
---------------------
 NCL   |   3 | 19
---------------------

Generating dataset...
Dataset generated - 0s

Max train / Last run
---------------------
  CL   | Run | Time  
---------------------
 NCL   |   4 | 19
---------------------

Generating dataset...
Dataset generated - 0s

Max train / Last run
---------------------
  CL   | Run | Time  
-------------------

# Miscellaneous