In [19]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from itertools import product
from tqdm import tqdm
from IPython.display import clear_output
import matplotlib.pyplot as plt

In [20]:
# =======================
# Configuración general
# =======================
DEVICE  = "cpu"   # "cuda" si tienes GPU
CDTYPE  = torch.complex64
FDTYPE  = torch.float32
EPS     = 1e-10
N       = 3       # qubits (este script es para N=3)

# Tipos y constantes
DTYPR  = torch.float32
DTYPEC = torch.complex64
EPS    = 1e-12

In [21]:
# =======================
# Estado objetivo aleatorio |psi_true> (3 qubits)
# =======================
def random_pure_state(n=3):
    dim = 2**n
    re = torch.randn(dim, dtype=FDTYPE, device=DEVICE)
    im = torch.randn(dim, dtype=FDTYPE, device=DEVICE)
    psi = (re + 1j*im).to(CDTYPE).view(dim,1)
    psi /= torch.linalg.norm(psi)
    return psi

psi_true = random_pure_state(N)  # <- estado objetivo aleatorio (cambia la seed si quieres)

psi_true

tensor([[-0.1641-0.1823j],
        [-0.1027-0.0645j],
        [-0.3274-0.2322j],
        [-0.3430+0.4436j],
        [-0.0708-0.0008j],
        [-0.4167-0.3340j],
        [ 0.3602-0.1691j],
        [-0.0329-0.0010j]])

In [22]:
# ======================
# Probabilidades teóricas por base
# ======================

# Unitaries locales (Z: I, X: H, Y: S^† H)
inv_sqrt2 = 1.0 / torch.sqrt(torch.tensor(2.0, dtype=DTYPR))
H = inv_sqrt2 * torch.tensor([[1, 1],
                              [1,-1]], dtype=DTYPEC)
S = torch.tensor([[1, 0],
                  [0, 1j]], dtype=DTYPEC)
Umap = {'Z': torch.eye(2, dtype=DTYPEC),
        'X': H,
        'Y': S.conj().T @ H}  # S^\dagger H

def kron3(A, B, C):
    return torch.kron(torch.kron(A, B), C)

def probs_true_in_basis(psi, B: str):
    """
    Devuelve p(s|B) para s∈{0,1}^3 y B in {'X','Y','Z'}^3.
    psi: (8,1) complejo (normalizado idealmente)
    B: string de longitud 3, e.g. 'ZXY'
    return: tensor real shape (8,) con suma = 1
    """
    assert psi.shape == (8,1), "psi debe ser (8,1)"
    assert len(B) == 3 and all(c in "XYZ" for c in B), "B debe ser p.ej. 'ZXY'"

    # U_B = ⊗_i U_{P_i}
    U_B = kron3(Umap[B[0]], Umap[B[1]], Umap[B[2]])       # (8,8)

    # |psi_B> = U_B^† |psi|
    v = U_B.conj().T @ psi                                 # (8,1)

    # Probabilidades = |coef|^2 en base Z
    p = (v.conj() * v).real.view(-1).to(DTYPR)             # (8,)
    p = p / p.sum()                         # normalización robusta
    return p

In [23]:
# =======================
# Simulación de (s,B) reales
# =======================

# --- helpers: índice <-> bitstring (N=3) ---
def idx_to_bits(idx):
    # idx 0..7 -> tensor([b0,b1,b2]) con orden binario b0 b1 b2
    return torch.tensor([(idx>>2)&1, (idx>>1)&1, idx&1], dtype=torch.long)

def bits_to_idx(bits):  # bits: LongTensor([b0,b1,b2])
    b0,b1,b2 = int(bits[0].item()), int(bits[1].item()), int(bits[2].item())
    return (b0<<2) | (b1<<1) | b2  # índice 0..7

# ======================
# Simulación de datos (s, B) para 3 qubits
# ======================
def sample_measurements(psi, M=8000, probs_bases=(1/3,1/3,1/3), seed=1):
    """
    Genera M tiros. Para cada qubit elige P_i ∈ {X,Y,Z} con probs_bases,
    forma B = P1 P2 P3 (string) y luego samplea s ~ p_true(s|B) con probs_true_in_basis_3q.
    Devuelve lista de (s_bits, B_str) donde:
      - s_bits: LongTensor shape (3,) con 0/1
      - B_str: string tipo 'ZXY'
    """
    rng = np.random.default_rng(seed)
    letters = np.array(['X','Y','Z'])
    probsB  = np.array(probs_bases, dtype=float)
    assert np.isclose(probsB.sum(), 1.0), "probs_bases debe sumar 1"

    data = []
    for _ in range(M):
        # Base local por qubit (independiente por cada posición)
        B = ''.join(rng.choice(letters, size=3, p=probsB))  # p.ej. 'ZXY'
        # Probabilidades verdaderas p(s|B) sobre los 8 bitstrings
        p = probs_true_in_basis(psi, B).detach().cpu().numpy()
        # Muestra un índice 0..7 y pásalo a bits
        idx = int(rng.choice(8, p=p))
        s_bits = idx_to_bits(idx)
        data.append((s_bits, B))
    return data

