# RealNet MNIST

In [7]:
"""
Block 1: RealNet Training
==========================
Trains the baseline Real MLP on 3 seeds.
Saves the trained preprocessor for reuse in subsequent blocks.

Outputs:
- realnet_results.pt: Contains results dict and trained preprocessor state
"""

import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from collections import defaultdict

device = torch.device("cuda")
print(f"Using PyTorch device: {device}")

# ============================================
# Comprehensive seed setting for reproducibility
# ============================================

def set_all_seeds(seed):
    """Set seeds for all RNG sources for reproducibility"""
    import random
    import numpy as np
    import torch
    
    # Python
    random.seed(seed)
    
    # NumPy
    np.random.seed(seed)
    
    # PyTorch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    
    # PyTorch deterministic mode
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # CuPy (if available)
    try:
        import cupy as cp
        cp.random.seed(seed)
    except ImportError:
        pass
    
    # PennyLane (if available)
    try:
        import pennylane as qml
        qml.numpy.random.seed(seed)
    except (ImportError, AttributeError):
        pass

# ============================================
# Shared Bottleneck Preprocessor: 784 → 16
# ============================================

class SharedPreprocessor(nn.Module):
    """Shared classical feature extractor: 784 → 16"""
    def __init__(self, input_dim=784, bottleneck_dim=16):
        super().__init__()
        self.fc = nn.Linear(input_dim, bottleneck_dim)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.tanh(self.fc(x))
        return x


# ============================================
# Real-valued Head and Network
# ============================================

class RealHead(nn.Module):
    """Standard MLP: 16 → 64 → 10"""
    def __init__(self, bottleneck_dim=16, hidden_dim=64, num_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(bottleneck_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


class RealNet(nn.Module):
    """Complete Real network: Preprocessor + RealHead"""
    def __init__(self):
        super().__init__()
        self.preprocessor = SharedPreprocessor(784, 16)
        self.head = RealHead(16, 64, 10)

    def forward(self, x):
        features = self.preprocessor(x)
        return self.head(features)


# ============================================
# Training and Evaluation
# ============================================

def train_one_epoch(model, loader, optimizer, device, show_progress=False):
    model.train()
    total_loss = 0.0
    total_samples = 0

    for batch_idx, (x, y) in enumerate(loader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        total_samples += x.size(0)

        if show_progress and batch_idx % 50 == 0:
            print(f"    Batch {batch_idx}/{len(loader)}, samples: {total_samples}", end="\r")

    if show_progress:
        print()
    return total_loss / len(loader.dataset)


def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    return correct / total if total > 0 else 0.0


def train_with_early_stopping(model, train_loader, test_loader, optimizer,
                              device, max_epochs=40, patience=10, name="Model"):
    if isinstance(device, str):
        device = torch.device(device)

    best_acc = 0.0
    epochs_without_improvement = 0
    start = time.time()
    last_acc = 0.0

    for epoch in range(1, max_epochs + 1):
        loss = train_one_epoch(model, train_loader, optimizer, device, show_progress=False)
        acc = evaluate(model, test_loader, device)
        last_acc = acc

        elapsed = time.time() - start
        print(f"  [{name}] Epoch {epoch:2d} | loss={loss:.4f} "
              f"| test_acc={acc:.4f} | time={elapsed:.1f}s")

        if acc > best_acc:
            best_acc = acc
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print(f"  [{name}] Early stop at epoch {epoch} "
                  f"(no improvement for {patience} epochs)")
            break

    total_time = time.time() - start
    return {
        "best_acc": best_acc,
        "final_acc": last_acc,
        "time": total_time,
        "epochs": epoch,
    }


# ============================================
# Data utilities
# ============================================

def stratified_sample(dataset, n_samples_per_class):
    """
    Create a stratified sample with n_samples_per_class from each class.
    Maintains class balance.
    """
    # Group indices by class
    class_indices = defaultdict(list)
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)
    
    # Sample from each class
    sampled_indices = []
    for class_label in sorted(class_indices.keys()):
        indices = class_indices[class_label]
        # Use fixed random seed for reproducibility
        rng = np.random.RandomState(42)
        selected = rng.choice(indices, size=min(n_samples_per_class, len(indices)), 
                             replace=False)
        sampled_indices.extend(selected)
    
    return sampled_indices


# ============================================
# Main
# ============================================

def main():
    print("=" * 70)
    print("BLOCK 1: RealNet Training")
    print("=" * 70)
    print("\nConfiguration:")
    print("  • Batch size: 32")
    print("  • Patience: 10")
    print("  • Seeds: [42, 123, 456]")
    print("  • Architecture: 784 → 16 → 64 → 10")
    print("=" * 70)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    # Load full datasets
    full_train_ds = datasets.MNIST(root="./data", train=True,
                                   download=True, transform=transform)
    full_test_ds = datasets.MNIST(root="./data", train=False,
                                  download=True, transform=transform)

    # Create stratified samples
    print("\nCreating stratified samples...")
    train_indices = stratified_sample(full_train_ds, n_samples_per_class=1500)  # 15K total
    test_indices = stratified_sample(full_test_ds, n_samples_per_class=300)     # 3K total
    
    train_ds = Subset(full_train_ds, train_indices)
    test_ds = Subset(full_test_ds, test_indices)
    
    print(f"  Train samples: {len(train_ds)} (stratified, 1500 per class)")
    print(f"  Test samples:  {len(test_ds)} (stratified, 300 per class)")

    train_loader = DataLoader(train_ds, batch_size=32,
                              shuffle=True, num_workers=0)
    test_loader = DataLoader(test_ds, batch_size=64,
                             shuffle=False, num_workers=0)

    seeds = [42, 123, 456]
    all_results = []
    trained_preprocessor_state = None

    for seed in seeds:
        print(f"\n{'=' * 70}")
        print(f"SEED {seed}")
        print("=" * 70)
        
        set_all_seeds(seed)

        print(f"\n  Training RealNet (seed={seed})...")
        real_model = RealNet().to(device)
        real_opt = torch.optim.Adam(real_model.parameters(), lr=1e-3)
        
        result = train_with_early_stopping(
            real_model, train_loader, test_loader, real_opt, device,
            max_epochs=40, patience=10, name="Real"
        )
        result["params"] = sum(p.numel() for p in real_model.parameters())
        result["seed"] = seed
        
        all_results.append(result)
        
        # Save preprocessor from first seed
        if trained_preprocessor_state is None:
            trained_preprocessor_state = real_model.preprocessor.state_dict()
            print(f"\n  → Saved preprocessor state from seed {seed}")

    # Summary
    print("\n" + "=" * 70)
    print("REALNET SUMMARY")
    print("=" * 70)
    
    accs = [r["best_acc"] for r in all_results]
    times = [r["time"] for r in all_results]
    epochs = [r["epochs"] for r in all_results]
    params = all_results[0]["params"]
    
    print(f"\nAccuracy:   {np.mean(accs):.4f} ± {np.std(accs):.4f}")
    print(f"Time:       {np.mean(times):.1f}s ± {np.std(times):.1f}s")
    print(f"Epochs:     {np.mean(epochs):.1f} ± {np.std(epochs):.1f}")
    print(f"Parameters: {params:,}")
    
    print("\nPer-seed results:")
    for r in all_results:
        print(f"  Seed {r['seed']}: acc={r['best_acc']:.4f}, "
              f"time={r['time']:.1f}s, epochs={r['epochs']}")

    # Save results and preprocessor
    save_dict = {
        "results": all_results,
        "preprocessor_state": trained_preprocessor_state,
        "summary": {
            "mean_acc": np.mean(accs),
            "std_acc": np.std(accs),
            "mean_time": np.mean(times),
            "params": params
        }
    }
    
    torch.save(save_dict, "realnet_results.pt")
    print(f"\n✓ Saved results to: realnet_results.pt")
    print("=" * 70)


if __name__ == "__main__":
    main()

Using PyTorch device: cuda
BLOCK 1: RealNet Training

Configuration:
  • Batch size: 32
  • Patience: 10
  • Seeds: [42, 123, 456]
  • Architecture: 784 → 16 → 64 → 10

Creating stratified samples...
  Train samples: 15000 (stratified, 1500 per class)
  Test samples:  3000 (stratified, 300 per class)

SEED 42

  Training RealNet (seed=42)...
  [Real] Epoch  1 | loss=0.7188 | test_acc=0.8977 | time=1.3s
  [Real] Epoch  2 | loss=0.2981 | test_acc=0.9090 | time=2.6s
  [Real] Epoch  3 | loss=0.2502 | test_acc=0.9157 | time=4.0s
  [Real] Epoch  4 | loss=0.2143 | test_acc=0.9173 | time=5.2s
  [Real] Epoch  5 | loss=0.1910 | test_acc=0.9293 | time=6.5s
  [Real] Epoch  6 | loss=0.1768 | test_acc=0.9210 | time=7.9s
  [Real] Epoch  7 | loss=0.1631 | test_acc=0.9253 | time=9.2s
  [Real] Epoch  8 | loss=0.1486 | test_acc=0.9290 | time=10.6s
  [Real] Epoch  9 | loss=0.1428 | test_acc=0.9260 | time=11.9s
  [Real] Epoch 10 | loss=0.1289 | test_acc=0.9250 | time=13.2s
  [Real] Epoch 11 | loss=0.1239 |

# Diagnostic MNIST

In [10]:
"""
Mini-script: Euclid vs Adam vs FS-NG covariance test (euclid vs adam vs diag vs block-diag vs full)
=================================================================================================

What this does
--------------
Runs K update steps on *multiple fixed minibatches* (10 minibatches) for each of 3 seeds,
updating ONLY the circuit weights θ, while holding the preprocessor + feature mapping + readout fixed.

We compare five update rules:

  1) euclid     : plain Euclidean gradient descent        θ <- θ - η * g
  2) adam       : Adam optimizer on θ (PyTorch-style)     θ <- AdamStep(g)
  3) diag       : FS/QFI natural gradient (metric_tensor approx="diag")
  4) block-diag : FS/QFI natural gradient (approx="block-diag")
  5) full       : FS/QFI natural gradient (approx=None, needs aux wire)

Optional hybrid (“shrinkage” FS metric toward Euclidean)
--------------------------------------------------------
Instead of pure FS-NG, you can blend the metric toward identity:

    G_eff = α * G_bar + (1-α) * I
    step  = (G_eff + λ I)^(-1) g

Set ALPHA_SHRINK in (0,1]. If you want *pure* FS-NG, set ALPHA_SHRINK=1.0.

Goal
----
- See whether FS covariance structure helps over Euclidean GD and Adam
- Quantify directional difference (cosine similarity vs full step direction)
- Keep runtime low: small minibatches, 10 minibatches, K steps each, 3 seeds

Requires
--------
- realnet_results.pt (from Block 1; contains frozen preprocessor weights)
- pennylane
- torchvision
- optional: pennylane-lightning-gpu (for lightning.gpu)

Notes
-----
- FULL metric_tensor (approx=None) requires an auxiliary wire on many devices.
  We allocate N_QUBITS + 1 wires and reserve the last wire as aux.
- Adam is implemented in a lightweight, deterministic “PyTorch Adam” way on θ only,
  so you can compare it without rerunning multi-day experiments.
"""

import time
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import pennylane as qml
from pennylane import numpy as pnp


# ---------------------------
# Config (tune for speed)
# ---------------------------
SEEDS = [42, 123, 456]

DATASET = "MNIST"          # fixed per your request
N_QUBITS = 4
N_LAYERS = 2
BATCH_SIZE = 8

N_MINIBATCHES = 10         # 10 minibatches per seed
K_STEPS = 10               # steps per minibatch (kept from prior diagnostic)

ETA = 0.01                 # Euclid/FS learning rate
LAM = 1e-3                 # FS damping

# Shrinkage: α=1 => pure FS, α<1 => blend toward Euclidean identity metric
ALPHA_SHRINK = 1.0         # try 0.8 or 0.5 for “Euclid+FS” blend

# Adam hyperparams (match your prior analysis spirit)
ADAM_LR = 1e-3
ADAM_BETAS = (0.9, 0.999)
ADAM_EPS = 1e-8
ADAM_WEIGHT_DECAY = 0.0

PREFERRED_DEVICE = "lightning.gpu"
FALLBACK_DEVICE = "default.qubit"


# ---------------------------
# Reproducibility
# ---------------------------
def set_all_seeds(seed: int):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    try:
        qml.numpy.random.seed(seed)
    except Exception:
        pass


# ---------------------------
# Shared Preprocessor (Block 1)
# ---------------------------
class SharedPreprocessor(nn.Module):
    """Shared classical feature extractor: 784 → 16"""
    def __init__(self, input_dim=784, bottleneck_dim=16):
        super().__init__()
        self.fc = nn.Linear(input_dim, bottleneck_dim)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return torch.tanh(self.fc(x))


# ---------------------------
# Device factory (+1 aux wire for full metric)
# ---------------------------
def make_device(n_qubits: int):
    n_wires_total = n_qubits + 1
    aux_wire = n_qubits
    try:
        dev = qml.device(PREFERRED_DEVICE, wires=n_wires_total)
        backend = PREFERRED_DEVICE
    except Exception as e:
        dev = qml.device(FALLBACK_DEVICE, wires=n_wires_total)
        backend = FALLBACK_DEVICE
        print(f"[device] Could not load {PREFERRED_DEVICE} ({type(e).__name__}: {e}). Using {FALLBACK_DEVICE}.")
    print(f"[device] backend={backend}, wires={n_wires_total} (aux_wire={aux_wire})")
    return dev, backend, aux_wire


# ---------------------------
# Helpers
# ---------------------------
def _cosine(a: np.ndarray, b: np.ndarray) -> float:
    na = np.linalg.norm(a) + 1e-12
    nb = np.linalg.norm(b) + 1e-12
    return float(np.dot(a, b) / (na * nb))


def _flat(theta) -> pnp.ndarray:
    return pnp.reshape(theta, (-1,))


def _unflat(v, shape) -> pnp.ndarray:
    return pnp.reshape(v, shape)


# ---------------------------
# Build circuit + metric fns
# ---------------------------
dev, backend, aux_wire = make_device(N_QUBITS)

@qml.qnode(dev, interface="autograd", diff_method="parameter-shift")
def circuit(inputs, weights):
    # circuit uses ONLY wires 0..N_QUBITS-1; last wire is aux-reserved
    for layer in range(N_LAYERS):
        for i in range(N_QUBITS):
            qml.RY(inputs[i], wires=i)
        for i in range(N_QUBITS):
            qml.RY(weights[layer, i, 0], wires=i)
            qml.RZ(weights[layer, i, 1], wires=i)

        for i in range(N_QUBITS - 1):
            qml.CNOT(wires=[i, i + 1])
        if N_QUBITS > 2:
            qml.CNOT(wires=[N_QUBITS - 1, 0])

    return (
        qml.expval(qml.PauliZ(0)),
        qml.expval(qml.PauliZ(1)),
        qml.expval(qml.PauliZ(2)),
        qml.expval(qml.PauliZ(3)),
        qml.expval(qml.PauliZ(0) @ qml.PauliZ(1)),
        qml.expval(qml.PauliZ(2) @ qml.PauliZ(3)),
    )

metric_full  = qml.metric_tensor(circuit, approx=None,       aux_wire=aux_wire)
metric_diag  = qml.metric_tensor(circuit, approx="diag")
metric_block = qml.metric_tensor(circuit, approx="block-diag")


# ---------------------------
# Load frozen preprocessor from Block 1
# ---------------------------
realnet = torch.load("realnet_results.pt", weights_only=False)
pre_state = realnet["preprocessor_state"]

pre = SharedPreprocessor(784, 16)
pre.load_state_dict(pre_state)
for p in pre.parameters():
    p.requires_grad = False
pre.eval()

# frozen feature_select (16 -> 4)
feature_select = nn.Linear(16, N_QUBITS)
for p in feature_select.parameters():
    p.requires_grad = False
feature_select.eval()


# ---------------------------
# Dataset: MNIST (fixed)
# ---------------------------
norm_mean, norm_std = (0.1307,), (0.3081,)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])
train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)


