# Code

In [1]:
"""
Clean SU(2) Comparison: Real vs Quaternion vs Quantum (LIGHTNING)
===================================================================

Research Question:
Are quaternion networks (classical SU(2)) approximations that capture
most of what quantum circuits do, without needing entanglement?

Lightning Acceleration:
- Uses pennylane-lightning-gpu if available (~50-100x speedup)
- Falls back to lightning.qubit (CPU) if no GPU (~10x speedup)
- Same architecture as SANER version (16-D bottleneck)

Install: pip install pennylane pennylane-lightning pennylane-lightning-gpu
"""

import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Try to import PennyLane and set quantum device
try:
    import pennylane as qml
    QUANTUM_DEVICE = "lightning.gpu"
    PENNYLANE_AVAILABLE = True
    print("✓ Using lightning.gpu device (requires: pip install pennylane-lightning-gpu)")
except ImportError:
    PENNYLANE_AVAILABLE = False
    QUANTUM_DEVICE = None
    print("✗ PennyLane not installed - quantum models disabled")
    print("  Install with: pip install pennylane pennylane-lightning-gpu")

device = torch.device("cuda")
print(f"Using PyTorch device: {device}")

# ============================================
# Quaternion utilities (PyTorch tensors)
# ============================================

def q_normalize(q):
    norm = torch.linalg.norm(q, dim=-1, keepdim=True) + 1e-8
    return q / norm

def q_conj(q):
    w, x, y, z = torch.unbind(q, dim=-1)
    return torch.stack([w, -x, -y, -z], dim=-1)

def q_mul(a, b):
    """Hamilton product of two quaternions"""
    aw, ax, ay, az = torch.unbind(a, dim=-1)
    bw, bx, by, bz = torch.unbind(b, dim=-1)

    w = aw * bw - ax * bx - ay * by - az * bz
    x = aw * bx + ax * bw + ay * bz - az * by
    y = aw * by - ax * bz + ay * bw + az * bx
    z = aw * bz + ax * by - ay * bx + az * bw

    return torch.stack([w, x, y, z], dim=-1)


# ============================================
# Shared Bottleneck Preprocessor: 784 → 16
# ============================================

class SharedPreprocessor(nn.Module):
    """Shared classical feature extractor: 784 → 16"""
    def __init__(self, input_dim=784, bottleneck_dim=16):
        super().__init__()
        self.fc = nn.Linear(input_dim, bottleneck_dim)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.tanh(self.fc(x))
        return x


# ============================================
# Real-valued Head and Network
# ============================================

