In [3]:
pip install jax==0.4.23 jaxlib==0.4.23 optax==0.1.7 pennylane==0.36.0 scipy==1.10.1

Collecting scipy==1.10.1
  Using cached scipy-1.10.1-cp310-cp310-macosx_12_0_arm64.whl.metadata (53 kB)
Using cached scipy-1.10.1-cp310-cp310-macosx_12_0_arm64.whl (28.8 MB)
Installing collected packages: scipy
  Attempting uninstall: scipy
    Found existing installation: scipy 1.15.3
    Uninstalling scipy-1.15.3:
      Successfully uninstalled scipy-1.15.3
Successfully installed scipy-1.10.1
Note: you may need to restart the kernel to use updated packages.


In [1]:
import os
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, cohen_kappa_score, matthews_corrcoef
import pandas as pd
import time
from sklearn import datasets
from sklearn.datasets import load_wine
from sklearn.utils import shuffle
from copy import deepcopy
import pennylane as qml
import pennylane as qml
from pennylane import math as qmath 
from jax import config
config.update("jax_enable_x64", True)  # Habilita float64/complex128
import jax
import jax.numpy as jnp
import jax.scipy.linalg as jsp
import pennylane as qml
import optax

In [2]:

# -------- Gell-Mann apiladas (índice 1..8) --------
def _gellmann_stack(dtype=jnp.complex64):
    l0 = jnp.zeros((3,3), dtype=dtype)  # placeholder en 0
    l1 = jnp.array([[0,1,0],[1,0,0],[0,0,0]], dtype=dtype)
    l2 = jnp.array([[0,-1j,0],[1j,0,0],[0,0,0]], dtype=dtype)
    l3 = jnp.array([[1,0,0],[0,-1,0],[0,0,0]], dtype=dtype)
    l4 = jnp.array([[0,0,1],[0,0,0],[1,0,0]], dtype=dtype)
    l5 = jnp.array([[0,0,-1j],[0,0,0],[1j,0,0]], dtype=dtype)
    l6 = jnp.array([[0,0,0],[0,0,1],[0,1,0]], dtype=dtype)
    l7 = jnp.array([[0,0,0],[0,0,-1j],[0,1j,0]], dtype=dtype)
    l8 = (1/jnp.sqrt(3)) * jnp.array([[1,0,0],[0,1,0],[0,0,-2]], dtype=dtype)
    return jnp.stack([l0,l1,l2,l3,l4,l5,l6,l7,l8], axis=0)

LMATS = _gellmann_stack()  # (9,3,3) complejo

# -------- Secuencia temporal (primero actúa λ8) --------
TEMP_SEQ = (8, 3, 2, 3, 5, 3, 2, 3)

def su3_unitary(alphas: jnp.ndarray, seq: tuple = TEMP_SEQ) -> jnp.ndarray:
    """
    Construye U tal que al aplicar al estado se cumple:
        U|ψ> = F_3 F_2 F_3 F_5 F_3 F_2 F_3 F_8 |ψ>
    cuando len(alphas)=8 y seq = TEMP_SEQ; para n<8, usa el prefijo.
    Convención: F_k(a) = exp(i * a * λ_k).  Acumulamos U = F @ U.
    """
    n = alphas.shape[0]
    ks = jnp.array(seq[:n], dtype=jnp.int32)

    def body(U, ak):
        a, k = ak
        lam = LMATS[k]  # (3,3)
        F = jsp.expm(1j * a.astype(jnp.complex64) * lam)
        return F @ U, None  # pre-multiplico para respetar orden temporal

    U0 = jnp.eye(3, dtype=jnp.complex64)
    (U_final, _) = jax.lax.scan(body, U0, (alphas, ks))
    return U_final


In [16]:
# -------------------- QNode (JAX) --------------------
dev = qml.device("default.qutrit", wires=1)

@qml.qnode(dev, interface="jax", diff_method="backprop")
def qcircuit(params, x):
    for p in params:
        
        # encoding layer
        qml.QutritUnitary(su3_unitary(x), wires=0)
        # variational layer
        qml.QutritUnitary(su3_unitary(p), wires=0)
        
    return qml.state()

v_qcircuit = jax.vmap(qcircuit, in_axes=(None, 0), out_axes=0)

# ---------- quantum labels (|0>, |1> as density matrices) ----------
def create_dm_labels():
    zero = jnp.array([1.0, 0.0, 0.0], dtype=jnp.float32)
    one  = jnp.array([0.0, 1.0, 0.0], dtype=jnp.float32)
    two  = jnp.array([0.0, 0.0, 1.0], dtype=jnp.float32)
    return jnp.stack([zero, one, two]) 

dm_labels = create_dm_labels()