# ---------------------------
# Fixed readout (constant) to isolate θ update behavior
# ---------------------------
C = 10
W = pnp.array(np.random.randn(C, 6).astype(np.float64) * 0.1)  # (10,6)
b = pnp.array(np.zeros((C,), dtype=np.float64))


# ---------------------------
# Loss factory (depends on current minibatch)
# ---------------------------
def make_batch_ce_loss(x_np, y_np):
    def batch_ce_loss(theta):
        losses = []
        for i in range(len(y_np)):
            q_raw = circuit(x_np[i], theta)
            q = pnp.stack(q_raw)  # (6,)
            logits = W @ q + b
            logits = logits - pnp.max(logits)
            ex = pnp.exp(logits)
            probs = ex / pnp.sum(ex)
            losses.append(-pnp.log(probs[int(y_np[i])] + 1e-12))
        return pnp.mean(pnp.stack(losses))
    return batch_ce_loss


# ---------------------------
# FS-NG / Euclid delta
# ---------------------------
def fsng_or_euclid_delta(theta, approx: str, batch_ce_loss, x_np):
    """
    approx in {"euclid","diag","block-diag","full"}.

    - euclid: delta = g
    - others: delta = (G_eff + λI)^(-1) g
              with optional shrinkage: G_eff = α G_bar + (1-α) I
    """
    theta_shape = theta.shape
    P = int(np.prod(theta_shape))
    I = pnp.eye(P, dtype=pnp.float64)

    g = qml.grad(batch_ce_loss)(theta)
    g_flat = pnp.reshape(g, (P,))

    if approx == "euclid":
        return np.array(g_flat, dtype=np.float64)

    if approx == "diag":
        metric_fn = metric_diag
    elif approx == "block-diag":
        metric_fn = metric_block
    elif approx == "full":
        metric_fn = metric_full
    else:
        raise ValueError("approx must be one of: 'euclid','diag','block-diag','full'")

    G_sum = pnp.zeros((P, P), dtype=pnp.float64)
    for i in range(x_np.shape[0]):
        Gi = metric_fn(x_np[i], theta)
        Gi = pnp.reshape(Gi, (P, P))
        G_sum = G_sum + Gi
    G_bar = G_sum / float(x_np.shape[0])

    alpha = float(ALPHA_SHRINK)
    G_eff = alpha * G_bar + (1.0 - alpha) * I
    G_reg = G_eff + LAM * I

    delta = pnp.linalg.solve(G_reg, g_flat)
    return np.array(delta, dtype=np.float64)


# ---------------------------
# Adam (PyTorch-style) on θ only
# ---------------------------
class AdamState:
    def __init__(self, P, betas=(0.9, 0.999)):
        self.t = 0
        self.m = np.zeros((P,), dtype=np.float64)
        self.v = np.zeros((P,), dtype=np.float64)
        self.beta1, self.beta2 = betas

def adam_step(theta, grad_flat, state: AdamState, lr=1e-3, eps=1e-8, weight_decay=0.0):
    """
    Deterministic Adam update on a *flat* theta vector.
    Returns updated flat theta and the applied update direction (delta).
    """
    state.t += 1

    g = grad_flat.astype(np.float64)
    if weight_decay != 0.0:
        g = g + weight_decay * theta

    b1, b2 = state.beta1, state.beta2
    state.m = b1 * state.m + (1.0 - b1) * g
    state.v = b2 * state.v + (1.0 - b2) * (g * g)

    mhat = state.m / (1.0 - b1 ** state.t)
    vhat = state.v / (1.0 - b2 ** state.t)

    step = lr * mhat / (np.sqrt(vhat) + eps)   # this is the update magnitude
    theta_new = theta - step
    return theta_new, step


# ---------------------------
# One minibatch run for one method
# ---------------------------
def run_one_minibatch(theta0, method, x_np, y_np):
    """
    Runs K_STEPS on a single minibatch and returns:
      losses: (K_STEPS,)
      delta1: first-step update direction (flat, numpy)
      elapsed
    """
    theta = pnp.array(theta0, requires_grad=True)
    theta_shape = theta.shape
    P = int(np.prod(theta_shape))

    batch_ce_loss = make_batch_ce_loss(x_np, y_np)

    # Adam needs state per minibatch (matches your “diagnostic” intent)
    adam_state = AdamState(P, betas=ADAM_BETAS) if method == "adam" else None

    losses = []
    delta1 = None

    t0 = time.time()
    for k in range(K_STEPS):
        loss_before = float(batch_ce_loss(theta))

        if method == "adam":
            g = qml.grad(batch_ce_loss)(theta)
            g_flat = np.array(pnp.reshape(g, (P,)), dtype=np.float64)

            theta_flat = np.array(pnp.reshape(theta, (P,)), dtype=np.float64)
            theta_flat_new, step_vec = adam_step(
                theta_flat, g_flat, adam_state,
                lr=ADAM_LR, eps=ADAM_EPS, weight_decay=ADAM_WEIGHT_DECAY
            )
            # For cosine comparisons, treat Adam's *applied update* as delta
            delta = step_vec / (ADAM_LR + 1e-12)  # normalized-ish direction (optional)
            theta = _unflat(theta_flat_new, theta_shape)

        else:
            delta = fsng_or_euclid_delta(theta, method, batch_ce_loss, x_np)
            theta_flat = np.array(pnp.reshape(theta, (P,)), dtype=np.float64)
            theta = _unflat(theta_flat - ETA * delta, theta_shape)

        loss_after = float(batch_ce_loss(theta))
        losses.append(loss_after)

        if k == 0:
            # Store first-step applied direction in comparable units:
            # - for FS/euclid: delta is the preconditioned direction (so update is ETA*delta)
            # - for Adam: delta above is roughly direction; we’ll store the *applied step* too
            if method == "adam":
                delta1 = np.array(step_vec, dtype=np.float64)  # actual applied update
            else:
                delta1 = np.array(ETA * delta, dtype=np.float64)  # applied update

    return np.array(losses), delta1, time.time() - t0


