In [55]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from classical_shadow import random_state

In [56]:
# ======================
# Configuración
# ======================
DEVICE = "cpu"
DTYPEC = torch.complex64
DTYPR = torch.float32
EPS    = 1e-10

In [57]:
# ======================
# Estado verdadero |psi_true> (1 qubit)
# ======================
psi_true = random_state()

psi_true

tensor([[0.4252+0.0000j],
        [0.5199-0.7408j]])

In [58]:
# ======================
# Probabilidades teóricas por base
# ======================
def probs_true_in_basis(psi, B):
    """
    Devuelve p(s|B) para s=0,1 y B in {'Z','X','Y'}.
    psi: (2,1) complejo
    """
    a0, a1 = psi.flatten()
    if B == 'Z':
        p0 = (a0.conj()*a0).real
        p1 = (a1.conj()*a1).real
    elif B == 'X':
        # |±> = (|0> ± |1>)/sqrt(2)
        A0 = (a0 + a1) / np.sqrt(2)  # s=0 => |+>
        A1 = (a0 - a1) / np.sqrt(2)  # s=1 => |->
        p0 = (A0.conj()*A0).real
        p1 = (A1.conj()*A1).real
    elif B == 'Y':
        # |±i> = (|0> ± i|1>)/sqrt(2)
        A0 = (a0 + 1j*a1) / np.sqrt(2)   # s=0 => |+i>
        A1 = (a0 - 1j*a1) / np.sqrt(2)   # s=1 => |-i>
        p0 = (A0.conj()*A0).real
        p1 = (A1.conj()*A1).real
    else:
        raise ValueError("Basis must be 'X','Y','Z'")
    # Garantiza normalización numérica
    norm = (p0 + p1).clamp_min(EPS)
    # norm = p0 + p1
    return torch.stack([p0/norm, p1/norm])

In [59]:
# ======================
# Simulación de datos (s, B)
# ======================
def sample_measurements(psi, M=2000, probs_bases=(1/3,1/3,1/3), seed=1):
    """
    Genera M tiros con bases (X,Y,Z) según probs_bases y resultados s~p_true(s|B).
    Devuelve lista de (s_tensor, B_char), donde s_tensor es 0 o 1 (LongTensor shape (1,))
    """
    rng = np.random.default_rng(seed)
    bases = np.array(['X','Y','Z'])
    probsB = np.array(probs_bases, dtype=float)
    data = []
    for _ in range(M):
        B = rng.choice(bases, p=probsB).item()
        p = probs_true_in_basis(psi, B).numpy().astype(np.float32)
        s = int(rng.choice([0,1], p = p))
        data.append((torch.tensor([s], dtype=torch.long), B))
    return data
    # return print(np.sum(p))

# Dataset
M = 50
data = sample_measurements(psi_true, M=M, probs_bases=(1/3,1/3,1/3), seed=2)
train_data = data

In [60]:
# ======================
# NQS mínimo para 1 qubit
# ======================
class NQS(nn.Module):
    """
    Parámetros:
      - logits (2): para p(0), p(1) con softmax (normaliza automáticamente)
      - phases (2): phi_0, phi_1 (reales)
    Amplitudes:
      alpha_0 = sqrt(p0) * exp(i phi_0)
      alpha_1 = sqrt(p1) * exp(i phi_1)
    """
    def __init__(self):
        super().__init__()
        self.logits = nn.Parameter(torch.zeros(2, dtype=DTYPR))
        self.phases = nn.Parameter(torch.zeros(2, dtype=DTYPR))

    def alphas(self):
        probs = F.softmax(self.logits, dim=0)             # (2,)
        amp   = torch.sqrt(probs).to(DTYPR) * torch.exp(1j*self.phases.to(DTYPR))
        return amp, probs

    def prob_s_given_B(self, s, B):
        """
        p_lambda(s|B) mediante amplitudes alpha_0, alpha_1.
        """
        al, probs = self.alphas()
        a0, a1 = al[0], al[1]
        if B == 'Z':
            return probs[s]  # |alpha_s|^2 = p(s) directamente
        elif B == 'X':
            A = (a0 + ((-1)**s.item())*a1) / np.sqrt(2.0)
            return (A.conj()*A).real.clamp_min(EPS)
        elif B == 'Y':
            # s=0 => (a0 + i a1)/√2 ; s=1 => (a0 - i a1)/√2
            A = (a0 + (1j if s.item()==0 else -1j)*a1) / np.sqrt(2.0)
            return (A.conj()*A).real.clamp_min(EPS)
        else:
            raise ValueError("Basis must be 'X','Y','Z'")

    def statevector(self):
        al, _ = self.alphas()
        psi = al.reshape(2,1).to(DTYPEC)
        # normaliza por seguridad
        psi = psi / torch.linalg.norm(psi)
        return psi

In [61]:
# ======================
# Loss (NLL sobre (s,B))
# ======================
def nll_batch(model, batch):
    logps = []
    for s, B in batch:
        p = model.prob_s_given_B(s, B)
        logps.append(torch.log2(p))
    
    logps = sum(logps)/len(logps) 
    return -logps

In [62]:
# ======================
# Entrenamiento
# ======================
model = NQS().to(DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=1e-2)
EPOCHS = 500
BATCH  = M

def batches(lst, bs):
    for i in range(0, len(lst), bs):
        yield lst[i:i+bs]

for epoch in range(1, EPOCHS+1):
    model.train()
    losses = []
    for batch in batches(train_data, BATCH):
        loss = nll_batch(model, batch)
        opt.zero_grad(); loss.backward(); opt.step()
        losses.append(loss.item())

    # if epoch % 20 == 0 or epoch == 1 or epoch == EPOCHS:
    #     print(f"Epoch {epoch:3d} | train NLL: {np.mean(losses):.6f}")

In [63]:
# ======================
# Fidelidad final
# ======================
with torch.no_grad():
    psi_est = model.statevector()           # (2,1)
    overlap = (psi_true.conj().T @ psi_est).squeeze()
    fidelity = (overlap.conj()*overlap).real.item()
    norm_est = float((psi_est.conj()*psi_est).sum().real)

print("\n=== Resultados (1 qubit) ===")
print(f"Fidelidad |<psi_true|psi_est>|^2 : {fidelity:.6f}")
print(f"Norma de |psi_est> (≈1):          {norm_est:.6f}")

# También puedes ver las probs y fases aprendidas:
with torch.no_grad():
    al, probs = model.alphas()
    phi0, phi1 = model.phases.tolist()
    print(f"p(0), p(1) estimadas: {probs[0].item():.4f}, {probs[1].item():.4f}")
    print(f"Fases phi0, phi1 (rad): {phi0:.3f}, {phi1:.3f}")


=== Resultados (1 qubit) ===
Fidelidad |<psi_true|psi_est>|^2 : 0.997772
Norma de |psi_est> (≈1):          1.000000
p(0), p(1) estimadas: 0.1604, 0.8396
Fases phi0, phi1 (rad): 0.531, -0.531