# ===== Ejemplo de uso / dataset =====
M = 100
data = sample_measurements(psi_true, M=M, probs_bases=(0.4,0.4,0.2))
split = int(0.8*M)
train_data = data[:split]
val_data   = data[split:]
# print(len(train_data), len(val_data), train_data[0])

In [24]:
# ======================
# NQS mínimo para 3 qubits
# ======================
class ThreeQubitNQS(nn.Module):
    """
    Parámetros:
      - logits (8): softmax -> p(t) sobre t∈{0..7}
      - phases (8): phi_t (reales)
    Amplitudes en Z:
      alpha_t = sqrt(p_t) * exp(i phi_t)
    """
    def __init__(self):
        super().__init__()
        self.logits = nn.Parameter(torch.zeros(8, dtype=DTYPR))
        self.phases = nn.Parameter(torch.zeros(8, dtype=DTYPR))

    def alphas(self):
        probs = F.softmax(self.logits, dim=0)  # (8,)
        phi = self.phases
        phi = phi - phi[0]                     # fija fase global -> φ_0 = 0
        amps = torch.sqrt(probs).to(DTYPEC) * torch.exp(1j * phi.to(DTYPR))
        return amps, probs

    def statevector(self):
        amps, _ = self.alphas()
        psi = amps.reshape(8,1).to(DTYPEC)
        # (por construcción ya está normalizado, igual renormalizamos por seguridad)
        psi = psi / torch.linalg.norm(psi)
        return psi  # (8,1)

    def prob_s_given_B(self, s_bits: torch.LongTensor, B: str):
        """
        p_lambda(s|B) = |<s,B|psi>|^2
        - s_bits: tensor long shape (3,), con bits 0/1
        - B: string 'XYZ' (long 3)
        """
        psi = self.statevector()                       # (8,1)
        p_all = probs_true_in_basis(psi, B)         # (8,)
        idx = bits_to_idx(s_bits)
        return p_all[idx].clamp_min(EPS)               # escalar >= EPS

    # (opcional) para obtener todas las probs p(s|B) de una:
    def probs_all_s_given_B(self, B: str):
        return probs_true_in_basis(self.statevector(), B)  # (8,)

In [25]:
# ======================
# Loss (NLL sobre (s,B))
# ======================
def nll_batch(model, batch):
    logps = []
    for s_bits, B in batch:
        p = model.prob_s_given_B(s_bits, B)
        logps.append(torch.log2(p))        # usa torch.log2(p) si quieres bits
    
    logps = sum(logps)/len(logps) 
    return -logps

In [26]:
# ======================
# Entrenamiento (3 qubits)
# ======================
model = ThreeQubitNQS().to(DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=1e-2)
EPOCHS = 500
BATCH  = 50

loss_history = []

def batches(lst, bs):
    for i in range(0, len(lst), bs):
        yield lst[i:i+bs]

for epoch in tqdm(range(1, EPOCHS+1), desc="Training"):
    model.train()
    losses = []
    for batch in batches(train_data, BATCH):
        loss = nll_batch(model, batch)  # True => loss en bits
        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # opcional, estable
        opt.step()
        losses.append(loss.item())
        
    loss_history.append(np.mean(losses))

    # validación
    model.eval()
    with torch.no_grad():
        val_loss = nll_batch(model, val_data).item()

    if epoch % 20 == 0 or epoch == 1 or epoch == EPOCHS:
        clear_output(wait=True)
        print(f"Epoch {epoch:3d} | train NLL (bits): {np.mean(losses):.6f} | val NLL (bits): {val_loss:.6f}")

Training: 100%|██████████| 500/500 [00:18<00:00, 26.41it/s]

Epoch 500 | train NLL (bits): 2.332618 | val NLL (bits): 2.680105





In [27]:
# ======================
# Fidelidad final (3 qubits)
# ======================
with torch.no_grad():
    psi_est = model.statevector()           # (8,1)
    fidelity = (torch.norm((psi_true.conj().T @ psi_est))**2).item()  # escalar real
    norm_est = float((psi_est.conj()*psi_est).sum().real)

print("\n=== Resultados (3 qubits) ===")
print(f"Fidelidad |<psi_true|psi_est>|^2 : {fidelity:.6f}")
print(f"Norma de |psi_est> (≈1):          {norm_est:.6f}")

# También puedes ver las probs y fases aprendidas:
with torch.no_grad():
    al, probs = model.alphas()
    phases = model.phases.tolist()
    print("Probabilidades estimadas (p_t para t=0..7):")
    print([f"{p.item():.4f}" for p in probs])
    print("Fases φ_t (rad, t=0..7):")
    print([f"{phi:.3f}" for phi in phases])


=== Resultados (3 qubits) ===
Fidelidad |<psi_true|psi_est>|^2 : 0.817136
Norma de |psi_est> (≈1):          1.000000
Probabilidades estimadas (p_t para t=0..7):
['0.0590', '0.1261', '0.2327', '0.1861', '0.0491', '0.2256', '0.1104', '0.0110']
Fases φ_t (rad, t=0..7):
['0.164', '-0.321', '-0.208', '-1.644', '0.136', '0.456', '1.738', '-0.932']