# ---------------------------
# Collect minibatches (deterministic but seed-dependent)
# ---------------------------
def get_minibatches(seed, n_minibatches, batch_size):
    """
    Returns list of (x_np, y_np) minibatches.
    Uses a deterministic RNG over indices so we don't depend on DataLoader nondeterminism.
    """
    rng = np.random.RandomState(seed)
    batches = []
    for _ in range(n_minibatches):
        idx = rng.choice(len(train_ds), size=batch_size, replace=False)
        x0, y0 = zip(*[train_ds[i] for i in idx])
        x0 = torch.stack(list(x0), dim=0)
        y0 = torch.tensor(list(y0), dtype=torch.long)

        with torch.no_grad():
            feats16 = pre(x0)
            xq = torch.tanh(feature_select(feats16))

        batches.append((xq.cpu().numpy(), y0.cpu().numpy()))
    return batches


# ---------------------------
# Main experiment loop
# ---------------------------
METHODS = ["euclid", "adam", "diag", "block-diag", "full"]

def summarize(arr):
    return float(np.mean(arr)), float(np.std(arr))


def main():
    print(f"\n[Euclid vs Adam vs FS-NG] DATASET={DATASET} | backend={backend}")
    print(f"  layers={N_LAYERS} | batch={BATCH_SIZE} | minibatches={N_MINIBATCHES} | steps={K_STEPS}")
    print(f"  ETA={ETA} | LAM={LAM} | alpha_shrink={ALPHA_SHRINK} | Adam(lr={ADAM_LR}, betas={ADAM_BETAS})\n")

    # Store per-method results across (seed, minibatch)
    final_losses = {m: [] for m in METHODS}
    times = {m: [] for m in METHODS}
    cos_vs_full = {m: [] for m in METHODS if m != "full"}  # cosine of step1 update vs full step1 update

    for seed in SEEDS:
        set_all_seeds(seed)
        print(f"\n{'='*70}\nSEED {seed}\n{'='*70}")

        # same initial theta per seed across methods
        theta0 = pnp.array(np.random.randn(N_LAYERS, N_QUBITS, 2) * 0.1, requires_grad=True)

        minibatches = get_minibatches(seed, N_MINIBATCHES, BATCH_SIZE)

        for mb_i, (x_np, y_np) in enumerate(minibatches, start=1):
            # run full first for reference direction cosine
            full_losses, full_delta1, full_t = run_one_minibatch(theta0, "full", x_np, y_np)
            final_losses["full"].append(full_losses[-1])
            times["full"].append(full_t)

            # run others
            for m in ["euclid", "adam", "diag", "block-diag"]:
                losses, delta1, tsec = run_one_minibatch(theta0, m, x_np, y_np)
                final_losses[m].append(losses[-1])
                times[m].append(tsec)

                # cosine on step-1 APPLIED update (so compare apples-to-apples)
                cos_vs_full[m].append(_cosine(delta1, full_delta1))

            if mb_i in (1, N_MINIBATCHES):
                print(f"  minibatch {mb_i:2d}/{N_MINIBATCHES}: "
                      f"full_lossK={full_losses[-1]:.4f} | "
                      f"diag_lossK={final_losses['diag'][-1]:.4f} | "
                      f"adam_lossK={final_losses['adam'][-1]:.4f}")

    print("\n" + "="*70)
    print("AGGREGATE RESULTS (across 3 seeds × 10 minibatches = 30 runs per method)")
    print("="*70)

    # Report: mean±std final loss, mean±std time, mean cosine vs full
    for m in METHODS:
        muL, sdL = summarize(final_losses[m])
        muT, sdT = summarize(times[m])
        if m == "full":
            print(f"  {m:9s}: final_loss={muL:.4f} ± {sdL:.4f} | time={muT:.2f}s ± {sdT:.2f}s | cos(step1 vs full)=1.000")
        else:
            muC, sdC = summarize(cos_vs_full[m])
            print(f"  {m:9s}: final_loss={muL:.4f} ± {sdL:.4f} | time={muT:.2f}s ± {sdT:.2f}s | cos(step1 vs full)={muC:.3f} ± {sdC:.3f}")

    print("\nNotes:")
    print("  • Cosines compare the *applied step vector* at step 1 (not the raw gradient).")
    print("  • Adam’s delta is its actual applied update on θ (per minibatch, fresh state).")
    print("  • If you want Adam to carry momentum across minibatches, move AdamState outside run_one_minibatch().")


if __name__ == "__main__":
    main()


[device] backend=lightning.gpu, wires=5 (aux_wire=4)

[Euclid vs Adam vs FS-NG] DATASET=MNIST | backend=lightning.gpu
  layers=2 | batch=8 | minibatches=10 | steps=10
  ETA=0.01 | LAM=0.001 | alpha_shrink=1.0 | Adam(lr=0.001, betas=(0.9, 0.999))


SEED 42
  minibatch  1/10: full_lossK=2.2681 | diag_lossK=2.2690 | adam_lossK=2.2700
  minibatch 10/10: full_lossK=2.3659 | diag_lossK=2.3660 | adam_lossK=2.3659

SEED 123
  minibatch  1/10: full_lossK=2.2654 | diag_lossK=2.2655 | adam_lossK=2.2655
  minibatch 10/10: full_lossK=2.3312 | diag_lossK=2.3310 | adam_lossK=2.3307

SEED 456
  minibatch  1/10: full_lossK=2.2671 | diag_lossK=2.2671 | adam_lossK=2.2668
  minibatch 10/10: full_lossK=2.3362 | diag_lossK=2.3362 | adam_lossK=2.3360

AGGREGATE RESULTS (across 3 seeds × 10 minibatches = 30 runs per method)
  euclid   : final_loss=2.3163 ± 0.0389 | time=6.04s ± 0.36s | cos(step1 vs full)=0.940 ± 0.047
  adam     : final_loss=2.3154 ± 0.0389 | time=5.97s ± 0.40s | cos(step1 vs full)=0.728 ± 0.

# RealNet FashionMNIST

In [5]:
"""
Block 1: RealNet Training
==========================
Trains the baseline Real MLP on 3 seeds.
Saves the trained preprocessor for reuse in subsequent blocks.

Outputs:
- realnet_results.pt: Contains results dict and trained preprocessor state
"""

import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

device = torch.device("cuda")
print(f"Using PyTorch device: {device}")

# ============================================
# Comprehensive seed setting for reproducibility
# ============================================

def set_all_seeds(seed):
    """Set seeds for all RNG sources for reproducibility"""
    import random
    import numpy as np
    import torch
    
    # Python
    random.seed(seed)
    
    # NumPy
    np.random.seed(seed)
    
    # PyTorch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    
    # PyTorch deterministic mode
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # CuPy (if available)
    try:
        import cupy as cp
        cp.random.seed(seed)
    except ImportError:
        pass
    
    # PennyLane (if available)
    try:
        import pennylane as qml
        qml.numpy.random.seed(seed)
    except (ImportError, AttributeError):
        pass

# ============================================
# Shared Bottleneck Preprocessor: 784 → 16
# ============================================

class SharedPreprocessor(nn.Module):
    """Shared classical feature extractor: 784 → 16"""
    def __init__(self, input_dim=784, bottleneck_dim=16):
        super().__init__()
        self.fc = nn.Linear(input_dim, bottleneck_dim)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.tanh(self.fc(x))
        return x


# ============================================
# Real-valued Head and Network
# ============================================