# --------------------------- cost function --------------------------
# def cost(params, X, y, dm_labels):
#     # y debe ser int para indexar dm_labels[y]
#     y = y.astype(jnp.int32)
#     dm_y = dm_labels[y]                    # (batch, 2, 2)
#     f = v_qcircuit(params, X, dm_y)    # (batch,)
#     return jnp.mean((1.0 - f) ** 2)
def cost(params, X, y, dm_labels):
    states = v_qcircuit(params, X)        # (batch, 3)
    targets = dm_labels[y]                # (batch, 3)

    # fidelity = jnp.abs(jnp.vdot(states, targets)) ** 2
    fidelity = jnp.sum(jnp.conj(states) * targets, axis=1)
    fidelity = jnp.abs(fidelity) ** 2
    return jnp.mean((1 - fidelity) ** 2)

# ------------------------------ metrics ------------------------------
# def predict(params, X, dm_labels):
#     def fidelities_for_sample(xi):
#         # evalúa fidelidad contra cada label_dm
#         return jnp.array([qcircuit(params, xi, dm) for dm in dm_labels])
#     fidelities = jax.vmap(fidelities_for_sample)(X)   # (batch, n_labels)
#     return jnp.argmax(fidelities, axis=1)
def predict(params, X, state_labels):
    """
    Retorna la clase predicha para cada muestra de X,
    usando fidelidad con los vectores `state_labels`.
    """
    states = v_qcircuit(params, X)  # (batch, 3)

    # Calculamos fidelidad con cada label
    def fidelity_with_label(target_state):
        # fidelity: batch de productos ⟨target|state⟩
        inner = jnp.sum(jnp.conj(states) * target_state, axis=1)
        return jnp.abs(inner) ** 2  # shape (batch,)

    fidelities = jnp.stack([
        fidelity_with_label(target_state)
        for target_state in state_labels
    ], axis=1)  # shape (batch, num_labels)

    return jnp.argmax(fidelities, axis=1)  # clase con mayor fidelidad

@jax.jit
def accuracy(y_true, y_pred):
    return jnp.mean((y_true.astype(jnp.int32) == y_pred.astype(jnp.int32)).astype(jnp.float32))

def evaluate_metrics(X, y, params, dm_labels):
    """
    Evalúa métricas de clasificación para el clasificador cuántico.

    Devuelve:
      - accuracy (%)
      - f1 macro (%)
      - precision macro (%)
      - recall macro (%)
      - cohen's kappa
      - matthews corrcoef (MCC)
      - y_true (numpy array)
      - y_pred (numpy array)
    """
    # y puede venir como DeviceArray -> pasamos a numpy
    y_true = np.asarray(y)
    # predict(...) ya vectoriza internamente con vmap
    y_pred = np.asarray(predict(params, X, dm_labels))

    acc = accuracy_score(y_true, y_pred) * 100.0
    f1  = f1_score(y_true, y_pred, average="macro", zero_division=0) * 100.0
    prec = precision_score(y_true, y_pred, average="macro", zero_division=0) * 100.0
    rec  = recall_score(y_true, y_pred, average="macro", zero_division=0) * 100.0
    kappa = cohen_kappa_score(y_true, y_pred)
    mcc   = matthews_corrcoef(y_true, y_pred)

    return acc, f1, prec, rec, kappa, mcc, y_true, y_pred

# --------------------------- training step ---------------------------
@jax.jit
def train_step(params, opt_state, X_batch, y_batch, dm_labels):
    y_batch = y_batch.astype(jnp.int32)

    loss_fn = lambda prms: cost(prms, X_batch, y_batch, dm_labels)
    loss, grads = jax.value_and_grad(loss_fn)(params)

    updates, new_opt_state = optimizer.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss

# ------------------------ iterate_minibatches ------------------------
def iterate_minibatches(X, y, batch_size):
  n = X.shape[0]
  for i in range(0, n, batch_size):
    yield X[i:i+batch_size], y[i:i+batch_size]

# ----------------- auxiliary function to plot data  -----------------
def plot_data(X, labels, ax, title=""):
    X_np = np.array(X)
    y_np = np.array(labels)
    ax.scatter(X_np[y_np==0, 0], X_np[y_np==0, 1], s=10, alpha=0.7, label="Clase 0")
    ax.scatter(X_np[y_np==1, 0], X_np[y_np==1, 1], s=10, alpha=0.7, label="Clase 1")
    ax.set_title(title)
    ax.set_xlabel("x1")
    ax.set_ylabel("x2")
    ax.legend(loc="best", fontsize=8)
    

In [17]:

'''
      ==================
            data
      ==================
'''

# create MinMaxScaler object
scaler = MinMaxScaler()


SEED = 1337
np.random.seed(SEED) # numpy
key = jax.random.PRNGKey(SEED) # jax

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