class RealHead(nn.Module):
    """Standard MLP: 16 → 64 → 10"""
    def __init__(self, bottleneck_dim=16, hidden_dim=64, num_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(bottleneck_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


class RealNet(nn.Module):
    """Complete Real network: Preprocessor + RealHead"""
    def __init__(self):
        super().__init__()
        self.preprocessor = SharedPreprocessor(784, 16)
        self.head = RealHead(16, 64, 10)

    def forward(self, x):
        features = self.preprocessor(x)
        return self.head(features)


# ============================================
# Quaternion Head and Network
# ============================================

class QuaternionLinear(nn.Module):
    def __init__(self, in_features, out_features):
        """
        in_features, out_features are in "quaternion units".
        Internally weight: (out_features, in_features, 4)
        """
        super().__init__()
        self.weight = nn.Parameter(torch.empty(out_features, in_features, 4))
        self.bias = nn.Parameter(torch.empty(out_features, 4))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.weight, mean=0.0, std=0.1)
        with torch.no_grad():
            self.weight[:] = q_normalize(self.weight)
            nn.init.constant_(self.bias[..., 0], 1.0)
            nn.init.constant_(self.bias[..., 1:], 0.0)

    def forward(self, x):
        """
        x: (B, in_features, 4)
        Returns: (B, out_features, 4)
        """
        w = self.weight.unsqueeze(0)
        x_exp = x.unsqueeze(1)
        prod = q_mul(w, x_exp)
        out = prod.sum(dim=2) + self.bias
        return out


class QuatHead(nn.Module):
    """
    Quaternion head: 4 quats → 16 quats → 10 quats → 10 logits.
    """
    def __init__(self, num_classes=10):
        super().__init__()
        self.quat_fc1 = QuaternionLinear(4, 16)
        self.quat_fc2 = QuaternionLinear(16, num_classes)

    def real_to_quat(self, x):
        """Convert 16 real features to 4 quaternions"""
        B = x.size(0)
        return x.view(B, 4, 4)

    def quat_to_real(self, q):
        """Extract real part of quaternions for classification"""
        return q[..., 0]

    def forward(self, x):
        q_in = self.real_to_quat(x)
        hq = self.quat_fc1(q_in)
        hq = q_normalize(hq)
        hq = torch.tanh(hq)
        q_out = self.quat_fc2(hq)
        logits = self.quat_to_real(q_out)
        return logits


class QuatNet(nn.Module):
    """Complete Quaternion network: Preprocessor + QuatHead"""
    def __init__(self):
        super().__init__()
        self.preprocessor = SharedPreprocessor(784, 16)
        self.head = QuatHead(10)

    def forward(self, x):
        features = self.preprocessor(x)
        return self.head(features)


# ============================================
# Quantum Head (Lightning-accelerated)
# ============================================

if PENNYLANE_AVAILABLE:
    class QuantumHead(nn.Module):
        """
        VQC with 4 qubits, 3 layers → 10 classes.
        Uses Lightning acceleration (GPU or CPU).
        """
        def __init__(self, n_qubits=4, n_layers=3, num_classes=10, 
                     use_entanglement=True, device_name=None):
            super().__init__()
            self.n_qubits = n_qubits
            self.n_layers = n_layers
            self.use_entanglement = use_entanglement

            # Map 16 features → n_qubits
            self.feature_select = nn.Linear(16, n_qubits)

            # Use specified device or default
            if device_name is None:
                device_name = QUANTUM_DEVICE
            
            # Quantum device
            self.dev = qml.device(device_name, wires=n_qubits)

            # Use adjoint differentiation for lightning (much faster)
            diff_method = "adjoint" if "lightning" in device_name else "parameter-shift"

            @qml.qnode(self.dev, interface="torch", diff_method=diff_method)
            def quantum_circuit(inputs, weights):
                """
                Single-sample circuit.
                inputs: (n_qubits,)
                weights: (n_layers, n_qubits, 2)
                """
                for layer in range(n_layers):
                    # Data re-uploading
                    for i in range(n_qubits):
                        qml.RY(inputs[i], wires=i)

                    # Trainable rotations
                    for i in range(n_qubits):
                        qml.RY(weights[layer, i, 0], wires=i)
                        qml.RZ(weights[layer, i, 1], wires=i)

                    # Optional entanglement
                    if self.use_entanglement:
                        for i in range(n_qubits - 1):
                            qml.CNOT(wires=[i, i + 1])
                        if n_qubits > 2:
                            qml.CNOT(wires=[n_qubits - 1, 0])

                return [
                    qml.expval(qml.PauliZ(0)),
                    qml.expval(qml.PauliZ(1)),
                    qml.expval(qml.PauliZ(2)),
                    qml.expval(qml.PauliZ(3)),
                    qml.expval(qml.PauliZ(0) @ qml.PauliZ(1)),
                    qml.expval(qml.PauliZ(2) @ qml.PauliZ(3))
                ]

            self.quantum_circuit = quantum_circuit

            weight_shape = (n_layers, n_qubits, 2)
            self.q_weights = nn.Parameter(torch.randn(weight_shape) * 0.1)
            self.fc_out = nn.Linear(6, num_classes)

        def forward(self, x):
            """
            Process samples with progress tracking.
            x: (batch, 16) real bottleneck features
            """
            batch_size = x.size(0)
            x = torch.tanh(self.feature_select(x))

            # Process in chunks for memory management
            chunk_size = 32
            quantum_outputs = []

            for start_idx in range(0, batch_size, chunk_size):
                end_idx = min(start_idx + chunk_size, batch_size)
                chunk = x[start_idx:end_idx]

                chunk_outputs = []
                for i in range(chunk.size(0)):
                    q_raw = self.quantum_circuit(chunk[i], self.q_weights)
                    if isinstance(q_raw, (list, tuple)):
                        q_out = torch.stack(q_raw)
                    else:
                        q_out = q_raw
                    chunk_outputs.append(q_out)

                quantum_outputs.extend(chunk_outputs)

            # Convert to tensor (cast to float32)
            quantum_outputs = torch.stack(quantum_outputs).float()
            quantum_outputs = quantum_outputs.to(self.fc_out.weight.dtype)

            output = self.fc_out(quantum_outputs)
            return output


    class QuantumNet(nn.Module):
        """Complete Quantum network: Preprocessor + QuantumHead"""
        def __init__(self, use_entanglement=True):
            super().__init__()
            self.preprocessor = SharedPreprocessor(784, 16)
            self.head = QuantumHead(
                n_qubits=4, n_layers=3, num_classes=10,
                use_entanglement=use_entanglement
            )

        def forward(self, x):
            features = self.preprocessor(x)
            return self.head(features)


# ============================================
# Training and Evaluation
# ============================================

def train_one_epoch(model, loader, optimizer, device, show_progress=True):
    model.train()
    total_loss = 0.0
    total_samples = 0

    for batch_idx, (x, y) in enumerate(loader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        total_samples += x.size(0)

        if show_progress and batch_idx % 50 == 0:
            print(f"    Batch {batch_idx}/{len(loader)}, samples: {total_samples}", end="\r")

    if show_progress:
        print()
    return total_loss / len(loader.dataset)


def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    return correct / total if total > 0 else 0.0


def train_with_early_stopping(model, train_loader, test_loader, optimizer,
                              device, max_epochs=40, patience=8, name="Model"):
    if isinstance(device, str):
        device = torch.device(device)

    best_acc = 0.0
    epochs_without_improvement = 0
    start = time.time()
    last_acc = 0.0

    # Show progress for quantum models, suppress for fast models
    show_progress = "Quantum" in name

    for epoch in range(1, max_epochs + 1):
        if show_progress:
            print(f"  [{name}] Epoch {epoch}/{max_epochs}")
        
        loss = train_one_epoch(model, train_loader, optimizer, device, show_progress)
        acc = evaluate(model, test_loader, device)
        last_acc = acc

        elapsed = time.time() - start
        print(f"  [{name}] Epoch {epoch:2d} | loss={loss:.4f} "
              f"| test_acc={acc:.4f} | time={elapsed:.1f}s")

        if acc > best_acc:
            best_acc = acc
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print(f"  [{name}] Early stop at epoch {epoch} "
                  f"(no improvement for {patience} epochs)")
            break

    total_time = time.time() - start
    return {
        "best_acc": best_acc,
        "final_acc": last_acc,
        "time": total_time,
        "epochs": epoch,
    }


# ============================================
# Per-seed experiment
# ============================================

def run_single_seed(seed, train_loader, test_loader, use_quantum=True):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    np.random.seed(seed)

    results = {}

    # ---------------- RealNet ----------------
    print(f"\n  Training RealNet (seed={seed})...")
    real_model = RealNet().to(device)
    real_opt = torch.optim.Adam(real_model.parameters(), lr=1e-3)
    results["Real"] = train_with_early_stopping(
        real_model, train_loader, test_loader, real_opt, device,
        max_epochs=40, patience=8, name="Real"
    )
    results["Real"]["params"] = sum(p.numel() for p in real_model.parameters())

    # ---------------- Freeze preprocessor ----------------
    shared_preprocessor = real_model.preprocessor
    for p in shared_preprocessor.parameters():
        p.requires_grad = False

    print(f"\n  → Preprocessor frozen with "
          f"{sum(p.numel() for p in shared_preprocessor.parameters())} params")

    # ---------------- QuatNet ----------------
    print(f"\n  Training QuatNet with frozen preprocessor (seed={seed})...")
    quat_model = QuatNet().to(device)
    quat_model.preprocessor = shared_preprocessor
    quat_opt = torch.optim.Adam(quat_model.head.parameters(), lr=1e-3)
    results["Quat"] = train_with_early_stopping(
        quat_model, train_loader, test_loader, quat_opt, device,
        max_epochs=40, patience=8, name="Quat"
    )
    results["Quat"]["params"] = sum(p.numel() for p in quat_model.parameters()
                                    if p.requires_grad)
    results["Quat"]["total_params"] = sum(p.numel() for p in quat_model.parameters())

    # ---------------- QuantumNet (if available) ----------------
    if use_quantum and PENNYLANE_AVAILABLE:
        from torch.utils.data import DataLoader as TorchDataLoader

        # Smaller batches for quantum models
        quantum_train_loader = TorchDataLoader(
            train_loader.dataset, batch_size=32, shuffle=True, num_workers=0
        )
        quantum_test_loader = TorchDataLoader(
            test_loader.dataset, batch_size=64, shuffle=False, num_workers=0
        )

        # Copy frozen preprocessor to CPU
        shared_preprocessor_cpu = SharedPreprocessor(784, 16)
        shared_preprocessor_cpu.load_state_dict(shared_preprocessor.state_dict())
        for p in shared_preprocessor_cpu.parameters():
            p.requires_grad = False

        # ---- Quantum no entanglement ----
        print(f"\n  Training QuantumNet (NO entanglement, seed={seed}, {QUANTUM_DEVICE})...")
        qno_model = QuantumNet(use_entanglement=False)
        qno_model.preprocessor = shared_preprocessor_cpu
        qno_opt = torch.optim.Adam(qno_model.head.parameters(), lr=1e-3)
        
        # On GPU-accelerated lightning, allow much longer training with patience-based stopping.
        # This lets the quantum model actually reach its plateau instead of hitting a hard epoch cap.
        results["QNoEnt"] = train_with_early_stopping(
            qno_model, quantum_train_loader, quantum_test_loader,
            qno_opt,
            device="cpu",          # quantum work still happens on GPU via PennyLane
            max_epochs=200,        # was 30
            patience=20,           # was 6
            name="QuantumNoEnt"
        )
        results["QNoEnt"]["params"] = sum(p.numel() for p in qno_model.parameters()
                                          if p.requires_grad)
        results["QNoEnt"]["total_params"] = sum(p.numel() for p in qno_model.parameters())
        
        # ---- Quantum with entanglement ----
        print(f"\n  Training QuantumNet (WITH entanglement, seed={seed}, {QUANTUM_DEVICE})...")
        qent_model = QuantumNet(use_entanglement=True)
        qent_model.preprocessor = shared_preprocessor_cpu
        qent_opt = torch.optim.Adam(qent_model.head.parameters(), lr=1e-3)
        
        results["QEnt"] = train_with_early_stopping(
            qent_model, quantum_train_loader, quantum_test_loader,
            qent_opt,
            device="cpu",          # same logic: PennyLane uses GPU internally
            max_epochs=200,        # was 30
            patience=20,           # was 6
            name="QuantumEnt"
        )
        results["QEnt"]["params"] = sum(p.numel() for p in qent_model.parameters()
                                        if p.requires_grad)
        results["QEnt"]["total_params"] = sum(p.numel() for p in qent_model.parameters())

        
    return results


# ============================================
# Main Experiment
# ============================================

def stratified_sample(dataset, n_samples_per_class):
    """
    Create a stratified sample with n_samples_per_class from each class.
    Maintains class balance.
    """
    from collections import defaultdict
    
    # Group indices by class
    class_indices = defaultdict(list)
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)
    
    # Sample from each class
    sampled_indices = []
    for class_label in sorted(class_indices.keys()):
        indices = class_indices[class_label]
        # Use fixed random seed for reproducibility
        rng = np.random.RandomState(42)
        selected = rng.choice(indices, size=min(n_samples_per_class, len(indices)), 
                             replace=False)
        sampled_indices.extend(selected)
    
    return sampled_indices


def main():
    print("=" * 70)
    print("CLEAN SU(2) COMPARISON: Real vs Quaternion vs Quantum (LIGHTNING)")
    print("=" * 70)
    print("\nExperimental Design:")
    print("  • Dataset: MNIST (28×28 grayscale)")
    print("  • Stratified sampling: 15K train (1,500/class), 3K test (300/class)")
    print("    - Rationale: Models have ~10-15K parameters; 60K samples unnecessary")
    print("    - Enables 4-6x faster quantum training while maintaining valid comparison")
    print("    - All models train on identical stratified samples")
    print("  • Shared preprocessor: 784 → 16 features")
    print("    - Trained with RealNet, then frozen")
    print("    - Reused (frozen) for Quat and Quantum heads")
    print("  • Heads on identical 16-D frozen features:")
    print("    - Real head: 16 → 64 → 10 (standard MLP)")
    print("    - Quaternion head: 4 quats → 16 quats → 10 quats → 10")
    print("    - Quantum head (no ent): 4 qubits, 3 layers → 10")
    print("    - Quantum head (entangled): 4 qubits, 3 layers + CNOT ring → 10")
    print(f"\nQuantum device: {QUANTUM_DEVICE}")
    if "lightning.gpu" in str(QUANTUM_DEVICE):
        print("  Expected quantum training time: ~30-45 min per seed")
    elif "lightning" in str(QUANTUM_DEVICE):
        print("  Expected quantum training time: ~1-2 hours per seed")
    else:
        print("  Expected quantum training time: ~6-8 hours per seed")
    print("=" * 70)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    # Load full datasets
    full_train_ds = datasets.MNIST(root="./data", train=True,
                                   download=True, transform=transform)
    full_test_ds = datasets.MNIST(root="./data", train=False,
                                  download=True, transform=transform)

    # Create stratified samples
    print("\nCreating stratified samples...")
    train_indices = stratified_sample(full_train_ds, n_samples_per_class=1500)  # 15K total
    test_indices = stratified_sample(full_test_ds, n_samples_per_class=300)     # 3K total
    
    from torch.utils.data import Subset
    train_ds = Subset(full_train_ds, train_indices)
    test_ds = Subset(full_test_ds, test_indices)
    
    print(f"  Train samples: {len(train_ds)} (stratified, 1500 per class)")
    print(f"  Test samples:  {len(test_ds)} (stratified, 300 per class)")

    train_loader = DataLoader(train_ds, batch_size=128,
                              shuffle=True, num_workers=0)
    test_loader = DataLoader(test_ds, batch_size=256,
                             shuffle=False, num_workers=0)

    seeds = [42, 123, 456]
    all_results = {
        "Real": [],
        "Quat": [],
        "QNoEnt": [],
        "QEnt": [],
    }

    for seed in seeds:
        print(f"\n{'=' * 70}")
        print(f"SEED {seed}")
        print("=" * 70)
        seed_results = run_single_seed(seed, train_loader, test_loader, 
                                      use_quantum=PENNYLANE_AVAILABLE)

        for key in all_results.keys():
            if key in seed_results:
                all_results[key].append(seed_results[key])

    # Aggregate
    print("\n" + "=" * 70)
    print("AGGREGATED RESULTS (mean ± std over seeds)")
    print("=" * 70)

    def summarize(name):
        if len(all_results[name]) == 0:
            return
        accs = [r["best_acc"] for r in all_results[name]]
        times = [r["time"] for r in all_results[name]]
        epochs = [r["epochs"] for r in all_results[name]]

        if name == "Real":
            params = all_results[name][0]["params"]
            param_str = f"{params:,}"
        else:
            trainable = all_results[name][0]["params"]
            total = all_results[name][0]["total_params"]
            param_str = f"{trainable:,} trainable (head), {total:,} total"

        print(f"\n{name:8s}:")
        print(f"  Accuracy:   {np.mean(accs):.4f} ± {np.std(accs):.4f}")
        print(f"  Time:       {np.mean(times):.1f}s (avg)")
        print(f"  Epochs:     {np.mean(epochs):.1f} (avg)")
        print(f"  Parameters: {param_str}")

    for name in ["Real", "Quat", "QNoEnt", "QEnt"]:
        summarize(name)

    # Comparative analysis
    print("\n" + "=" * 70)
    print("COMPARATIVE ANALYSIS")
    print("=" * 70)

    real_accs = [r["best_acc"] for r in all_results["Real"]]
    quat_accs = [r["best_acc"] for r in all_results["Quat"]]
    
    if all_results["QNoEnt"]:
        qno_accs = [r["best_acc"] for r in all_results["QNoEnt"]]
        qent_accs = [r["best_acc"] for r in all_results["QEnt"]]

        def pct_gap(a, b):
            return (np.mean(a) - np.mean(b)) * 100.0

        print(f"\n1. Quat vs Real:")
        print(f"   Gap: {pct_gap(real_accs, quat_accs):.2f} percentage points")
        print(f"   Quat captures {np.mean(quat_accs)/np.mean(real_accs)*100:.1f}% of Real performance")

        print(f"\n2. Quantum (no ent) vs Real:")
        print(f"   Gap: {pct_gap(real_accs, qno_accs):.2f} percentage points")
        print(f"   QNoEnt captures {np.mean(qno_accs)/np.mean(real_accs)*100:.1f}% of Real performance")

        print(f"\n3. Quantum (ent) vs Real:")
        print(f"   Gap: {pct_gap(real_accs, qent_accs):.2f} percentage points")
        print(f"   QEnt captures {np.mean(qent_accs)/np.mean(real_accs)*100:.1f}% of Real performance")

        print(f"\n4. Quantum (ent) vs Quantum (no ent):")
        print(f"   Gap: {pct_gap(qno_accs, qent_accs):.2f} percentage points")
        if np.mean(qent_accs) > np.mean(qno_accs):
            print("   → Entanglement improves performance.")
        else:
            print("   → Entanglement does not improve performance here.")

        print(f"\n5. Quantum vs Quaternion:")
        print(f"   Gap (Quat - QNoEnt): {pct_gap(quat_accs, qno_accs):.2f} points")
        print(f"   Gap (Quat - QEnt):   {pct_gap(quat_accs, qent_accs):.2f} points")
    else:
        print("\nQuantum models not trained (PennyLane not available)")
        print("Only Real vs Quat comparison available:")
        def pct_gap(a, b):
            return (np.mean(a) - np.mean(b)) * 100.0
        print(f"   Gap: {pct_gap(real_accs, quat_accs):.2f} percentage points")
        print(f"   Quat captures {np.mean(quat_accs)/np.mean(real_accs)*100:.1f}% of Real performance")

    print("\n" + "=" * 70)
    print("KEY INTERPRETATION HINTS")
    print("=" * 70)
    print("Check:")
    print("  • How close Quat is to Real (does classical SU(2) keep up with MLP?).")
    if all_results["QNoEnt"]:
        print("  • Whether QNoEnt ≈ Quat (quantum SU(2) w/o entanglement vs quaternion).")
        print("  • Whether QEnt > QNoEnt (empirical value of entanglement on MNIST).")
    print("=" * 70)


if __name__ == "__main__":
    main()

✓ Using lightning.gpu device (requires: pip install pennylane-lightning-gpu)
Using PyTorch device: cpu
CLEAN SU(2) COMPARISON: Real vs Quaternion vs Quantum (LIGHTNING)

Experimental Design:
  • Dataset: MNIST (28×28 grayscale)
  • Stratified sampling: 15K train (1,500/class), 3K test (300/class)
    - Rationale: Models have ~10-15K parameters; 60K samples unnecessary
    - Enables 4-6x faster quantum training while maintaining valid comparison
    - All models train on identical stratified samples
  • Shared preprocessor: 784 → 16 features
    - Trained with RealNet, then frozen
    - Reused (frozen) for Quat and Quantum heads
  • Heads on identical 16-D frozen features:
    - Real head: 16 → 64 → 10 (standard MLP)
    - Quaternion head: 4 quats → 16 quats → 10 quats → 10
    - Quantum head (no ent): 4 qubits, 3 layers → 10
    - Quantum head (entangled): 4 qubits, 3 layers + CNOT ring → 10

Quantum device: lightning.gpu
  Expected quantum training time: ~30-45 min per seed

Creating 