class RealHead(nn.Module):
    """Standard MLP: 16 → 64 → 10"""
    def __init__(self, bottleneck_dim=16, hidden_dim=64, num_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(bottleneck_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


class RealNet(nn.Module):
    """Complete Real network: Preprocessor + RealHead"""
    def __init__(self):
        super().__init__()
        self.preprocessor = SharedPreprocessor(784, 16)
        self.head = RealHead(16, 64, 10)

    def forward(self, x):
        features = self.preprocessor(x)
        return self.head(features)


# ============================================
# Training and Evaluation
# ============================================

def train_one_epoch(model, loader, optimizer, device, show_progress=False):
    model.train()
    total_loss = 0.0
    total_samples = 0

    for batch_idx, (x, y) in enumerate(loader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        total_samples += x.size(0)

        if show_progress and batch_idx % 50 == 0:
            print(f"    Batch {batch_idx}/{len(loader)}, samples: {total_samples}", end="\r")

    if show_progress:
        print()
    return total_loss / len(loader.dataset)


def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    return correct / total if total > 0 else 0.0


def train_with_early_stopping(model, train_loader, test_loader, optimizer,
                              device, max_epochs=40, patience=10, name="Model"):
    if isinstance(device, str):
        device = torch.device(device)

    best_acc = 0.0
    epochs_without_improvement = 0
    start = time.time()
    last_acc = 0.0

    for epoch in range(1, max_epochs + 1):
        loss = train_one_epoch(model, train_loader, optimizer, device, show_progress=False)
        acc = evaluate(model, test_loader, device)
        last_acc = acc

        elapsed = time.time() - start
        print(f"  [{name}] Epoch {epoch:2d} | loss={loss:.4f} "
              f"| test_acc={acc:.4f} | time={elapsed:.1f}s")

        if acc > best_acc:
            best_acc = acc
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print(f"  [{name}] Early stop at epoch {epoch} "
                  f"(no improvement for {patience} epochs)")
            break

    total_time = time.time() - start
    return {
        "best_acc": best_acc,
        "final_acc": last_acc,
        "time": total_time,
        "epochs": epoch,
    }


# ============================================
# Data utilities
# ============================================

def stratified_sample_from_targets(dataset, n_samples_per_class, seed=42):
    """
    Create a stratified sample with n_samples_per_class from each class.
    Uses dataset.targets directly - NO image loading during sampling.
    """
    # FashionMNIST and MNIST expose targets as a tensor
    targets = dataset.targets
    rng = np.random.RandomState(seed)
    
    sampled_indices = []
    for c in range(10):
        # Find all indices for class c
        idx_c = (targets == c).nonzero(as_tuple=False).view(-1).cpu().numpy()
        k = min(n_samples_per_class, len(idx_c))
        # Sample without replacement
        selected = rng.choice(idx_c, size=k, replace=False).tolist()
        sampled_indices.extend(selected)
    
    # Shuffle the combined indices
    rng.shuffle(sampled_indices)
    return sampled_indices


# ============================================
# Main
# ============================================

def main():
    print("=" * 70)
    print("BLOCK 1: RealNet Training")
    print("=" * 70)
    print("\nConfiguration:")
    print("  • Batch size: 32")
    print("  • Patience: 10")
    print("  • Seeds: [42, 123, 456]")
    print("  • Architecture: 784 → 16 → 64 → 10")
    print("=" * 70)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.2860,), (0.3530,))
    ])

    # Load full datasets
    full_train_ds = datasets.FashionMNIST(root="./data", train=True,
                                         download=True, transform=transform)
    full_test_ds = datasets.FashionMNIST(root="./data", train=False,
                                        download=True, transform=transform)

    # Create stratified samples (fast - uses targets only, no image loading)
    print("\nCreating stratified samples...")
    t0 = time.time()
    train_indices = stratified_sample_from_targets(full_train_ds, n_samples_per_class=1500, seed=42)
    test_indices = stratified_sample_from_targets(full_test_ds, n_samples_per_class=300, seed=42)
    print(f"  Sampling took {time.time()-t0:.2f}s")
    
    train_ds = Subset(full_train_ds, train_indices)
    test_ds = Subset(full_test_ds, test_indices)
    
    print(f"  Train samples: {len(train_ds)} (stratified, 1500 per class)")
    print(f"  Test samples:  {len(test_ds)} (stratified, 300 per class)")

    train_loader = DataLoader(train_ds, batch_size=32,
                              shuffle=True, num_workers=0)
    test_loader = DataLoader(test_ds, batch_size=64,
                             shuffle=False, num_workers=0)

    seeds = [42, 123, 456]
    all_results = []
    trained_preprocessor_state = None

    for seed in seeds:
        print(f"\n{'=' * 70}")
        print(f"SEED {seed}")
        print("=" * 70)
        
        set_all_seeds(seed)

        print(f"\n  Training RealNet (seed={seed})...")
        real_model = RealNet().to(device)
        real_opt = torch.optim.Adam(real_model.parameters(), lr=1e-3)
        
        result = train_with_early_stopping(
            real_model, train_loader, test_loader, real_opt, device,
            max_epochs=40, patience=10, name="Real"
        )
        result["params"] = sum(p.numel() for p in real_model.parameters())
        result["seed"] = seed
        
        all_results.append(result)
        
        # Save preprocessor from first seed
        if trained_preprocessor_state is None:
            trained_preprocessor_state = real_model.preprocessor.state_dict()
            print(f"\n  → Saved preprocessor state from seed {seed}")

    # Summary
    print("\n" + "=" * 70)
    print("REALNET SUMMARY")
    print("=" * 70)
    
    accs = [r["best_acc"] for r in all_results]
    times = [r["time"] for r in all_results]
    epochs = [r["epochs"] for r in all_results]
    params = all_results[0]["params"]
    
    print(f"\nAccuracy:   {np.mean(accs):.4f} ± {np.std(accs):.4f}")
    print(f"Time:       {np.mean(times):.1f}s ± {np.std(times):.1f}s")
    print(f"Epochs:     {np.mean(epochs):.1f} ± {np.std(epochs):.1f}")
    print(f"Parameters: {params:,}")
    
    print("\nPer-seed results:")
    for r in all_results:
        print(f"  Seed {r['seed']}: acc={r['best_acc']:.4f}, "
              f"time={r['time']:.1f}s, epochs={r['epochs']}")

    # Save results and preprocessor
    save_dict = {
        "results": all_results,
        "preprocessor_state": trained_preprocessor_state,
        "summary": {
            "mean_acc": np.mean(accs),
            "std_acc": np.std(accs),
            "mean_time": np.mean(times),
            "params": params
        }
    }
    
    torch.save(save_dict, "realnet_results1.pt")
    print(f"\n✓ Saved results to: realnet_results.pt")
    print("=" * 70)


if __name__ == "__main__":
    main()

Using PyTorch device: cuda
BLOCK 1: RealNet Training

Configuration:
  • Batch size: 32
  • Patience: 10
  • Seeds: [42, 123, 456]
  • Architecture: 784 → 16 → 64 → 10

Creating stratified samples...
  Sampling took 0.01s
  Train samples: 15000 (stratified, 1500 per class)
  Test samples:  3000 (stratified, 300 per class)

SEED 42

  Training RealNet (seed=42)...
  [Real] Epoch  1 | loss=0.8276 | test_acc=0.7903 | time=1.3s
  [Real] Epoch  2 | loss=0.4974 | test_acc=0.8260 | time=2.6s
  [Real] Epoch  3 | loss=0.4442 | test_acc=0.8333 | time=3.9s
  [Real] Epoch  4 | loss=0.4149 | test_acc=0.8270 | time=5.2s
  [Real] Epoch  5 | loss=0.3965 | test_acc=0.8373 | time=6.6s
  [Real] Epoch  6 | loss=0.3760 | test_acc=0.8383 | time=8.1s
  [Real] Epoch  7 | loss=0.3641 | test_acc=0.8303 | time=9.5s
  [Real] Epoch  8 | loss=0.3509 | test_acc=0.8377 | time=10.9s
  [Real] Epoch  9 | loss=0.3429 | test_acc=0.8390 | time=12.3s
  [Real] Epoch 10 | loss=0.3271 | test_acc=0.8440 | time=13.6s
  [Real] Ep

# Diagnostic Fashion MNIST

In [11]:
"""
Mini-script: Euclid vs Adam vs FS-NG covariance test (euclid vs adam vs diag vs block-diag vs full)
=================================================================================================

What this does
--------------
Runs K update steps on *multiple fixed minibatches* (10 minibatches) for each of 3 seeds,
updating ONLY the circuit weights θ, while holding the preprocessor + feature mapping + readout fixed.

We compare five update rules:

  1) euclid     : plain Euclidean gradient descent        θ <- θ - η * g
  2) adam       : Adam optimizer on θ (PyTorch-style)     θ <- AdamStep(g)
  3) diag       : FS/QFI natural gradient (metric_tensor approx="diag")
  4) block-diag : FS/QFI natural gradient (approx="block-diag")
  5) full       : FS/QFI natural gradient (approx=None, needs aux wire)

Optional hybrid (“shrinkage” FS metric toward Euclidean)
--------------------------------------------------------
Instead of pure FS-NG, you can blend the metric toward identity:

    G_eff = α * G_bar + (1-α) * I
    step  = (G_eff + λ I)^(-1) g

Set ALPHA_SHRINK in (0,1]. If you want *pure* FS-NG, set ALPHA_SHRINK=1.0.

Goal
----
- See whether FS covariance structure helps over Euclidean GD and Adam
- Quantify directional difference (cosine similarity vs full step direction)
- Keep runtime low: small minibatches, 10 minibatches, K steps each, 3 seeds

Requires
--------
- realnet_results.pt (from Block 1; contains frozen preprocessor weights)
- pennylane
- torchvision
- optional: pennylane-lightning-gpu (for lightning.gpu)

Notes
-----
- FULL metric_tensor (approx=None) requires an auxiliary wire on many devices.
  We allocate N_QUBITS + 1 wires and reserve the last wire as aux.
- Adam is implemented in a lightweight, deterministic “PyTorch Adam” way on θ only,
  so you can compare it without rerunning multi-day experiments.
"""

import time
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import pennylane as qml
from pennylane import numpy as pnp


# ---------------------------
# Config (tune for speed)
# ---------------------------
SEEDS = [42, 123, 456]

DATASET = "FashionMNIST"          # fixed per your request
N_QUBITS = 4
N_LAYERS = 2
BATCH_SIZE = 8

N_MINIBATCHES = 10         # 10 minibatches per seed
K_STEPS = 10               # steps per minibatch (kept from prior diagnostic)

ETA = 0.01                 # Euclid/FS learning rate
LAM = 1e-3                 # FS damping

# Shrinkage: α=1 => pure FS, α<1 => blend toward Euclidean identity metric
ALPHA_SHRINK = 1.0         # try 0.8 or 0.5 for “Euclid+FS” blend

# Adam hyperparams (match your prior analysis spirit)
ADAM_LR = 1e-3
ADAM_BETAS = (0.9, 0.999)
ADAM_EPS = 1e-8
ADAM_WEIGHT_DECAY = 0.0

PREFERRED_DEVICE = "lightning.gpu"
FALLBACK_DEVICE = "default.qubit"


# ---------------------------
# Reproducibility
# ---------------------------
def set_all_seeds(seed: int):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    try:
        qml.numpy.random.seed(seed)
    except Exception:
        pass


# ---------------------------
# Shared Preprocessor (Block 1)
# ---------------------------
class SharedPreprocessor(nn.Module):
    """Shared classical feature extractor: 784 → 16"""
    def __init__(self, input_dim=784, bottleneck_dim=16):
        super().__init__()
        self.fc = nn.Linear(input_dim, bottleneck_dim)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return torch.tanh(self.fc(x))


# ---------------------------
# Device factory (+1 aux wire for full metric)
# ---------------------------
def make_device(n_qubits: int):
    n_wires_total = n_qubits + 1
    aux_wire = n_qubits
    try:
        dev = qml.device(PREFERRED_DEVICE, wires=n_wires_total)
        backend = PREFERRED_DEVICE
    except Exception as e:
        dev = qml.device(FALLBACK_DEVICE, wires=n_wires_total)
        backend = FALLBACK_DEVICE
        print(f"[device] Could not load {PREFERRED_DEVICE} ({type(e).__name__}: {e}). Using {FALLBACK_DEVICE}.")
    print(f"[device] backend={backend}, wires={n_wires_total} (aux_wire={aux_wire})")
    return dev, backend, aux_wire


# ---------------------------
# Helpers
# ---------------------------
def _cosine(a: np.ndarray, b: np.ndarray) -> float:
    na = np.linalg.norm(a) + 1e-12
    nb = np.linalg.norm(b) + 1e-12
    return float(np.dot(a, b) / (na * nb))


def _flat(theta) -> pnp.ndarray:
    return pnp.reshape(theta, (-1,))


def _unflat(v, shape) -> pnp.ndarray:
    return pnp.reshape(v, shape)


# ---------------------------
# Build circuit + metric fns
# ---------------------------
dev, backend, aux_wire = make_device(N_QUBITS)