def load_dataset(n_components):

    # dataset
    wine = load_wine()

    # features
    X_data = wine.data
    X_data = scaler.fit_transform(X_data)

    # labels
    y_data = wine.target

    pca = PCA(n_components=n_components)

    x_data = pca.fit_transform(X_data)

    x_data, y_data = shuffle(x_data, y_data, random_state=42)

    return x_data, y_data, pca.explained_variance_ratio_.sum()


# stratified split with seed (80% train / 20% test)
X_np, y_np, _ = load_dataset(8)

X_train_np, X_test_np, y_train_np, y_test_np = train_test_split(
    X_np, y_np,
    train_size=0.8,
    test_size=0.2,
    stratify=y_np,
    shuffle=True,
    random_state=SEED,
)

# numpy data to jax data
X_train = jnp.array(X_train_np, dtype=jnp.float32)
X_test  = jnp.array(X_test_np,  dtype=jnp.float32)
y_train = jnp.array(y_train_np, dtype=jnp.int32)
y_test  = jnp.array(y_test_np,  dtype=jnp.int32)


In [19]:
''' -----------------------------------
             1 qutrit training
    -----------------------------------
'''

# init
LEARNING_RATE = 0.001
NUM_LAYERS    = 6
INIT_PARAMS   = jax.random.uniform(key, shape=(NUM_LAYERS, 8), minval=-1.0, maxval=1.0)
optimizer     = optax.adam(LEARNING_RATE)
opt_state     = optimizer.init(INIT_PARAMS)
EPOCHS        = 200
BATCH_SIZE    = 32

# initiate timer
start_time = time.time()

# start training
params = INIT_PARAMS
for epoch in range(EPOCHS):
    for Xb, yb in iterate_minibatches(X_train, y_train, BATCH_SIZE):
        params, opt_state, loss = train_step(params, opt_state, Xb, yb, dm_labels)

    y_pred_train = predict(params, X_train, dm_labels)
    y_pred_test  = predict(params, X_test,  dm_labels)
    acc_train = accuracy(y_train, y_pred_train)
    acc_test = accuracy(y_test,  y_pred_test)

    print(f"Epoch {epoch+1:02d} | Loss: {float(loss):.4f} | Train Acc: {float(acc_train):.3f} | Test Acc: {float(acc_test):.3f}")

# total elapsed time
elapsed_time = time.time() - start_time

# metrics
final_loss_train = float(cost(params, X_train, y_train, dm_labels))
acc_train, _, _, _, _, _, _, _ = evaluate_metrics(X_train, y_train, params, dm_labels)

final_loss_test = float(cost(params, X_test, y_test, dm_labels))
acc_test, f1, prec, rec, kappa, mcc, y_true, y_pred = evaluate_metrics(X_test, y_test, params, dm_labels)

# print metrics
print(
    f"Total time: {elapsed_time:.2f}s"
    f"Train Loss: {final_loss_train:.6f} | "
    f"Test Loss: {final_loss_test:.6f} | "
    f"Acc: {acc_test:.2f}% | F1: {f1:.2f}% | "
    f"Prec: {prec:.2f}% | Rec: {rec:.2f}% | "
    f"κ: {kappa:.3f} | MCC: {mcc:.3f}"
)


Epoch 01 | Loss: 0.5318 | Train Acc: 0.359 | Test Acc: 0.444
Epoch 02 | Loss: 0.4339 | Train Acc: 0.472 | Test Acc: 0.583
Epoch 03 | Loss: 0.3497 | Train Acc: 0.599 | Test Acc: 0.639
Epoch 04 | Loss: 0.2879 | Train Acc: 0.655 | Test Acc: 0.639
Epoch 05 | Loss: 0.2464 | Train Acc: 0.697 | Test Acc: 0.694
Epoch 06 | Loss: 0.2205 | Train Acc: 0.704 | Test Acc: 0.722
Epoch 07 | Loss: 0.2061 | Train Acc: 0.732 | Test Acc: 0.694
Epoch 08 | Loss: 0.1992 | Train Acc: 0.739 | Test Acc: 0.722
Epoch 09 | Loss: 0.1961 | Train Acc: 0.761 | Test Acc: 0.750
Epoch 10 | Loss: 0.1945 | Train Acc: 0.782 | Test Acc: 0.750
Epoch 11 | Loss: 0.1928 | Train Acc: 0.789 | Test Acc: 0.778
Epoch 12 | Loss: 0.1902 | Train Acc: 0.789 | Test Acc: 0.778
Epoch 13 | Loss: 0.1864 | Train Acc: 0.803 | Test Acc: 0.778
Epoch 14 | Loss: 0.1818 | Train Acc: 0.796 | Test Acc: 0.778
Epoch 15 | Loss: 0.1771 | Train Acc: 0.803 | Test Acc: 0.806
Epoch 16 | Loss: 0.1727 | Train Acc: 0.810 | Test Acc: 0.806
Epoch 17 | Loss: 0.1686 