@qml.qnode(dev, interface="autograd", diff_method="parameter-shift")
def circuit(inputs, weights):
    # circuit uses ONLY wires 0..N_QUBITS-1; last wire is aux-reserved
    for layer in range(N_LAYERS):
        for i in range(N_QUBITS):
            qml.RY(inputs[i], wires=i)
        for i in range(N_QUBITS):
            qml.RY(weights[layer, i, 0], wires=i)
            qml.RZ(weights[layer, i, 1], wires=i)

        for i in range(N_QUBITS - 1):
            qml.CNOT(wires=[i, i + 1])
        if N_QUBITS > 2:
            qml.CNOT(wires=[N_QUBITS - 1, 0])

    return (
        qml.expval(qml.PauliZ(0)),
        qml.expval(qml.PauliZ(1)),
        qml.expval(qml.PauliZ(2)),
        qml.expval(qml.PauliZ(3)),
        qml.expval(qml.PauliZ(0) @ qml.PauliZ(1)),
        qml.expval(qml.PauliZ(2) @ qml.PauliZ(3)),
    )

metric_full  = qml.metric_tensor(circuit, approx=None,       aux_wire=aux_wire)
metric_diag  = qml.metric_tensor(circuit, approx="diag")
metric_block = qml.metric_tensor(circuit, approx="block-diag")


# ---------------------------
# Load frozen preprocessor from Block 1
# ---------------------------
realnet = torch.load("realnet_results1.pt", weights_only=False)
pre_state = realnet["preprocessor_state"]

pre = SharedPreprocessor(784, 16)
pre.load_state_dict(pre_state)
for p in pre.parameters():
    p.requires_grad = False
pre.eval()

# frozen feature_select (16 -> 4)
feature_select = nn.Linear(16, N_QUBITS)
for p in feature_select.parameters():
    p.requires_grad = False
feature_select.eval()


# ---------------------------
# Dataset: MNIST (fixed)
# ---------------------------
norm_mean, norm_std = (0.1307,), (0.3081,)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])
train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)


# ---------------------------
# Fixed readout (constant) to isolate θ update behavior
# ---------------------------
C = 10
W = pnp.array(np.random.randn(C, 6).astype(np.float64) * 0.1)  # (10,6)
b = pnp.array(np.zeros((C,), dtype=np.float64))


# ---------------------------
# Loss factory (depends on current minibatch)
# ---------------------------
def make_batch_ce_loss(x_np, y_np):
    def batch_ce_loss(theta):
        losses = []
        for i in range(len(y_np)):
            q_raw = circuit(x_np[i], theta)
            q = pnp.stack(q_raw)  # (6,)
            logits = W @ q + b
            logits = logits - pnp.max(logits)
            ex = pnp.exp(logits)
            probs = ex / pnp.sum(ex)
            losses.append(-pnp.log(probs[int(y_np[i])] + 1e-12))
        return pnp.mean(pnp.stack(losses))
    return batch_ce_loss


# ---------------------------
# FS-NG / Euclid delta
# ---------------------------
def fsng_or_euclid_delta(theta, approx: str, batch_ce_loss, x_np):
    """
    approx in {"euclid","diag","block-diag","full"}.

    - euclid: delta = g
    - others: delta = (G_eff + λI)^(-1) g
              with optional shrinkage: G_eff = α G_bar + (1-α) I
    """
    theta_shape = theta.shape
    P = int(np.prod(theta_shape))
    I = pnp.eye(P, dtype=pnp.float64)

    g = qml.grad(batch_ce_loss)(theta)
    g_flat = pnp.reshape(g, (P,))

    if approx == "euclid":
        return np.array(g_flat, dtype=np.float64)

    if approx == "diag":
        metric_fn = metric_diag
    elif approx == "block-diag":
        metric_fn = metric_block
    elif approx == "full":
        metric_fn = metric_full
    else:
        raise ValueError("approx must be one of: 'euclid','diag','block-diag','full'")

    G_sum = pnp.zeros((P, P), dtype=pnp.float64)
    for i in range(x_np.shape[0]):
        Gi = metric_fn(x_np[i], theta)
        Gi = pnp.reshape(Gi, (P, P))
        G_sum = G_sum + Gi
    G_bar = G_sum / float(x_np.shape[0])

    alpha = float(ALPHA_SHRINK)
    G_eff = alpha * G_bar + (1.0 - alpha) * I
    G_reg = G_eff + LAM * I

    delta = pnp.linalg.solve(G_reg, g_flat)
    return np.array(delta, dtype=np.float64)


# ---------------------------
# Adam (PyTorch-style) on θ only
# ---------------------------
class AdamState:
    def __init__(self, P, betas=(0.9, 0.999)):
        self.t = 0
        self.m = np.zeros((P,), dtype=np.float64)
        self.v = np.zeros((P,), dtype=np.float64)
        self.beta1, self.beta2 = betas

def adam_step(theta, grad_flat, state: AdamState, lr=1e-3, eps=1e-8, weight_decay=0.0):
    """
    Deterministic Adam update on a *flat* theta vector.
    Returns updated flat theta and the applied update direction (delta).
    """
    state.t += 1

    g = grad_flat.astype(np.float64)
    if weight_decay != 0.0:
        g = g + weight_decay * theta

    b1, b2 = state.beta1, state.beta2
    state.m = b1 * state.m + (1.0 - b1) * g
    state.v = b2 * state.v + (1.0 - b2) * (g * g)

    mhat = state.m / (1.0 - b1 ** state.t)
    vhat = state.v / (1.0 - b2 ** state.t)

    step = lr * mhat / (np.sqrt(vhat) + eps)   # this is the update magnitude
    theta_new = theta - step
    return theta_new, step


# ---------------------------
# One minibatch run for one method
# ---------------------------
def run_one_minibatch(theta0, method, x_np, y_np):
    """
    Runs K_STEPS on a single minibatch and returns:
      losses: (K_STEPS,)
      delta1: first-step update direction (flat, numpy)
      elapsed
    """
    theta = pnp.array(theta0, requires_grad=True)
    theta_shape = theta.shape
    P = int(np.prod(theta_shape))

    batch_ce_loss = make_batch_ce_loss(x_np, y_np)

    # Adam needs state per minibatch (matches your “diagnostic” intent)
    adam_state = AdamState(P, betas=ADAM_BETAS) if method == "adam" else None

    losses = []
    delta1 = None

    t0 = time.time()
    for k in range(K_STEPS):
        loss_before = float(batch_ce_loss(theta))

        if method == "adam":
            g = qml.grad(batch_ce_loss)(theta)
            g_flat = np.array(pnp.reshape(g, (P,)), dtype=np.float64)

            theta_flat = np.array(pnp.reshape(theta, (P,)), dtype=np.float64)
            theta_flat_new, step_vec = adam_step(
                theta_flat, g_flat, adam_state,
                lr=ADAM_LR, eps=ADAM_EPS, weight_decay=ADAM_WEIGHT_DECAY
            )
            # For cosine comparisons, treat Adam's *applied update* as delta
            delta = step_vec / (ADAM_LR + 1e-12)  # normalized-ish direction (optional)
            theta = _unflat(theta_flat_new, theta_shape)

        else:
            delta = fsng_or_euclid_delta(theta, method, batch_ce_loss, x_np)
            theta_flat = np.array(pnp.reshape(theta, (P,)), dtype=np.float64)
            theta = _unflat(theta_flat - ETA * delta, theta_shape)

        loss_after = float(batch_ce_loss(theta))
        losses.append(loss_after)

        if k == 0:
            # Store first-step applied direction in comparable units:
            # - for FS/euclid: delta is the preconditioned direction (so update is ETA*delta)
            # - for Adam: delta above is roughly direction; we’ll store the *applied step* too
            if method == "adam":
                delta1 = np.array(step_vec, dtype=np.float64)  # actual applied update
            else:
                delta1 = np.array(ETA * delta, dtype=np.float64)  # applied update

    return np.array(losses), delta1, time.time() - t0


# ---------------------------
# Collect minibatches (deterministic but seed-dependent)
# ---------------------------
def get_minibatches(seed, n_minibatches, batch_size):
    """
    Returns list of (x_np, y_np) minibatches.
    Uses a deterministic RNG over indices so we don't depend on DataLoader nondeterminism.
    """
    rng = np.random.RandomState(seed)
    batches = []
    for _ in range(n_minibatches):
        idx = rng.choice(len(train_ds), size=batch_size, replace=False)
        x0, y0 = zip(*[train_ds[i] for i in idx])
        x0 = torch.stack(list(x0), dim=0)
        y0 = torch.tensor(list(y0), dtype=torch.long)

        with torch.no_grad():
            feats16 = pre(x0)
            xq = torch.tanh(feature_select(feats16))

        batches.append((xq.cpu().numpy(), y0.cpu().numpy()))
    return batches


# ---------------------------
# Main experiment loop
# ---------------------------
METHODS = ["euclid", "adam", "diag", "block-diag", "full"]

def summarize(arr):
    return float(np.mean(arr)), float(np.std(arr))


def main():
    print(f"\n[Euclid vs Adam vs FS-NG] DATASET={DATASET} | backend={backend}")
    print(f"  layers={N_LAYERS} | batch={BATCH_SIZE} | minibatches={N_MINIBATCHES} | steps={K_STEPS}")
    print(f"  ETA={ETA} | LAM={LAM} | alpha_shrink={ALPHA_SHRINK} | Adam(lr={ADAM_LR}, betas={ADAM_BETAS})\n")

    # Store per-method results across (seed, minibatch)
    final_losses = {m: [] for m in METHODS}
    times = {m: [] for m in METHODS}
    cos_vs_full = {m: [] for m in METHODS if m != "full"}  # cosine of step1 update vs full step1 update

    for seed in SEEDS:
        set_all_seeds(seed)
        print(f"\n{'='*70}\nSEED {seed}\n{'='*70}")

        # same initial theta per seed across methods
        theta0 = pnp.array(np.random.randn(N_LAYERS, N_QUBITS, 2) * 0.1, requires_grad=True)

        minibatches = get_minibatches(seed, N_MINIBATCHES, BATCH_SIZE)

        for mb_i, (x_np, y_np) in enumerate(minibatches, start=1):
            # run full first for reference direction cosine
            full_losses, full_delta1, full_t = run_one_minibatch(theta0, "full", x_np, y_np)
            final_losses["full"].append(full_losses[-1])
            times["full"].append(full_t)

            # run others
            for m in ["euclid", "adam", "diag", "block-diag"]:
                losses, delta1, tsec = run_one_minibatch(theta0, m, x_np, y_np)
                final_losses[m].append(losses[-1])
                times[m].append(tsec)

                # cosine on step-1 APPLIED update (so compare apples-to-apples)
                cos_vs_full[m].append(_cosine(delta1, full_delta1))

            if mb_i in (1, N_MINIBATCHES):
                print(f"  minibatch {mb_i:2d}/{N_MINIBATCHES}: "
                      f"full_lossK={full_losses[-1]:.4f} | "
                      f"diag_lossK={final_losses['diag'][-1]:.4f} | "
                      f"adam_lossK={final_losses['adam'][-1]:.4f}")

    print("\n" + "="*70)
    print("AGGREGATE RESULTS (across 3 seeds × 10 minibatches = 30 runs per method)")
    print("="*70)

    # Report: mean±std final loss, mean±std time, mean cosine vs full
    for m in METHODS:
        muL, sdL = summarize(final_losses[m])
        muT, sdT = summarize(times[m])
        if m == "full":
            print(f"  {m:9s}: final_loss={muL:.4f} ± {sdL:.4f} | time={muT:.2f}s ± {sdT:.2f}s | cos(step1 vs full)=1.000")
        else:
            muC, sdC = summarize(cos_vs_full[m])
            print(f"  {m:9s}: final_loss={muL:.4f} ± {sdL:.4f} | time={muT:.2f}s ± {sdT:.2f}s | cos(step1 vs full)={muC:.3f} ± {sdC:.3f}")

    print("\nNotes:")
    print("  • Cosines compare the *applied step vector* at step 1 (not the raw gradient).")
    print("  • Adam’s delta is its actual applied update on θ (per minibatch, fresh state).")
    print("  • If you want Adam to carry momentum across minibatches, move AdamState outside run_one_minibatch().")


if __name__ == "__main__":
    main()


[device] backend=lightning.gpu, wires=5 (aux_wire=4)

[Euclid vs Adam vs FS-NG] DATASET=FashionMNIST | backend=lightning.gpu
  layers=2 | batch=8 | minibatches=10 | steps=10
  ETA=0.01 | LAM=0.001 | alpha_shrink=1.0 | Adam(lr=0.001, betas=(0.9, 0.999))


SEED 42
  minibatch  1/10: full_lossK=2.2397 | diag_lossK=2.2400 | adam_lossK=2.2405
  minibatch 10/10: full_lossK=2.3050 | diag_lossK=2.3050 | adam_lossK=2.3046

SEED 123
  minibatch  1/10: full_lossK=2.4305 | diag_lossK=2.4305 | adam_lossK=2.4304
  minibatch 10/10: full_lossK=2.2562 | diag_lossK=2.2560 | adam_lossK=2.2568

SEED 456
  minibatch  1/10: full_lossK=2.2706 | diag_lossK=2.2707 | adam_lossK=2.2705
  minibatch 10/10: full_lossK=2.2765 | diag_lossK=2.2765 | adam_lossK=2.2762

AGGREGATE RESULTS (across 3 seeds × 10 minibatches = 30 runs per method)
  euclid   : final_loss=2.3014 ± 0.0516 | time=5.95s ± 0.41s | cos(step1 vs full)=0.937 ± 0.074
  adam     : final_loss=2.3005 ± 0.0515 | time=5.92s ± 0.47s | cos(step1 vs full)=0.7

# RealNet CIFAR

In [12]:
"""
Block 1: RealNet Training (CIFAR-10)
=====================================
Trains the baseline Real MLP on 3 seeds using CIFAR-10.
Saves the trained preprocessor for reuse in subsequent blocks.

Outputs:
- realnet_results.pt: Contains results dict and trained preprocessor state
"""

import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

device = torch.device("cuda")
print(f"Using PyTorch device: {device}")

# ============================================
# Comprehensive seed setting for reproducibility
# ============================================

def set_all_seeds(seed):
    """Set seeds for all RNG sources for reproducibility"""
    import random
    import numpy as np
    import torch
    
    # Python
    random.seed(seed)
    
    # NumPy
    np.random.seed(seed)
    
    # PyTorch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    
    # PyTorch deterministic mode
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # CuPy (if available)
    try:
        import cupy as cp
        cp.random.seed(seed)
    except ImportError:
        pass
    
    # PennyLane (if available)
    try:
        import pennylane as qml
        qml.numpy.random.seed(seed)
    except (ImportError, AttributeError):
        pass

# ============================================
# Shared Bottleneck Preprocessor: 3072 → 16
# ============================================

class SharedPreprocessor(nn.Module):
    """Shared classical feature extractor: 3072 → 16"""
    def __init__(self, input_dim=3072, bottleneck_dim=16):
        super().__init__()
        self.fc = nn.Linear(input_dim, bottleneck_dim)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten to (batch, 3072)
        x = torch.tanh(self.fc(x))
        return x


# ============================================
# Real-valued Head and Network
# ============================================

class RealHead(nn.Module):
    """Standard MLP: 16 → 64 → 10"""
    def __init__(self, bottleneck_dim=16, hidden_dim=64, num_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(bottleneck_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


class RealNet(nn.Module):
    """Complete Real network: Preprocessor + RealHead"""
    def __init__(self):
        super().__init__()
        self.preprocessor = SharedPreprocessor(3072, 16)
        self.head = RealHead(16, 64, 10)

    def forward(self, x):
        features = self.preprocessor(x)
        return self.head(features)


# ============================================
# Training and Evaluation
# ============================================

def train_one_epoch(model, loader, optimizer, device, show_progress=False):
    model.train()
    total_loss = 0.0
    total_samples = 0

    for batch_idx, (x, y) in enumerate(loader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        total_samples += x.size(0)

        if show_progress and batch_idx % 50 == 0:
            print(f"    Batch {batch_idx}/{len(loader)}, samples: {total_samples}", end="\r")

    if show_progress:
        print()
    return total_loss / len(loader.dataset)


def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    return correct / total if total > 0 else 0.0


def train_with_early_stopping(model, train_loader, test_loader, optimizer,
                              device, max_epochs=40, patience=10, name="Model"):
    if isinstance(device, str):
        device = torch.device(device)

    best_acc = 0.0
    epochs_without_improvement = 0
    start = time.time()
    last_acc = 0.0

    for epoch in range(1, max_epochs + 1):
        loss = train_one_epoch(model, train_loader, optimizer, device, show_progress=False)
        acc = evaluate(model, test_loader, device)
        last_acc = acc

        elapsed = time.time() - start
        print(f"  [{name}] Epoch {epoch:2d} | loss={loss:.4f} "
              f"| test_acc={acc:.4f} | time={elapsed:.1f}s")

        if acc > best_acc:
            best_acc = acc
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print(f"  [{name}] Early stop at epoch {epoch} "
                  f"(no improvement for {patience} epochs)")
            break

    total_time = time.time() - start
    return {
        "best_acc": best_acc,
        "final_acc": last_acc,
        "time": total_time,
        "epochs": epoch,
    }


# ============================================
# Data utilities
# ============================================

def stratified_sample_from_targets(dataset, n_samples_per_class, seed=42):
    """
    Create a stratified sample with n_samples_per_class from each class.
    Uses dataset.targets directly - NO image loading during sampling.
    """
    # CIFAR-10 exposes targets as a list
    if hasattr(dataset, 'targets'):
        targets = np.array(dataset.targets)
    else:
        # Fallback for wrapped datasets
        targets = np.array([dataset[i][1] for i in range(len(dataset))])
    
    rng = np.random.RandomState(seed)
    
    sampled_indices = []
    for c in range(10):
        # Find all indices for class c
        idx_c = np.where(targets == c)[0]
        k = min(n_samples_per_class, len(idx_c))
        # Sample without replacement
        selected = rng.choice(idx_c, size=k, replace=False).tolist()
        sampled_indices.extend(selected)
    
    # Shuffle the combined indices
    rng.shuffle(sampled_indices)
    return sampled_indices


# ============================================
# Main
# ============================================

def main():
    print("=" * 70)
    print("BLOCK 1: RealNet Training (CIFAR-10)")
    print("=" * 70)
    print("\nConfiguration:")
    print("  • Dataset: CIFAR-10 (32×32 RGB)")
    print("  • Input dimension: 3072 (32×32×3 flattened)")
    print("  • Batch size: 128")
    print("  • Patience: 10")
    print("  • Seeds: [42, 123, 456]")
    print("  • Architecture: 3072 → 16 → 64 → 10")
    print("=" * 70)

    # CIFAR-10 normalization (channel-wise means and stds)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2470, 0.2435, 0.2616]
        )
    ])

    # Load full datasets
    full_train_ds = datasets.CIFAR10(root="./data", train=True,
                                     download=True, transform=transform)
    full_test_ds = datasets.CIFAR10(root="./data", train=False,
                                    download=True, transform=transform)

    # Create stratified samples (fast - uses targets only, no image loading)
    print("\nCreating stratified samples...")
    t0 = time.time()
    train_indices = stratified_sample_from_targets(full_train_ds, n_samples_per_class=1500, seed=42)
    test_indices = stratified_sample_from_targets(full_test_ds, n_samples_per_class=300, seed=42)
    print(f"  Sampling took {time.time()-t0:.2f}s")
    
    train_ds = Subset(full_train_ds, train_indices)
    test_ds = Subset(full_test_ds, test_indices)
    
    print(f"  Train samples: {len(train_ds)} (stratified, 1500 per class)")
    print(f"  Test samples:  {len(test_ds)} (stratified, 300 per class)")

    train_loader = DataLoader(train_ds, batch_size=128,
                              shuffle=True, num_workers=4)
    test_loader = DataLoader(test_ds, batch_size=256,
                             shuffle=False, num_workers=4)

    seeds = [42, 123, 456]
    all_results = []
    trained_preprocessor_state = None

    for seed in seeds:
        print(f"\n{'=' * 70}")
        print(f"SEED {seed}")
        print("=" * 70)
        
        set_all_seeds(seed)

        print(f"\n  Training RealNet (seed={seed})...")
        real_model = RealNet().to(device)
        real_opt = torch.optim.Adam(real_model.parameters(), lr=1e-3)
        
        result = train_with_early_stopping(
            real_model, train_loader, test_loader, real_opt, device,
            max_epochs=40, patience=10, name="Real"
        )
        result["params"] = sum(p.numel() for p in real_model.parameters())
        result["seed"] = seed
        
        all_results.append(result)
        
        # Save preprocessor from first seed
        if trained_preprocessor_state is None:
            trained_preprocessor_state = real_model.preprocessor.state_dict()
            print(f"\n  → Saved preprocessor state from seed {seed}")

    # Summary
    print("\n" + "=" * 70)
    print("REALNET SUMMARY (CIFAR-10)")
    print("=" * 70)
    
    accs = [r["best_acc"] for r in all_results]
    times = [r["time"] for r in all_results]
    epochs = [r["epochs"] for r in all_results]
    params = all_results[0]["params"]
    
    print(f"\nAccuracy:   {np.mean(accs):.4f} ± {np.std(accs):.4f}")
    print(f"Time:       {np.mean(times):.1f}s ± {np.std(times):.1f}s")
    print(f"Epochs:     {np.mean(epochs):.1f} ± {np.std(epochs):.1f}")
    print(f"Parameters: {params:,}")
    
    print("\nPer-seed results:")
    for r in all_results:
        print(f"  Seed {r['seed']}: acc={r['best_acc']:.4f}, "
              f"time={r['time']:.1f}s, epochs={r['epochs']}")

    # Save results and preprocessor
    save_dict = {
        "results": all_results,
        "preprocessor_state": trained_preprocessor_state,
        "summary": {
            "mean_acc": np.mean(accs),
            "std_acc": np.std(accs),
            "mean_time": np.mean(times),
            "params": params
        }
    }
    
    torch.save(save_dict, "realnet_cifar10_results.pt")
    print(f"\n✓ Saved results to: realnet_cifar10_results.pt")
    print("=" * 70)


if __name__ == "__main__":
    main()

Using PyTorch device: cuda
BLOCK 1: RealNet Training (CIFAR-10)

Configuration:
  • Dataset: CIFAR-10 (32×32 RGB)
  • Input dimension: 3072 (32×32×3 flattened)
  • Batch size: 128
  • Patience: 10
  • Seeds: [42, 123, 456]
  • Architecture: 3072 → 16 → 64 → 10

Creating stratified samples...
  Sampling took 0.00s
  Train samples: 15000 (stratified, 1500 per class)
  Test samples:  3000 (stratified, 300 per class)

SEED 42

  Training RealNet (seed=42)...
  [Real] Epoch  1 | loss=2.0209 | test_acc=0.3273 | time=0.5s
  [Real] Epoch  2 | loss=1.8281 | test_acc=0.3470 | time=0.9s
  [Real] Epoch  3 | loss=1.7635 | test_acc=0.3577 | time=1.4s
  [Real] Epoch  4 | loss=1.7250 | test_acc=0.3617 | time=1.8s
  [Real] Epoch  5 | loss=1.6946 | test_acc=0.3710 | time=2.3s
  [Real] Epoch  6 | loss=1.6776 | test_acc=0.3710 | time=2.7s
  [Real] Epoch  7 | loss=1.6463 | test_acc=0.3630 | time=3.2s
  [Real] Epoch  8 | loss=1.6154 | test_acc=0.3730 | time=3.6s
  [Real] Epoch  9 | loss=1.6088 | test_acc=0.

# Diagnostic CIFAR

In [13]:
"""
Mini-script: Euclid vs Adam vs FS-NG covariance test (CIFAR-10)
================================================================

What this does
--------------
Runs K update steps on *multiple fixed minibatches* (10 minibatches) for each of 3 seeds,
updating ONLY the circuit weights θ, while holding the preprocessor + feature mapping + readout fixed.

We compare five update rules:

  1) euclid     : plain Euclidean gradient descent        θ <- θ - η * g
  2) adam       : Adam optimizer on θ (PyTorch-style)     θ <- AdamStep(g)
  3) diag       : FS/QFI natural gradient (metric_tensor approx="diag")
  4) block-diag : FS/QFI natural gradient (approx="block-diag")
  5) full       : FS/QFI natural gradient (approx=None, needs aux wire)

Optional hybrid ("shrinkage" FS metric toward Euclidean)
--------------------------------------------------------
Instead of pure FS-NG, you can blend the metric toward identity:

    G_eff = α * G_bar + (1-α) * I
    step  = (G_eff + λ I)^(-1) g

Set ALPHA_SHRINK in (0,1]. If you want *pure* FS-NG, set ALPHA_SHRINK=1.0.

Goal
----
- See whether FS covariance structure helps over Euclidean GD and Adam
- Quantify directional difference (cosine similarity vs full step direction)
- Keep runtime low: small minibatches, 10 minibatches, K steps each, 3 seeds

Requires
--------
- realnet_cifar10_results.pt (from Block 1; contains frozen preprocessor weights)
- pennylane
- torchvision
- optional: pennylane-lightning-gpu (for lightning.gpu)

Notes
-----
- FULL metric_tensor (approx=None) requires an auxiliary wire on many devices.
  We allocate N_QUBITS + 1 wires and reserve the last wire as aux.
- Adam is implemented in a lightweight, deterministic "PyTorch Adam" way on θ only,
  so you can compare it without rerunning multi-day experiments.
"""

import time
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import pennylane as qml
from pennylane import numpy as pnp


# ---------------------------
# Config (tune for speed)
# ---------------------------
SEEDS = [42, 123, 456]

DATASET = "CIFAR-10"               # changed from FashionMNIST
N_QUBITS = 4
N_LAYERS = 2
BATCH_SIZE = 8

N_MINIBATCHES = 10         # 10 minibatches per seed
K_STEPS = 10               # steps per minibatch (kept from prior diagnostic)

ETA = 0.01                 # Euclid/FS learning rate
LAM = 1e-3                 # FS damping

# Shrinkage: α=1 => pure FS, α<1 => blend toward Euclidean identity metric
ALPHA_SHRINK = 1.0         # try 0.8 or 0.5 for "Euclid+FS" blend

# Adam hyperparams (match your prior analysis spirit)
ADAM_LR = 1e-3
ADAM_BETAS = (0.9, 0.999)
ADAM_EPS = 1e-8
ADAM_WEIGHT_DECAY = 0.0

PREFERRED_DEVICE = "lightning.gpu"
FALLBACK_DEVICE = "default.qubit"


# ---------------------------
# Reproducibility
# ---------------------------
def set_all_seeds(seed: int):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    try:
        qml.numpy.random.seed(seed)
    except Exception:
        pass


# ---------------------------
# Shared Preprocessor (Block 1) - CIFAR-10 version
# ---------------------------
class SharedPreprocessor(nn.Module):
    """Shared classical feature extractor: 3072 → 16"""
    def __init__(self, input_dim=3072, bottleneck_dim=16):
        super().__init__()
        self.fc = nn.Linear(input_dim, bottleneck_dim)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten to (batch, 3072)
        return torch.tanh(self.fc(x))


# ---------------------------
# Device factory (+1 aux wire for full metric)
# ---------------------------
def make_device(n_qubits: int):
    n_wires_total = n_qubits + 1
    aux_wire = n_qubits
    try:
        dev = qml.device(PREFERRED_DEVICE, wires=n_wires_total)
        backend = PREFERRED_DEVICE
    except Exception as e:
        dev = qml.device(FALLBACK_DEVICE, wires=n_wires_total)
        backend = FALLBACK_DEVICE
        print(f"[device] Could not load {PREFERRED_DEVICE} ({type(e).__name__}: {e}). Using {FALLBACK_DEVICE}.")
    print(f"[device] backend={backend}, wires={n_wires_total} (aux_wire={aux_wire})")
    return dev, backend, aux_wire


# ---------------------------
# Helpers
# ---------------------------
def _cosine(a: np.ndarray, b: np.ndarray) -> float:
    na = np.linalg.norm(a) + 1e-12
    nb = np.linalg.norm(b) + 1e-12
    return float(np.dot(a, b) / (na * nb))


def _flat(theta) -> pnp.ndarray:
    return pnp.reshape(theta, (-1,))


def _unflat(v, shape) -> pnp.ndarray:
    return pnp.reshape(v, shape)


# ---------------------------
# Build circuit + metric fns
# ---------------------------
dev, backend, aux_wire = make_device(N_QUBITS)

@qml.qnode(dev, interface="autograd", diff_method="parameter-shift")
def circuit(inputs, weights):
    # circuit uses ONLY wires 0..N_QUBITS-1; last wire is aux-reserved
    for layer in range(N_LAYERS):
        for i in range(N_QUBITS):
            qml.RY(inputs[i], wires=i)
        for i in range(N_QUBITS):
            qml.RY(weights[layer, i, 0], wires=i)
            qml.RZ(weights[layer, i, 1], wires=i)

        for i in range(N_QUBITS - 1):
            qml.CNOT(wires=[i, i + 1])
        if N_QUBITS > 2:
            qml.CNOT(wires=[N_QUBITS - 1, 0])

    return (
        qml.expval(qml.PauliZ(0)),
        qml.expval(qml.PauliZ(1)),
        qml.expval(qml.PauliZ(2)),
        qml.expval(qml.PauliZ(3)),
        qml.expval(qml.PauliZ(0) @ qml.PauliZ(1)),
        qml.expval(qml.PauliZ(2) @ qml.PauliZ(3)),
    )

metric_full  = qml.metric_tensor(circuit, approx=None,       aux_wire=aux_wire)
metric_diag  = qml.metric_tensor(circuit, approx="diag")
metric_block = qml.metric_tensor(circuit, approx="block-diag")


# ---------------------------
# Load frozen preprocessor from Block 1
# ---------------------------
realnet = torch.load("realnet_cifar10_results.pt", weights_only=False)
pre_state = realnet["preprocessor_state"]

pre = SharedPreprocessor(3072, 16)  # CIFAR-10: 3072 input dims
pre.load_state_dict(pre_state)
for p in pre.parameters():
    p.requires_grad = False
pre.eval()

# frozen feature_select (16 -> 4)
feature_select = nn.Linear(16, N_QUBITS)
for p in feature_select.parameters():
    p.requires_grad = False
feature_select.eval()


# ---------------------------
# Dataset: CIFAR-10
# ---------------------------
# CIFAR-10 normalization (channel-wise)
norm_mean = [0.4914, 0.4822, 0.4465]
norm_std = [0.2470, 0.2435, 0.2616]

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

train_ds = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)


# ---------------------------
# Fixed readout (constant) to isolate θ update behavior
# ---------------------------
C = 10
W = pnp.array(np.random.randn(C, 6).astype(np.float64) * 0.1)  # (10,6)
b = pnp.array(np.zeros((C,), dtype=np.float64))


# ---------------------------
# Loss factory (depends on current minibatch)
# ---------------------------
def make_batch_ce_loss(x_np, y_np):
    def batch_ce_loss(theta):
        losses = []
        for i in range(len(y_np)):
            q_raw = circuit(x_np[i], theta)
            q = pnp.stack(q_raw)  # (6,)
            logits = W @ q + b
            logits = logits - pnp.max(logits)
            ex = pnp.exp(logits)
            probs = ex / pnp.sum(ex)
            losses.append(-pnp.log(probs[int(y_np[i])] + 1e-12))
        return pnp.mean(pnp.stack(losses))
    return batch_ce_loss


# ---------------------------
# FS-NG / Euclid delta
# ---------------------------
def fsng_or_euclid_delta(theta, approx: str, batch_ce_loss, x_np):
    """
    approx in {"euclid","diag","block-diag","full"}.

    - euclid: delta = g
    - others: delta = (G_eff + λI)^(-1) g
              with optional shrinkage: G_eff = α G_bar + (1-α) I
    """
    theta_shape = theta.shape
    P = int(np.prod(theta_shape))
    I = pnp.eye(P, dtype=pnp.float64)

    g = qml.grad(batch_ce_loss)(theta)
    g_flat = pnp.reshape(g, (P,))

    if approx == "euclid":
        return np.array(g_flat, dtype=np.float64)

    if approx == "diag":
        metric_fn = metric_diag
    elif approx == "block-diag":
        metric_fn = metric_block
    elif approx == "full":
        metric_fn = metric_full
    else:
        raise ValueError("approx must be one of: 'euclid','diag','block-diag','full'")

    G_sum = pnp.zeros((P, P), dtype=pnp.float64)
    for i in range(x_np.shape[0]):
        Gi = metric_fn(x_np[i], theta)
        Gi = pnp.reshape(Gi, (P, P))
        G_sum = G_sum + Gi
    G_bar = G_sum / float(x_np.shape[0])

    alpha = float(ALPHA_SHRINK)
    G_eff = alpha * G_bar + (1.0 - alpha) * I
    G_reg = G_eff + LAM * I

    delta = pnp.linalg.solve(G_reg, g_flat)
    return np.array(delta, dtype=np.float64)


# ---------------------------
# Adam (PyTorch-style) on θ only
# ---------------------------
class AdamState:
    def __init__(self, P, betas=(0.9, 0.999)):
        self.t = 0
        self.m = np.zeros((P,), dtype=np.float64)
        self.v = np.zeros((P,), dtype=np.float64)
        self.beta1, self.beta2 = betas

def adam_step(theta, grad_flat, state: AdamState, lr=1e-3, eps=1e-8, weight_decay=0.0):
    """
    Deterministic Adam update on a *flat* theta vector.
    Returns updated flat theta and the applied update direction (delta).
    """
    state.t += 1

    g = grad_flat.astype(np.float64)
    if weight_decay != 0.0:
        g = g + weight_decay * theta

    b1, b2 = state.beta1, state.beta2
    state.m = b1 * state.m + (1.0 - b1) * g
    state.v = b2 * state.v + (1.0 - b2) * (g * g)

    mhat = state.m / (1.0 - b1 ** state.t)
    vhat = state.v / (1.0 - b2 ** state.t)

    step = lr * mhat / (np.sqrt(vhat) + eps)   # this is the update magnitude
    theta_new = theta - step
    return theta_new, step


# ---------------------------
# One minibatch run for one method
# ---------------------------
def run_one_minibatch(theta0, method, x_np, y_np):
    """
    Runs K_STEPS on a single minibatch and returns:
      losses: (K_STEPS,)
      delta1: first-step update direction (flat, numpy)
      elapsed
    """
    theta = pnp.array(theta0, requires_grad=True)
    theta_shape = theta.shape
    P = int(np.prod(theta_shape))

    batch_ce_loss = make_batch_ce_loss(x_np, y_np)

    # Adam needs state per minibatch (matches your "diagnostic" intent)
    adam_state = AdamState(P, betas=ADAM_BETAS) if method == "adam" else None

    losses = []
    delta1 = None

    t0 = time.time()
    for k in range(K_STEPS):
        loss_before = float(batch_ce_loss(theta))

        if method == "adam":
            g = qml.grad(batch_ce_loss)(theta)
            g_flat = np.array(pnp.reshape(g, (P,)), dtype=np.float64)

            theta_flat = np.array(pnp.reshape(theta, (P,)), dtype=np.float64)
            theta_flat_new, step_vec = adam_step(
                theta_flat, g_flat, adam_state,
                lr=ADAM_LR, eps=ADAM_EPS, weight_decay=ADAM_WEIGHT_DECAY
            )
            # For cosine comparisons, treat Adam's *applied update* as delta
            delta = step_vec / (ADAM_LR + 1e-12)  # normalized-ish direction (optional)
            theta = _unflat(theta_flat_new, theta_shape)

        else:
            delta = fsng_or_euclid_delta(theta, method, batch_ce_loss, x_np)
            theta_flat = np.array(pnp.reshape(theta, (P,)), dtype=np.float64)
            theta = _unflat(theta_flat - ETA * delta, theta_shape)

        loss_after = float(batch_ce_loss(theta))
        losses.append(loss_after)

        if k == 0:
            # Store first-step applied direction in comparable units:
            # - for FS/euclid: delta is the preconditioned direction (so update is ETA*delta)
            # - for Adam: delta above is roughly direction; we'll store the *applied step* too
            if method == "adam":
                delta1 = np.array(step_vec, dtype=np.float64)  # actual applied update
            else:
                delta1 = np.array(ETA * delta, dtype=np.float64)  # applied update

    return np.array(losses), delta1, time.time() - t0


# ---------------------------
# Collect minibatches (deterministic but seed-dependent)
# ---------------------------
def get_minibatches(seed, n_minibatches, batch_size):
    """
    Returns list of (x_np, y_np) minibatches.
    Uses a deterministic RNG over indices so we don't depend on DataLoader nondeterminism.
    """
    rng = np.random.RandomState(seed)
    batches = []
    for _ in range(n_minibatches):
        idx = rng.choice(len(train_ds), size=batch_size, replace=False)
        x0, y0 = zip(*[train_ds[i] for i in idx])
        x0 = torch.stack(list(x0), dim=0)
        y0 = torch.tensor(list(y0), dtype=torch.long)

        with torch.no_grad():
            feats16 = pre(x0)
            xq = torch.tanh(feature_select(feats16))

        batches.append((xq.cpu().numpy(), y0.cpu().numpy()))
    return batches


# ---------------------------
# Main experiment loop
# ---------------------------
METHODS = ["euclid", "adam", "diag", "block-diag", "full"]

def summarize(arr):
    return float(np.mean(arr)), float(np.std(arr))


def main():
    print(f"\n[Euclid vs Adam vs FS-NG] DATASET={DATASET} | backend={backend}")
    print(f"  layers={N_LAYERS} | batch={BATCH_SIZE} | minibatches={N_MINIBATCHES} | steps={K_STEPS}")
    print(f"  ETA={ETA} | LAM={LAM} | alpha_shrink={ALPHA_SHRINK} | Adam(lr={ADAM_LR}, betas={ADAM_BETAS})\n")

    # Store per-method results across (seed, minibatch)
    final_losses = {m: [] for m in METHODS}
    times = {m: [] for m in METHODS}
    cos_vs_full = {m: [] for m in METHODS if m != "full"}  # cosine of step1 update vs full step1 update

    for seed in SEEDS:
        set_all_seeds(seed)
        print(f"\n{'='*70}\nSEED {seed}\n{'='*70}")

        # same initial theta per seed across methods
        theta0 = pnp.array(np.random.randn(N_LAYERS, N_QUBITS, 2) * 0.1, requires_grad=True)

        minibatches = get_minibatches(seed, N_MINIBATCHES, BATCH_SIZE)

        for mb_i, (x_np, y_np) in enumerate(minibatches, start=1):
            # run full first for reference direction cosine
            full_losses, full_delta1, full_t = run_one_minibatch(theta0, "full", x_np, y_np)
            final_losses["full"].append(full_losses[-1])
            times["full"].append(full_t)

            # run others
            for m in ["euclid", "adam", "diag", "block-diag"]:
                losses, delta1, tsec = run_one_minibatch(theta0, m, x_np, y_np)
                final_losses[m].append(losses[-1])
                times[m].append(tsec)

                # cosine on step-1 APPLIED update (so compare apples-to-apples)
                cos_vs_full[m].append(_cosine(delta1, full_delta1))

            if mb_i in (1, N_MINIBATCHES):
                print(f"  minibatch {mb_i:2d}/{N_MINIBATCHES}: "
                      f"full_lossK={full_losses[-1]:.4f} | "
                      f"diag_lossK={final_losses['diag'][-1]:.4f} | "
                      f"adam_lossK={final_losses['adam'][-1]:.4f}")

    print("\n" + "="*70)
    print("AGGREGATE RESULTS (across 3 seeds × 10 minibatches = 30 runs per method)")
    print("="*70)

    # Report: mean±std final loss, mean±std time, mean cosine vs full
    for m in METHODS:
        muL, sdL = summarize(final_losses[m])
        muT, sdT = summarize(times[m])
        if m == "full":
            print(f"  {m:9s}: final_loss={muL:.4f} ± {sdL:.4f} | time={muT:.2f}s ± {sdT:.2f}s | cos(step1 vs full)=1.000")
        else:
            muC, sdC = summarize(cos_vs_full[m])
            print(f"  {m:9s}: final_loss={muL:.4f} ± {sdL:.4f} | time={muT:.2f}s ± {sdT:.2f}s | cos(step1 vs full)={muC:.3f} ± {sdC:.3f}")

    print("\nNotes:")
    print("  • Cosines compare the *applied step vector* at step 1 (not the raw gradient).")
    print("  • Adam's delta is its actual applied update on θ (per minibatch, fresh state).")
    print("  • If you want Adam to carry momentum across minibatches, move AdamState outside run_one_minibatch().")


if __name__ == "__main__":
    main()

[device] backend=lightning.gpu, wires=5 (aux_wire=4)

[Euclid vs Adam vs FS-NG] DATASET=CIFAR-10 | backend=lightning.gpu
  layers=2 | batch=8 | minibatches=10 | steps=10
  ETA=0.01 | LAM=0.001 | alpha_shrink=1.0 | Adam(lr=0.001, betas=(0.9, 0.999))


SEED 42
  minibatch  1/10: full_lossK=2.3282 | diag_lossK=2.3282 | adam_lossK=2.3290
  minibatch 10/10: full_lossK=2.3053 | diag_lossK=2.3053 | adam_lossK=2.3054

SEED 123
  minibatch  1/10: full_lossK=2.3336 | diag_lossK=2.3338 | adam_lossK=2.3337
  minibatch 10/10: full_lossK=2.3474 | diag_lossK=2.3473 | adam_lossK=2.3479

SEED 456
  minibatch  1/10: full_lossK=2.3449 | diag_lossK=2.3449 | adam_lossK=2.3446
  minibatch 10/10: full_lossK=2.3083 | diag_lossK=2.3082 | adam_lossK=2.3079

AGGREGATE RESULTS (across 3 seeds × 10 minibatches = 30 runs per method)
  euclid   : final_loss=2.3204 ± 0.0357 | time=6.09s ± 0.52s | cos(step1 vs full)=0.958 ± 0.031
  adam     : final_loss=2.3196 ± 0.0357 | time=5.97s ± 0.39s | cos(step1 vs full)=0.689 ±