# RealNet CIFAR

In [1]:
"""
Block 1: RealNet Training (CIFAR-10)
=====================================
Trains the baseline Real MLP on 3 seeds using CIFAR-10.
Saves the trained preprocessor for reuse in subsequent blocks.

Outputs:
- realnet_results.pt: Contains results dict and trained preprocessor state
"""

import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

device = torch.device("cuda")
print(f"Using PyTorch device: {device}")

# ============================================
# Comprehensive seed setting for reproducibility
# ============================================

def set_all_seeds(seed):
    """Set seeds for all RNG sources for reproducibility"""
    import random
    import numpy as np
    import torch
    
    # Python
    random.seed(seed)
    
    # NumPy
    np.random.seed(seed)
    
    # PyTorch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    
    # PyTorch deterministic mode
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # CuPy (if available)
    try:
        import cupy as cp
        cp.random.seed(seed)
    except ImportError:
        pass
    
    # PennyLane (if available)
    try:
        import pennylane as qml
        qml.numpy.random.seed(seed)
    except (ImportError, AttributeError):
        pass

# ============================================
# Shared Bottleneck Preprocessor: 3072 → 16
# ============================================

class SharedPreprocessor(nn.Module):
    """Shared classical feature extractor: 3072 → 16"""
    def __init__(self, input_dim=3072, bottleneck_dim=16):
        super().__init__()
        self.fc = nn.Linear(input_dim, bottleneck_dim)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten to (batch, 3072)
        x = torch.tanh(self.fc(x))
        return x


# ============================================
# Real-valued Head and Network
# ============================================

class RealHead(nn.Module):
    """Standard MLP: 16 → 64 → 10"""
    def __init__(self, bottleneck_dim=16, hidden_dim=64, num_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(bottleneck_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


class RealNet(nn.Module):
    """Complete Real network: Preprocessor + RealHead"""
    def __init__(self):
        super().__init__()
        self.preprocessor = SharedPreprocessor(3072, 16)
        self.head = RealHead(16, 64, 10)

    def forward(self, x):
        features = self.preprocessor(x)
        return self.head(features)


# ============================================
# Training and Evaluation
# ============================================

def train_one_epoch(model, loader, optimizer, device, show_progress=False):
    model.train()
    total_loss = 0.0
    total_samples = 0

    for batch_idx, (x, y) in enumerate(loader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        total_samples += x.size(0)

        if show_progress and batch_idx % 50 == 0:
            print(f"    Batch {batch_idx}/{len(loader)}, samples: {total_samples}", end="\r")

    if show_progress:
        print()
    return total_loss / len(loader.dataset)


def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    return correct / total if total > 0 else 0.0


def train_with_early_stopping(model, train_loader, test_loader, optimizer,
                              device, max_epochs=40, patience=10, name="Model"):
    if isinstance(device, str):
        device = torch.device(device)

    best_acc = 0.0
    epochs_without_improvement = 0
    start = time.time()
    last_acc = 0.0

    for epoch in range(1, max_epochs + 1):
        loss = train_one_epoch(model, train_loader, optimizer, device, show_progress=False)
        acc = evaluate(model, test_loader, device)
        last_acc = acc

        elapsed = time.time() - start
        print(f"  [{name}] Epoch {epoch:2d} | loss={loss:.4f} "
              f"| test_acc={acc:.4f} | time={elapsed:.1f}s")

        if acc > best_acc:
            best_acc = acc
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print(f"  [{name}] Early stop at epoch {epoch} "
                  f"(no improvement for {patience} epochs)")
            break

    total_time = time.time() - start
    return {
        "best_acc": best_acc,
        "final_acc": last_acc,
        "time": total_time,
        "epochs": epoch,
    }


# ============================================
# Data utilities
# ============================================

def stratified_sample_from_targets(dataset, n_samples_per_class, seed=42):
    """
    Create a stratified sample with n_samples_per_class from each class.
    Uses dataset.targets directly - NO image loading during sampling.
    """
    # CIFAR-10 exposes targets as a list
    if hasattr(dataset, 'targets'):
        targets = np.array(dataset.targets)
    else:
        # Fallback for wrapped datasets
        targets = np.array([dataset[i][1] for i in range(len(dataset))])
    
    rng = np.random.RandomState(seed)
    
    sampled_indices = []
    for c in range(10):
        # Find all indices for class c
        idx_c = np.where(targets == c)[0]
        k = min(n_samples_per_class, len(idx_c))
        # Sample without replacement
        selected = rng.choice(idx_c, size=k, replace=False).tolist()
        sampled_indices.extend(selected)
    
    # Shuffle the combined indices
    rng.shuffle(sampled_indices)
    return sampled_indices


# ============================================
# Main
# ============================================

def main():
    print("=" * 70)
    print("BLOCK 1: RealNet Training (CIFAR-10)")
    print("=" * 70)
    print("\nConfiguration:")
    print("  • Dataset: CIFAR-10 (32×32 RGB)")
    print("  • Input dimension: 3072 (32×32×3 flattened)")
    print("  • Batch size: 128")
    print("  • Patience: 10")
    print("  • Seeds: [42, 123, 456]")
    print("  • Architecture: 3072 → 16 → 64 → 10")
    print("=" * 70)

    # CIFAR-10 normalization (channel-wise means and stds)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2470, 0.2435, 0.2616]
        )
    ])

    # Load full datasets
    full_train_ds = datasets.CIFAR10(root="./data", train=True,
                                     download=True, transform=transform)
    full_test_ds = datasets.CIFAR10(root="./data", train=False,
                                    download=True, transform=transform)

    # Create stratified samples (fast - uses targets only, no image loading)
    print("\nCreating stratified samples...")
    t0 = time.time()
    train_indices = stratified_sample_from_targets(full_train_ds, n_samples_per_class=1500, seed=42)
    test_indices = stratified_sample_from_targets(full_test_ds, n_samples_per_class=300, seed=42)
    print(f"  Sampling took {time.time()-t0:.2f}s")
    
    train_ds = Subset(full_train_ds, train_indices)
    test_ds = Subset(full_test_ds, test_indices)
    
    print(f"  Train samples: {len(train_ds)} (stratified, 1500 per class)")
    print(f"  Test samples:  {len(test_ds)} (stratified, 300 per class)")

    train_loader = DataLoader(train_ds, batch_size=128,
                              shuffle=True, num_workers=4)
    test_loader = DataLoader(test_ds, batch_size=256,
                             shuffle=False, num_workers=4)

    seeds = [42, 123, 456]
    all_results = []
    trained_preprocessor_state = None

    for seed in seeds:
        print(f"\n{'=' * 70}")
        print(f"SEED {seed}")
        print("=" * 70)
        
        set_all_seeds(seed)

        print(f"\n  Training RealNet (seed={seed})...")
        real_model = RealNet().to(device)
        real_opt = torch.optim.Adam(real_model.parameters(), lr=1e-3)
        
        result = train_with_early_stopping(
            real_model, train_loader, test_loader, real_opt, device,
            max_epochs=40, patience=10, name="Real"
        )
        result["params"] = sum(p.numel() for p in real_model.parameters())
        result["seed"] = seed
        
        all_results.append(result)
        
        # Save preprocessor from first seed
        if trained_preprocessor_state is None:
            trained_preprocessor_state = real_model.preprocessor.state_dict()
            print(f"\n  → Saved preprocessor state from seed {seed}")

    # Summary
    print("\n" + "=" * 70)
    print("REALNET SUMMARY (CIFAR-10)")
    print("=" * 70)
    
    accs = [r["best_acc"] for r in all_results]
    times = [r["time"] for r in all_results]
    epochs = [r["epochs"] for r in all_results]
    params = all_results[0]["params"]
    
    print(f"\nAccuracy:   {np.mean(accs):.4f} ± {np.std(accs):.4f}")
    print(f"Time:       {np.mean(times):.1f}s ± {np.std(times):.1f}s")
    print(f"Epochs:     {np.mean(epochs):.1f} ± {np.std(epochs):.1f}")
    print(f"Parameters: {params:,}")
    
    print("\nPer-seed results:")
    for r in all_results:
        print(f"  Seed {r['seed']}: acc={r['best_acc']:.4f}, "
              f"time={r['time']:.1f}s, epochs={r['epochs']}")

    # Save results and preprocessor
    save_dict = {
        "results": all_results,
        "preprocessor_state": trained_preprocessor_state,
        "summary": {
            "mean_acc": np.mean(accs),
            "std_acc": np.std(accs),
            "mean_time": np.mean(times),
            "params": params
        }
    }
    
    torch.save(save_dict, "realnet_cifar10_results.pt")
    print(f"\n✓ Saved results to: realnet_cifar10_results.pt")
    print("=" * 70)


if __name__ == "__main__":
    main()

Using PyTorch device: cuda
BLOCK 1: RealNet Training (CIFAR-10)

Configuration:
  • Dataset: CIFAR-10 (32×32 RGB)
  • Input dimension: 3072 (32×32×3 flattened)
  • Batch size: 128
  • Patience: 10
  • Seeds: [42, 123, 456]
  • Architecture: 3072 → 16 → 64 → 10

Creating stratified samples...
  Sampling took 0.00s
  Train samples: 15000 (stratified, 1500 per class)
  Test samples:  3000 (stratified, 300 per class)

SEED 42

  Training RealNet (seed=42)...
  [Real] Epoch  1 | loss=2.0209 | test_acc=0.3273 | time=0.7s
  [Real] Epoch  2 | loss=1.8281 | test_acc=0.3470 | time=1.1s
  [Real] Epoch  3 | loss=1.7635 | test_acc=0.3577 | time=1.6s
  [Real] Epoch  4 | loss=1.7250 | test_acc=0.3617 | time=1.9s
  [Real] Epoch  5 | loss=1.6946 | test_acc=0.3710 | time=2.3s
  [Real] Epoch  6 | loss=1.6776 | test_acc=0.3710 | time=2.7s
  [Real] Epoch  7 | loss=1.6463 | test_acc=0.3630 | time=3.1s
  [Real] Epoch  8 | loss=1.6154 | test_acc=0.3730 | time=3.5s
  [Real] Epoch  9 | loss=1.6088 | test_acc=0.

# QuatNet

In [2]:
"""
Block 2: QuatNet Training (CIFAR-10)
=====================================
Loads frozen preprocessor from Block 1 and trains quaternion head on 3 seeds.

Requirements:
- realnet_cifar10_results.pt (from Block 1)

Outputs:
- quatnet_cifar10_results.pt: Contains quaternion head results
"""

import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

device = torch.device("cuda")
print(f"Using PyTorch device: {device}")

# ============================================
# Comprehensive seed setting for reproducibility
# ============================================

def set_all_seeds(seed):
    """Set seeds for all RNG sources for reproducibility"""
    import random
    import numpy as np
    import torch
    
    # Python
    random.seed(seed)
    
    # NumPy
    np.random.seed(seed)
    
    # PyTorch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    
    # PyTorch deterministic mode
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # CuPy (if available)
    try:
        import cupy as cp
        cp.random.seed(seed)
    except ImportError:
        pass
    
    # PennyLane (if available)
    try:
        import pennylane as qml
        qml.numpy.random.seed(seed)
    except (ImportError, AttributeError):
        pass

# ============================================
# Quaternion utilities (PyTorch tensors)
# ============================================

def q_normalize(q):
    norm = torch.linalg.norm(q, dim=-1, keepdim=True) + 1e-8
    return q / norm

def q_conj(q):
    w, x, y, z = torch.unbind(q, dim=-1)
    return torch.stack([w, -x, -y, -z], dim=-1)

def q_mul(a, b):
    """Hamilton product of two quaternions"""
    aw, ax, ay, az = torch.unbind(a, dim=-1)
    bw, bx, by, bz = torch.unbind(b, dim=-1)

    w = aw * bw - ax * bx - ay * by - az * bz
    x = aw * bx + ax * bw + ay * bz - az * by
    y = aw * by - ax * bz + ay * bw + az * bx
    z = aw * bz + ax * by - ay * bx + az * bw

    return torch.stack([w, x, y, z], dim=-1)


# ============================================
# Shared Preprocessor (must match Block 1 - CIFAR-10)
# ============================================

class SharedPreprocessor(nn.Module):
    """Shared classical feature extractor: 3072 → 16"""
    def __init__(self, input_dim=3072, bottleneck_dim=16):
        super().__init__()
        self.fc = nn.Linear(input_dim, bottleneck_dim)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten to (batch, 3072)
        x = torch.tanh(self.fc(x))
        return x


# ============================================
# Quaternion Head and Network
# ============================================

class QuaternionLinear(nn.Module):
    def __init__(self, in_features, out_features):
        """
        in_features, out_features are in "quaternion units".
        Internally weight: (out_features, in_features, 4)
        """
        super().__init__()
        self.weight = nn.Parameter(torch.empty(out_features, in_features, 4))
        self.bias = nn.Parameter(torch.empty(out_features, 4))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.weight, mean=0.0, std=0.1)
        with torch.no_grad():
            self.weight[:] = q_normalize(self.weight)
            nn.init.constant_(self.bias[..., 0], 1.0)
            nn.init.constant_(self.bias[..., 1:], 0.0)

    def forward(self, x):
        """
        x: (B, in_features, 4)
        Returns: (B, out_features, 4)
        """
        w = self.weight.unsqueeze(0)
        x_exp = x.unsqueeze(1)
        prod = q_mul(w, x_exp)
        out = prod.sum(dim=2) + self.bias
        return out


class QuatHead(nn.Module):
    """
    Quaternion head: 4 quats → 16 quats → 10 quats → 10 logits.
    """
    def __init__(self, num_classes=10):
        super().__init__()
        self.quat_fc1 = QuaternionLinear(4, 16)
        self.quat_fc2 = QuaternionLinear(16, num_classes)

    def real_to_quat(self, x):
        """Convert 16 real features to 4 quaternions"""
        B = x.size(0)
        return x.view(B, 4, 4)

    def quat_to_real(self, q):
        """Extract real part of quaternions for classification"""
        return q[..., 0]

    def forward(self, x):
        q_in = self.real_to_quat(x)
        hq = self.quat_fc1(q_in)
        hq = q_normalize(hq)
        hq = torch.tanh(hq)
        q_out = self.quat_fc2(hq)
        logits = self.quat_to_real(q_out)
        return logits


class QuatNet(nn.Module):
    """Complete Quaternion network: Preprocessor + QuatHead"""
    def __init__(self):
        super().__init__()
        self.preprocessor = SharedPreprocessor(3072, 16)
        self.head = QuatHead(10)

    def forward(self, x):
        features = self.preprocessor(x)
        return self.head(features)


# ============================================
# Training and Evaluation
# ============================================

def train_one_epoch(model, loader, optimizer, device, show_progress=False):
    model.train()
    total_loss = 0.0
    total_samples = 0

    for batch_idx, (x, y) in enumerate(loader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        total_samples += x.size(0)

        if show_progress and batch_idx % 50 == 0:
            print(f"    Batch {batch_idx}/{len(loader)}, samples: {total_samples}", end="\r")

    if show_progress:
        print()
    return total_loss / len(loader.dataset)


def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    return correct / total if total > 0 else 0.0


def train_with_early_stopping(model, train_loader, test_loader, optimizer,
                              device, max_epochs=40, patience=10, name="Model"):
    if isinstance(device, str):
        device = torch.device(device)

    best_acc = 0.0
    epochs_without_improvement = 0
    start = time.time()
    last_acc = 0.0

    for epoch in range(1, max_epochs + 1):
        loss = train_one_epoch(model, train_loader, optimizer, device, show_progress=False)
        acc = evaluate(model, test_loader, device)
        last_acc = acc

        elapsed = time.time() - start
        print(f"  [{name}] Epoch {epoch:2d} | loss={loss:.4f} "
              f"| test_acc={acc:.4f} | time={elapsed:.1f}s")

        if acc > best_acc:
            best_acc = acc
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print(f"  [{name}] Early stop at epoch {epoch} "
                  f"(no improvement for {patience} epochs)")
            break

    total_time = time.time() - start
    return {
        "best_acc": best_acc,
        "final_acc": last_acc,
        "time": total_time,
        "epochs": epoch,
    }


# ============================================
# Data utilities
# ============================================

def stratified_sample_from_targets(dataset, n_samples_per_class, seed=42):
    """
    Create a stratified sample with n_samples_per_class from each class.
    Uses dataset.targets directly - NO image loading during sampling.
    """
    # CIFAR-10 exposes targets as a list
    if hasattr(dataset, 'targets'):
        targets = np.array(dataset.targets)
    else:
        # Fallback for wrapped datasets
        targets = np.array([dataset[i][1] for i in range(len(dataset))])
    
    rng = np.random.RandomState(seed)
    
    sampled_indices = []
    for c in range(10):
        # Find all indices for class c
        idx_c = np.where(targets == c)[0]
        k = min(n_samples_per_class, len(idx_c))
        # Sample without replacement
        selected = rng.choice(idx_c, size=k, replace=False).tolist()
        sampled_indices.extend(selected)
    
    # Shuffle the combined indices
    rng.shuffle(sampled_indices)
    return sampled_indices


# ============================================
# Main
# ============================================

def main():
    print("=" * 70)
    print("BLOCK 2: QuatNet Training (CIFAR-10, Frozen Preprocessor)")
    print("=" * 70)
    print("\nConfiguration:")
    print("  • Dataset: CIFAR-10 (32×32 RGB)")
    print("  • Batch size: 128")
    print("  • Patience: 10")
    print("  • Seeds: [42, 123, 456]")
    print("  • Architecture: [frozen 3072→16] → 4 quats → 16 quats → 10")
    print("=" * 70)

    # Load frozen preprocessor from Block 1
    print("\nLoading frozen preprocessor from Block 1...")
    try:
        realnet_data = torch.load("realnet_cifar10_results.pt", weights_only=False)
        preprocessor_state = realnet_data["preprocessor_state"]
        print("✓ Loaded preprocessor state from realnet_cifar10_results.pt")
    except FileNotFoundError:
        print("✗ ERROR: realnet_cifar10_results.pt not found!")
        print("  Run Block 1 first to train RealNet and save preprocessor.")
        return

    # Create frozen preprocessor
    shared_preprocessor = SharedPreprocessor(3072, 16).to(device)
    shared_preprocessor.load_state_dict(preprocessor_state)
    for p in shared_preprocessor.parameters():
        p.requires_grad = False
    
    preprocessor_params = sum(p.numel() for p in shared_preprocessor.parameters())
    print(f"  Preprocessor frozen with {preprocessor_params:,} params")

    # Load data - CIFAR-10 normalization
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2470, 0.2435, 0.2616]
        )
    ])

    full_train_ds = datasets.CIFAR10(root="./data", train=True,
                                     download=True, transform=transform)
    full_test_ds = datasets.CIFAR10(root="./data", train=False,
                                    download=True, transform=transform)

    # Create stratified samples (fast - uses targets only, no image loading)
    print("\nCreating stratified samples...")
    t0 = time.time()
    train_indices = stratified_sample_from_targets(full_train_ds, n_samples_per_class=1500, seed=42)
    test_indices = stratified_sample_from_targets(full_test_ds, n_samples_per_class=300, seed=42)
    print(f"  Sampling took {time.time()-t0:.2f}s")
    
    train_ds = Subset(full_train_ds, train_indices)
    test_ds = Subset(full_test_ds, test_indices)
    
    print(f"  Train samples: {len(train_ds)} (stratified, 1500 per class)")
    print(f"  Test samples:  {len(test_ds)} (stratified, 300 per class)")

    train_loader = DataLoader(train_ds, batch_size=128,
                              shuffle=True, num_workers=4)
    test_loader = DataLoader(test_ds, batch_size=256,
                             shuffle=False, num_workers=4)

    seeds = [42, 123, 456]
    all_results = []

    for seed in seeds:
        print(f"\n{'=' * 70}")
        print(f"SEED {seed}")
        print("=" * 70)
        
        set_all_seeds(seed)

        print(f"\n  Training QuatNet (seed={seed})...")
        quat_model = QuatNet().to(device)
        quat_model.preprocessor = shared_preprocessor  # Use frozen preprocessor
        
        quat_opt = torch.optim.Adam(quat_model.head.parameters(), lr=1e-3)
        
        result = train_with_early_stopping(
            quat_model, train_loader, test_loader, quat_opt, device,
            max_epochs=40, patience=10, name="Quat"
        )
        
        result["trainable_params"] = sum(p.numel() for p in quat_model.parameters()
                                         if p.requires_grad)
        result["total_params"] = sum(p.numel() for p in quat_model.parameters())
        result["seed"] = seed
        
        all_results.append(result)

    # Summary
    print("\n" + "=" * 70)
    print("QUATNET SUMMARY (CIFAR-10)")
    print("=" * 70)
    
    accs = [r["best_acc"] for r in all_results]
    times = [r["time"] for r in all_results]
    epochs = [r["epochs"] for r in all_results]
    trainable = all_results[0]["trainable_params"]
    total = all_results[0]["total_params"]
    
    print(f"\nAccuracy:   {np.mean(accs):.4f} ± {np.std(accs):.4f}")
    print(f"Time:       {np.mean(times):.1f}s ± {np.std(times):.1f}s")
    print(f"Epochs:     {np.mean(epochs):.1f} ± {np.std(epochs):.1f}")
    print(f"Parameters: {trainable:,} trainable (head), {total:,} total")
    
    print("\nPer-seed results:")
    for r in all_results:
        print(f"  Seed {r['seed']}: acc={r['best_acc']:.4f}, "
              f"time={r['time']:.1f}s, epochs={r['epochs']}")

    # Save results
    save_dict = {
        "results": all_results,
        "summary": {
            "mean_acc": np.mean(accs),
            "std_acc": np.std(accs),
            "mean_time": np.mean(times),
            "trainable_params": trainable,
            "total_params": total
        }
    }
    
    torch.save(save_dict, "quatnet_cifar10_results.pt")
    print(f"\n✓ Saved results to: quatnet_cifar10_results.pt")
    print("=" * 70)


if __name__ == "__main__":
    main()

Using PyTorch device: cuda
BLOCK 2: QuatNet Training (CIFAR-10, Frozen Preprocessor)

Configuration:
  • Dataset: CIFAR-10 (32×32 RGB)
  • Batch size: 128
  • Patience: 10
  • Seeds: [42, 123, 456]
  • Architecture: [frozen 3072→16] → 4 quats → 16 quats → 10

Loading frozen preprocessor from Block 1...
✓ Loaded preprocessor state from realnet_cifar10_results.pt
  Preprocessor frozen with 49,168 params

Creating stratified samples...
  Sampling took 0.00s
  Train samples: 15000 (stratified, 1500 per class)
  Test samples:  3000 (stratified, 300 per class)

SEED 42

  Training QuatNet (seed=42)...
  [Quat] Epoch  1 | loss=2.7936 | test_acc=0.2123 | time=0.9s
  [Quat] Epoch  2 | loss=2.0419 | test_acc=0.2943 | time=1.6s
  [Quat] Epoch  3 | loss=1.8016 | test_acc=0.3283 | time=2.2s
  [Quat] Epoch  4 | loss=1.6987 | test_acc=0.3450 | time=2.9s
  [Quat] Epoch  5 | loss=1.6418 | test_acc=0.3553 | time=3.5s
  [Quat] Epoch  6 | loss=1.6071 | test_acc=0.3593 | time=4.4s
  [Quat] Epoch  7 | loss=

# Quant No Ent

In [3]:
"""
Block 3: Quantum (No Entanglement) Training (CIFAR-10)
========================================================
Loads frozen preprocessor from Block 1 and trains quantum head WITHOUT entanglement.
Uses Lightning-GPU acceleration.

Requirements:
- realnet_cifar10_results.pt (from Block 1)
- pennylane, pennylane-lightning-gpu

Outputs:
- quantum_noent_cifar10_results.pt: Contains quantum (no ent) results
"""

import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

try:
    import pennylane as qml
    QUANTUM_DEVICE = "lightning.gpu"
    PENNYLANE_AVAILABLE = True
    print("✓ Using lightning.gpu device")
except ImportError:
    PENNYLANE_AVAILABLE = False
    print("✗ PennyLane not installed")
    exit(1)

device = torch.device("cuda")
print(f"Using PyTorch device: {device}")

# ============================================
# Comprehensive seed setting for reproducibility
# ============================================

def set_all_seeds(seed):
    """Set seeds for all RNG sources for reproducibility"""
    import random
    import numpy as np
    import torch
    
    # Python
    random.seed(seed)
    
    # NumPy
    np.random.seed(seed)
    
    # PyTorch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    
    # PyTorch deterministic mode
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # CuPy (if available)
    try:
        import cupy as cp
        cp.random.seed(seed)
    except ImportError:
        pass
    
    # PennyLane (if available)
    try:
        import pennylane as qml
        qml.numpy.random.seed(seed)
    except (ImportError, AttributeError):
        pass

# ============================================
# Shared Preprocessor (must match Block 1 - CIFAR-10)
# ============================================

class SharedPreprocessor(nn.Module):
    """Shared classical feature extractor: 3072 → 16"""
    def __init__(self, input_dim=3072, bottleneck_dim=16):
        super().__init__()
        self.fc = nn.Linear(input_dim, bottleneck_dim)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten to (batch, 3072)
        x = torch.tanh(self.fc(x))
        return x


# ============================================
# Quantum Head (No Entanglement)
# ============================================

class QuantumHead(nn.Module):
    """
    VQC with 4 qubits, 3 layers, NO entanglement → 10 classes.
    Uses Lightning acceleration (GPU).
    """
    def __init__(self, n_qubits=4, n_layers=3, num_classes=10, device_name=None):
        super().__init__()
        self.n_qubits = n_qubits
        self.n_layers = n_layers

        # Map 16 features → n_qubits
        self.feature_select = nn.Linear(16, n_qubits)

        # Use specified device or default
        if device_name is None:
            device_name = QUANTUM_DEVICE
        
        # Quantum device
        self.dev = qml.device(device_name, wires=n_qubits)

        # Use adjoint differentiation for lightning (much faster)
        diff_method = "adjoint" if "lightning" in device_name else "parameter-shift"

        @qml.qnode(self.dev, interface="torch", diff_method=diff_method)
        def quantum_circuit(inputs, weights):
            """
            Single-sample circuit WITHOUT entanglement.
            inputs: (n_qubits,)
            weights: (n_layers, n_qubits, 2)
            """
            for layer in range(n_layers):
                # Data re-uploading
                for i in range(n_qubits):
                    qml.RY(inputs[i], wires=i)

                # Trainable rotations
                for i in range(n_qubits):
                    qml.RY(weights[layer, i, 0], wires=i)
                    qml.RZ(weights[layer, i, 1], wires=i)

                # NO ENTANGLEMENT

            return [
                qml.expval(qml.PauliZ(0)),
                qml.expval(qml.PauliZ(1)),
                qml.expval(qml.PauliZ(2)),
                qml.expval(qml.PauliZ(3)),
                qml.expval(qml.PauliZ(0) @ qml.PauliZ(1)),
                qml.expval(qml.PauliZ(2) @ qml.PauliZ(3))
            ]

        self.quantum_circuit = quantum_circuit

        weight_shape = (n_layers, n_qubits, 2)
        self.q_weights = nn.Parameter(torch.randn(weight_shape) * 0.1)
        self.fc_out = nn.Linear(6, num_classes)

    def forward(self, x):
        """
        Process samples with progress tracking.
        x: (batch, 16) real bottleneck features
        """
        batch_size = x.size(0)
        x = torch.tanh(self.feature_select(x))

        # Process in chunks for memory management
        chunk_size = 32
        quantum_outputs = []

        for start_idx in range(0, batch_size, chunk_size):
            end_idx = min(start_idx + chunk_size, batch_size)
            chunk = x[start_idx:end_idx]

            chunk_outputs = []
            for i in range(chunk.size(0)):
                q_raw = self.quantum_circuit(chunk[i], self.q_weights)
                if isinstance(q_raw, (list, tuple)):
                    q_out = torch.stack(q_raw)
                else:
                    q_out = q_raw
                chunk_outputs.append(q_out)

            quantum_outputs.extend(chunk_outputs)

        # Convert to tensor (cast to float32)
        quantum_outputs = torch.stack(quantum_outputs).float()
        quantum_outputs = quantum_outputs.to(self.fc_out.weight.dtype)

        output = self.fc_out(quantum_outputs)
        return output


class QuantumNet(nn.Module):
    """Complete Quantum network: Preprocessor + QuantumHead (no ent)"""
    def __init__(self):
        super().__init__()
        self.preprocessor = SharedPreprocessor(3072, 16)
        self.head = QuantumHead(n_qubits=4, n_layers=3, num_classes=10)

    def forward(self, x):
        features = self.preprocessor(x)
        return self.head(features)


# ============================================
# Training and Evaluation
# ============================================

def train_one_epoch(model, loader, optimizer, show_progress=True):
    model.train()
    total_loss = 0.0
    total_samples = 0

    for batch_idx, (x, y) in enumerate(loader):
        # Keep data on CPU for quantum models
        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        total_samples += x.size(0)

        if show_progress and batch_idx % 20 == 0:
            print(f"    Batch {batch_idx}/{len(loader)}, samples: {total_samples}", end="\r")

    if show_progress:
        print()
    return total_loss / len(loader.dataset)


def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            # Keep data on CPU
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    return correct / total if total > 0 else 0.0


def train_with_early_stopping(model, train_loader, test_loader, optimizer,
                              max_epochs=200, patience=10, name="Model"):
    best_acc = 0.0
    epochs_without_improvement = 0
    start = time.time()
    last_acc = 0.0

    for epoch in range(1, max_epochs + 1):
        print(f"  [{name}] Epoch {epoch}/{max_epochs}")
        loss = train_one_epoch(model, train_loader, optimizer, show_progress=True)
        acc = evaluate(model, test_loader)
        last_acc = acc

        elapsed = time.time() - start
        print(f"  [{name}] Epoch {epoch:2d} | loss={loss:.4f} "
              f"| test_acc={acc:.4f} | time={elapsed:.1f}s")

        if acc > best_acc:
            best_acc = acc
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print(f"  [{name}] Early stop at epoch {epoch} "
                  f"(no improvement for {patience} epochs)")
            break

    total_time = time.time() - start
    return {
        "best_acc": best_acc,
        "final_acc": last_acc,
        "time": total_time,
        "epochs": epoch,
    }


# ============================================
# Data utilities
# ============================================

def stratified_sample_from_targets(dataset, n_samples_per_class, seed=42):
    """
    Create a stratified sample with n_samples_per_class from each class.
    Uses dataset.targets directly - NO image loading during sampling.
    """
    # CIFAR-10 exposes targets as a list
    if hasattr(dataset, 'targets'):
        targets = np.array(dataset.targets)
    else:
        # Fallback for wrapped datasets
        targets = np.array([dataset[i][1] for i in range(len(dataset))])
    
    rng = np.random.RandomState(seed)
    
    sampled_indices = []
    for c in range(10):
        # Find all indices for class c
        idx_c = np.where(targets == c)[0]
        k = min(n_samples_per_class, len(idx_c))
        # Sample without replacement
        selected = rng.choice(idx_c, size=k, replace=False).tolist()
        sampled_indices.extend(selected)
    
    # Shuffle the combined indices
    rng.shuffle(sampled_indices)
    return sampled_indices


# ============================================
# Main
# ============================================

def main():
    print("=" * 70)
    print("BLOCK 3: Quantum (NO Entanglement) Training (CIFAR-10)")
    print("=" * 70)
    print("\nConfiguration:")
    print("  • Dataset: CIFAR-10 (32×32 RGB)")
    print("  • Batch size: 32")
    print("  • Patience: 10")
    print("  • Max epochs: 200")
    print("  • Seeds: [42, 123, 456]")
    print("  • Architecture: [frozen 3072→16] → 4 qubits (3 layers, NO ent) → 10")
    print(f"  • Quantum device: {QUANTUM_DEVICE}")
    print("=" * 70)

    # Load frozen preprocessor from Block 1
    print("\nLoading frozen preprocessor from Block 1...")
    try:
        realnet_data = torch.load("realnet_cifar10_results.pt", weights_only=False)
        preprocessor_state = realnet_data["preprocessor_state"]
        print("✓ Loaded preprocessor state from realnet_cifar10_results.pt")
    except FileNotFoundError:
        print("✗ ERROR: realnet_cifar10_results.pt not found!")
        print("  Run Block 1 first to train RealNet and save preprocessor.")
        return

    # Create frozen preprocessor (CPU for quantum)
    shared_preprocessor = SharedPreprocessor(3072, 16)
    shared_preprocessor.load_state_dict(preprocessor_state)
    for p in shared_preprocessor.parameters():
        p.requires_grad = False
    
    preprocessor_params = sum(p.numel() for p in shared_preprocessor.parameters())
    print(f"  Preprocessor frozen with {preprocessor_params:,} params")

    # Load data - CIFAR-10 normalization
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2470, 0.2435, 0.2616]
        )
    ])

    full_train_ds = datasets.CIFAR10(root="./data", train=True,
                                     download=True, transform=transform)
    full_test_ds = datasets.CIFAR10(root="./data", train=False,
                                    download=True, transform=transform)

    # Create stratified samples (fast - uses targets only, no image loading)
    print("\nCreating stratified samples...")
    t0 = time.time()
    train_indices = stratified_sample_from_targets(full_train_ds, n_samples_per_class=1500, seed=42)
    test_indices = stratified_sample_from_targets(full_test_ds, n_samples_per_class=300, seed=42)
    print(f"  Sampling took {time.time()-t0:.2f}s")
    
    train_ds = Subset(full_train_ds, train_indices)
    test_ds = Subset(full_test_ds, test_indices)
    
    print(f"  Train samples: {len(train_ds)} (stratified, 1500 per class)")
    print(f"  Test samples:  {len(test_ds)} (stratified, 300 per class)")

    # Smaller batches for quantum
    train_loader = DataLoader(train_ds, batch_size=32,
                              shuffle=True, num_workers=0)
    test_loader = DataLoader(test_ds, batch_size=64,
                             shuffle=False, num_workers=0)

    seeds = [42, 123, 456]
    all_results = []

    for seed in seeds:
        print(f"\n{'=' * 70}")
        print(f"SEED {seed}")
        print("=" * 70)
        
        set_all_seeds(seed)

        print(f"\n  Training QuantumNet (NO entanglement, seed={seed})...")
        quantum_model = QuantumNet()
        quantum_model.preprocessor = shared_preprocessor  # Use frozen preprocessor
        
        quantum_opt = torch.optim.Adam(quantum_model.head.parameters(), lr=1e-3)
        
        result = train_with_early_stopping(
            quantum_model, train_loader, test_loader, quantum_opt,
            max_epochs=200, patience=10, name="QuantumNoEnt"
        )
        
        result["trainable_params"] = sum(p.numel() for p in quantum_model.parameters()
                                         if p.requires_grad)
        result["total_params"] = sum(p.numel() for p in quantum_model.parameters())
        result["seed"] = seed
        
        all_results.append(result)

    # Summary
    print("\n" + "=" * 70)
    print("QUANTUM (NO ENTANGLEMENT) SUMMARY (CIFAR-10)")
    print("=" * 70)
    
    accs = [r["best_acc"] for r in all_results]
    times = [r["time"] for r in all_results]
    epochs = [r["epochs"] for r in all_results]
    trainable = all_results[0]["trainable_params"]
    total = all_results[0]["total_params"]
    
    print(f"\nAccuracy:   {np.mean(accs):.4f} ± {np.std(accs):.4f}")
    print(f"Time:       {np.mean(times):.1f}s ± {np.std(times):.1f}s")
    print(f"Epochs:     {np.mean(epochs):.1f} ± {np.std(epochs):.1f}")
    print(f"Parameters: {trainable:,} trainable (head), {total:,} total")
    
    print("\nPer-seed results:")
    for r in all_results:
        print(f"  Seed {r['seed']}: acc={r['best_acc']:.4f}, "
              f"time={r['time']:.1f}s, epochs={r['epochs']}")

    # Save results
    save_dict = {
        "results": all_results,
        "summary": {
            "mean_acc": np.mean(accs),
            "std_acc": np.std(accs),
            "mean_time": np.mean(times),
            "trainable_params": trainable,
            "total_params": total
        }
    }
    
    torch.save(save_dict, "quantum_noent_cifar10_results.pt")
    print(f"\n✓ Saved results to: quantum_noent_cifar10_results.pt")
    print("=" * 70)


if __name__ == "__main__":
    main()

✓ Using lightning.gpu device
Using PyTorch device: cuda
BLOCK 3: Quantum (NO Entanglement) Training (CIFAR-10)

Configuration:
  • Dataset: CIFAR-10 (32×32 RGB)
  • Batch size: 32
  • Patience: 10
  • Max epochs: 200
  • Seeds: [42, 123, 456]
  • Architecture: [frozen 3072→16] → 4 qubits (3 layers, NO ent) → 10
  • Quantum device: lightning.gpu

Loading frozen preprocessor from Block 1...
✓ Loaded preprocessor state from realnet_cifar10_results.pt
  Preprocessor frozen with 49,168 params

Creating stratified samples...
  Sampling took 0.00s
  Train samples: 15000 (stratified, 1500 per class)
  Test samples:  3000 (stratified, 300 per class)

SEED 42

  Training QuantumNet (NO entanglement, seed=42)...
  [QuantumNoEnt] Epoch 1/200
    Batch 460/469, samples: 14752
  [QuantumNoEnt] Epoch  1 | loss=2.2121 | test_acc=0.2337 | time=457.3s
  [QuantumNoEnt] Epoch 2/200
    Batch 460/469, samples: 14752
  [QuantumNoEnt] Epoch  2 | loss=1.9542 | test_acc=0.2880 | time=928.3s
  [QuantumNoEnt] Ep

# Quant Ent

In [8]:
"""
Block 4: Quantum (WITH Entanglement) Training (CIFAR-10)
==========================================================
Loads frozen preprocessor from Block 1 and trains quantum head WITH entanglement.
Uses Lightning-GPU acceleration.

Requirements:
- realnet_cifar10_results.pt (from Block 1)
- pennylane, pennylane-lightning-gpu

Outputs:
- quantum_ent_cifar10_results.pt: Contains quantum (with ent) results
"""

import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

try:
    import pennylane as qml
    QUANTUM_DEVICE = "lightning.gpu"
    PENNYLANE_AVAILABLE = True
    print("✓ Using lightning.gpu device")
except ImportError:
    PENNYLANE_AVAILABLE = False
    print("✗ PennyLane not installed")
    exit(1)

device = torch.device("cuda")
print(f"Using PyTorch device: {device}")

# ============================================
# Comprehensive seed setting for reproducibility
# ============================================

def set_all_seeds(seed):
    """Set seeds for all RNG sources for reproducibility"""
    import random
    import numpy as np
    import torch
    
    # Python
    random.seed(seed)
    
    # NumPy
    np.random.seed(seed)
    
    # PyTorch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    
    # PyTorch deterministic mode
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # CuPy (if available)
    try:
        import cupy as cp
        cp.random.seed(seed)
    except ImportError:
        pass
    
    # PennyLane (if available)
    try:
        import pennylane as qml
        qml.numpy.random.seed(seed)
    except (ImportError, AttributeError):
        pass

# ============================================
# Shared Preprocessor (must match Block 1 - CIFAR-10)
# ============================================

class SharedPreprocessor(nn.Module):
    """Shared classical feature extractor: 3072 → 16"""
    def __init__(self, input_dim=3072, bottleneck_dim=16):
        super().__init__()
        self.fc = nn.Linear(input_dim, bottleneck_dim)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten to (batch, 3072)
        x = torch.tanh(self.fc(x))
        return x


# ============================================
# Quantum Head (WITH Entanglement)
# ============================================

class QuantumHead(nn.Module):
    """
    VQC with 4 qubits, 3 layers, WITH entanglement → 10 classes.
    Uses Lightning acceleration (GPU).
    """
    def __init__(self, n_qubits=4, n_layers=3, num_classes=10, device_name=None):
        super().__init__()
        self.n_qubits = n_qubits
        self.n_layers = n_layers

        # Map 16 features → n_qubits
        self.feature_select = nn.Linear(16, n_qubits)

        # Use specified device or default
        if device_name is None:
            device_name = QUANTUM_DEVICE
        
        # Quantum device
        self.dev = qml.device(device_name, wires=n_qubits)

        # Use adjoint differentiation for lightning (much faster)
        diff_method = "adjoint" if "lightning" in device_name else "parameter-shift"

        @qml.qnode(self.dev, interface="torch", diff_method=diff_method)
        def quantum_circuit(inputs, weights):
            """
            Single-sample circuit WITH entanglement.
            inputs: (n_qubits,)
            weights: (n_layers, n_qubits, 2)
            """
            for layer in range(n_layers):
                # Data re-uploading
                for i in range(n_qubits):
                    qml.RY(inputs[i], wires=i)

                # Trainable rotations
                for i in range(n_qubits):
                    qml.RY(weights[layer, i, 0], wires=i)
                    qml.RZ(weights[layer, i, 1], wires=i)

                # ENTANGLEMENT: CNOT ring
                for i in range(n_qubits - 1):
                    qml.CNOT(wires=[i, i + 1])
                if n_qubits > 2:
                    qml.CNOT(wires=[n_qubits - 1, 0])

            return [
                qml.expval(qml.PauliZ(0)),
                qml.expval(qml.PauliZ(1)),
                qml.expval(qml.PauliZ(2)),
                qml.expval(qml.PauliZ(3)),
                qml.expval(qml.PauliZ(0) @ qml.PauliZ(1)),
                qml.expval(qml.PauliZ(2) @ qml.PauliZ(3))
            ]

        self.quantum_circuit = quantum_circuit

        weight_shape = (n_layers, n_qubits, 2)
        self.q_weights = nn.Parameter(torch.randn(weight_shape) * 0.1)
        self.fc_out = nn.Linear(6, num_classes)

    def forward(self, x):
        """
        Process samples with progress tracking.
        x: (batch, 16) real bottleneck features
        """
        batch_size = x.size(0)
        x = torch.tanh(self.feature_select(x))

        # Process in chunks for memory management
        chunk_size = 32
        quantum_outputs = []

        for start_idx in range(0, batch_size, chunk_size):
            end_idx = min(start_idx + chunk_size, batch_size)
            chunk = x[start_idx:end_idx]

            chunk_outputs = []
            for i in range(chunk.size(0)):
                q_raw = self.quantum_circuit(chunk[i], self.q_weights)
                if isinstance(q_raw, (list, tuple)):
                    q_out = torch.stack(q_raw)
                else:
                    q_out = q_raw
                chunk_outputs.append(q_out)

            quantum_outputs.extend(chunk_outputs)

        # Convert to tensor (cast to float32)
        quantum_outputs = torch.stack(quantum_outputs).float()
        quantum_outputs = quantum_outputs.to(self.fc_out.weight.dtype)

        output = self.fc_out(quantum_outputs)
        return output


class QuantumNet(nn.Module):
    """Complete Quantum network: Preprocessor + QuantumHead (with ent)"""
    def __init__(self):
        super().__init__()
        self.preprocessor = SharedPreprocessor(3072, 16)
        self.head = QuantumHead(n_qubits=4, n_layers=3, num_classes=10)

    def forward(self, x):
        features = self.preprocessor(x)
        return self.head(features)


# ============================================
# Training and Evaluation
# ============================================

def train_one_epoch(model, loader, optimizer, show_progress=True):
    model.train()
    total_loss = 0.0
    total_samples = 0

    for batch_idx, (x, y) in enumerate(loader):
        # Keep data on CPU for quantum models
        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        total_samples += x.size(0)

        if show_progress and batch_idx % 20 == 0:
            print(f"    Batch {batch_idx}/{len(loader)}, samples: {total_samples}", end="\r")

    if show_progress:
        print()
    return total_loss / len(loader.dataset)


def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            # Keep data on CPU
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    return correct / total if total > 0 else 0.0


def train_with_early_stopping(model, train_loader, test_loader, optimizer,
                              max_epochs=200, patience=10, name="Model"):
    best_acc = 0.0
    epochs_without_improvement = 0
    start = time.time()
    last_acc = 0.0

    for epoch in range(1, max_epochs + 1):
        print(f"  [{name}] Epoch {epoch}/{max_epochs}")
        loss = train_one_epoch(model, train_loader, optimizer, show_progress=True)
        acc = evaluate(model, test_loader)
        last_acc = acc

        elapsed = time.time() - start
        print(f"  [{name}] Epoch {epoch:2d} | loss={loss:.4f} "
              f"| test_acc={acc:.4f} | time={elapsed:.1f}s")

        if acc > best_acc:
            best_acc = acc
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print(f"  [{name}] Early stop at epoch {epoch} "
                  f"(no improvement for {patience} epochs)")
            break

    total_time = time.time() - start
    return {
        "best_acc": best_acc,
        "final_acc": last_acc,
        "time": total_time,
        "epochs": epoch,
    }


# ============================================
# Data utilities
# ============================================

def stratified_sample_from_targets(dataset, n_samples_per_class, seed=42):
    """
    Create a stratified sample with n_samples_per_class from each class.
    Uses dataset.targets directly - NO image loading during sampling.
    """
    # CIFAR-10 exposes targets as a list
    if hasattr(dataset, 'targets'):
        targets = np.array(dataset.targets)
    else:
        # Fallback for wrapped datasets
        targets = np.array([dataset[i][1] for i in range(len(dataset))])
    
    rng = np.random.RandomState(seed)
    
    sampled_indices = []
    for c in range(10):
        # Find all indices for class c
        idx_c = np.where(targets == c)[0]
        k = min(n_samples_per_class, len(idx_c))
        # Sample without replacement
        selected = rng.choice(idx_c, size=k, replace=False).tolist()
        sampled_indices.extend(selected)
    
    # Shuffle the combined indices
    rng.shuffle(sampled_indices)
    return sampled_indices


# ============================================
# Main
# ============================================

def main():
    print("=" * 70)
    print("BLOCK 4: Quantum (WITH Entanglement) Training (CIFAR-10)")
    print("=" * 70)
    print("\nConfiguration:")
    print("  • Dataset: CIFAR-10 (32×32 RGB)")
    print("  • Batch size: 32")
    print("  • Patience: 10")
    print("  • Max epochs: 200")
    print("  • Seeds: [42, 123, 456]")
    print("  • Architecture: [frozen 3072→16] → 4 qubits (3 layers, WITH ent) → 10")
    print(f"  • Quantum device: {QUANTUM_DEVICE}")
    print("=" * 70)

    # Load frozen preprocessor from Block 1
    print("\nLoading frozen preprocessor from Block 1...")
    try:
        realnet_data = torch.load("realnet_cifar10_results.pt", weights_only=False)
        preprocessor_state = realnet_data["preprocessor_state"]
        print("✓ Loaded preprocessor state from realnet_cifar10_results.pt")
    except FileNotFoundError:
        print("✗ ERROR: realnet_cifar10_results.pt not found!")
        print("  Run Block 1 first to train RealNet and save preprocessor.")
        return

    # Create frozen preprocessor (CPU for quantum)
    shared_preprocessor = SharedPreprocessor(3072, 16)
    shared_preprocessor.load_state_dict(preprocessor_state)
    for p in shared_preprocessor.parameters():
        p.requires_grad = False
    
    preprocessor_params = sum(p.numel() for p in shared_preprocessor.parameters())
    print(f"  Preprocessor frozen with {preprocessor_params:,} params")

    # Load data - CIFAR-10 normalization
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2470, 0.2435, 0.2616]
        )
    ])

    full_train_ds = datasets.CIFAR10(root="./data", train=True,
                                     download=True, transform=transform)
    full_test_ds = datasets.CIFAR10(root="./data", train=False,
                                    download=True, transform=transform)

    # Create stratified samples (fast - uses targets only, no image loading)
    print("\nCreating stratified samples...")
    t0 = time.time()
    train_indices = stratified_sample_from_targets(full_train_ds, n_samples_per_class=1500, seed=42)
    test_indices = stratified_sample_from_targets(full_test_ds, n_samples_per_class=300, seed=42)
    print(f"  Sampling took {time.time()-t0:.2f}s")
    
    train_ds = Subset(full_train_ds, train_indices)
    test_ds = Subset(full_test_ds, test_indices)
    
    print(f"  Train samples: {len(train_ds)} (stratified, 1500 per class)")
    print(f"  Test samples:  {len(test_ds)} (stratified, 300 per class)")

    # Smaller batches for quantum
    train_loader = DataLoader(train_ds, batch_size=32,
                              shuffle=True, num_workers=0)
    test_loader = DataLoader(test_ds, batch_size=64,
                             shuffle=False, num_workers=0)

    seeds = [42, 123, 456]
    all_results = []

    for seed in seeds:
        print(f"\n{'=' * 70}")
        print(f"SEED {seed}")
        print("=" * 70)
        
        set_all_seeds(seed)

        print(f"\n  Training QuantumNet (WITH entanglement, seed={seed})...")
        quantum_model = QuantumNet()
        quantum_model.preprocessor = shared_preprocessor  # Use frozen preprocessor
        
        quantum_opt = torch.optim.Adam(quantum_model.head.parameters(), lr=1e-3)
        
        result = train_with_early_stopping(
            quantum_model, train_loader, test_loader, quantum_opt,
            max_epochs=200, patience=10, name="QuantumEnt"
        )
        
        result["trainable_params"] = sum(p.numel() for p in quantum_model.parameters()
                                         if p.requires_grad)
        result["total_params"] = sum(p.numel() for p in quantum_model.parameters())
        result["seed"] = seed
        
        all_results.append(result)

    # Summary
    print("\n" + "=" * 70)
    print("QUANTUM (WITH ENTANGLEMENT) SUMMARY (CIFAR-10)")
    print("=" * 70)
    
    accs = [r["best_acc"] for r in all_results]
    times = [r["time"] for r in all_results]
    epochs = [r["epochs"] for r in all_results]
    trainable = all_results[0]["trainable_params"]
    total = all_results[0]["total_params"]
    
    print(f"\nAccuracy:   {np.mean(accs):.4f} ± {np.std(accs):.4f}")
    print(f"Time:       {np.mean(times):.1f}s ± {np.std(times):.1f}s")
    print(f"Epochs:     {np.mean(epochs):.1f} ± {np.std(epochs):.1f}")
    print(f"Parameters: {trainable:,} trainable (head), {total:,} total")
    
    print("\nPer-seed results:")
    for r in all_results:
        print(f"  Seed {r['seed']}: acc={r['best_acc']:.4f}, "
              f"time={r['time']:.1f}s, epochs={r['epochs']}")

    # Save results
    save_dict = {
        "results": all_results,
        "summary": {
            "mean_acc": np.mean(accs),
            "std_acc": np.std(accs),
            "mean_time": np.mean(times),
            "trainable_params": trainable,
            "total_params": total
        }
    }
    
    torch.save(save_dict, "quantum_ent_cifar10_results.pt")
    print(f"\n✓ Saved results to: quantum_ent_cifar10_results.pt")
    print("=" * 70)


if __name__ == "__main__":
    main()

✓ Using lightning.gpu device
Using PyTorch device: cuda
BLOCK 4: Quantum (WITH Entanglement) Training (CIFAR-10)

Configuration:
  • Dataset: CIFAR-10 (32×32 RGB)
  • Batch size: 32
  • Patience: 10
  • Max epochs: 200
  • Seeds: [42, 123, 456]
  • Architecture: [frozen 3072→16] → 4 qubits (3 layers, WITH ent) → 10
  • Quantum device: lightning.gpu

Loading frozen preprocessor from Block 1...
✓ Loaded preprocessor state from realnet_cifar10_results.pt
  Preprocessor frozen with 49,168 params

Creating stratified samples...
  Sampling took 0.00s
  Train samples: 15000 (stratified, 1500 per class)
  Test samples:  3000 (stratified, 300 per class)

SEED 42

  Training QuantumNet (WITH entanglement, seed=42)...
  [QuantumEnt] Epoch 1/200
    Batch 460/469, samples: 14752
  [QuantumEnt] Epoch  1 | loss=2.2983 | test_acc=0.2123 | time=529.6s
  [QuantumEnt] Epoch 2/200
    Batch 460/469, samples: 14752
  [QuantumEnt] Epoch  2 | loss=2.1265 | test_acc=0.2300 | time=1043.0s
  [QuantumEnt] Epoch

# Comparison

In [9]:
"""
Block 5: Aggregate Results and Comparative Analysis (CIFAR-10)
===============================================================
Loads results from Blocks 1-4 and performs comprehensive comparison.

Requirements:
- realnet_cifar10_results.pt (from Block 1)
- quatnet_cifar10_results.pt (from Block 2)
- quantum_noent_cifar10_results.pt (from Block 3)
- quantum_ent_cifar10_results.pt (from Block 4)

Outputs:
- Comprehensive comparison tables
- Statistical analysis
- Research question answers
"""

import numpy as np
import torch

def load_results():
    """Load all results files"""
    results = {}
    
    try:
        real_data = torch.load("realnet_cifar10_results.pt", weights_only=False)
        results["Real"] = real_data["results"]
        print("✓ Loaded RealNet (CIFAR-10) results")
    except FileNotFoundError:
        print("✗ realnet_cifar10_results.pt not found")
        return None
    
    try:
        quat_data = torch.load("quatnet_cifar10_results.pt", weights_only=False)
        results["Quat"] = quat_data["results"]
        print("✓ Loaded QuatNet (CIFAR-10) results")
    except FileNotFoundError:
        print("✗ quatnet_cifar10_results.pt not found")
        return None
    
    try:
        qno_data = torch.load("quantum_noent_cifar10_results.pt", weights_only=False)
        results["QNoEnt"] = qno_data["results"]
        print("✓ Loaded Quantum (no ent) (CIFAR-10) results")
    except FileNotFoundError:
        print("⚠ quantum_noent_cifar10_results.pt not found (skipping)")
        results["QNoEnt"] = []
    
    try:
        qent_data = torch.load("quantum_ent_cifar10_results.pt", weights_only=False)
        results["QEnt"] = qent_data["results"]
        print("✓ Loaded Quantum (with ent) (CIFAR-10) results")
    except FileNotFoundError:
        print("⚠ quantum_ent_cifar10_results.pt not found (skipping)")
        results["QEnt"] = []
    
    return results


def print_summary_table(results):
    """Print aggregated summary table"""
    print("\n" + "=" * 90)
    print("AGGREGATED RESULTS (CIFAR-10) (mean ± std over 3 seeds)")
    print("=" * 90)
    print(f"{'Model':<15} {'Accuracy':<20} {'Time (s)':<20} {'Epochs':<15} {'Parameters':<20}")
    print("-" * 90)
    
    for name in ["Real", "Quat", "QNoEnt", "QEnt"]:
        if not results[name]:
            continue
            
        accs = [r["best_acc"] for r in results[name]]
        times = [r["time"] for r in results[name]]
        epochs = [r["epochs"] for r in results[name]]
        
        acc_str = f"{np.mean(accs):.4f} ± {np.std(accs):.4f}"
        time_str = f"{np.mean(times):.1f} ± {np.std(times):.1f}"
        epoch_str = f"{np.mean(epochs):.1f} ± {np.std(epochs):.1f}"
        
        if name == "Real":
            params = results[name][0]["params"]
            param_str = f"{params:,}"
        else:
            trainable = results[name][0]["trainable_params"]
            total = results[name][0]["total_params"]
            param_str = f"{trainable:,} (head)"
        
        print(f"{name:<15} {acc_str:<20} {time_str:<20} {epoch_str:<15} {param_str:<20}")
    
    print("=" * 90)


def print_per_seed_table(results):
    """Print detailed per-seed results"""
    print("\n" + "=" * 70)
    print("PER-SEED RESULTS (CIFAR-10)")
    print("=" * 70)
    
    seeds = [42, 123, 456]
    
    for seed_idx, seed in enumerate(seeds):
        print(f"\nSeed {seed}:")
        print(f"{'Model':<15} {'Accuracy':<12} {'Time (s)':<12} {'Epochs':<10}")
        print("-" * 50)
        
        for name in ["Real", "Quat", "QNoEnt", "QEnt"]:
            if not results[name]:
                continue
            
            r = results[name][seed_idx]
            print(f"{name:<15} {r['best_acc']:<12.4f} {r['time']:<12.1f} {r['epochs']:<10}")


def comparative_analysis(results):
    """Perform comparative analysis answering research questions"""
    print("\n" + "=" * 90)
    print("COMPARATIVE ANALYSIS (CIFAR-10): Answering Research Questions")
    print("=" * 90)
    
    real_accs = [r["best_acc"] for r in results["Real"]]
    quat_accs = [r["best_acc"] for r in results["Quat"]]
    
    def pct_gap(a, b):
        """Percentage point gap (a - b)"""
        return (np.mean(a) - np.mean(b)) * 100.0
    
    def retention(a, b):
        """Percentage of performance retained"""
        return (np.mean(a) / np.mean(b)) * 100.0
    
    print("\n1. QUATERNION vs REAL MLP:")
    print("   " + "-" * 70)
    print(f"   Real MLP accuracy:        {np.mean(real_accs):.4f} ± {np.std(real_accs):.4f}")
    print(f"   Quaternion accuracy:      {np.mean(quat_accs):.4f} ± {np.std(quat_accs):.4f}")
    print(f"   Gap:                      {pct_gap(real_accs, quat_accs):.2f} percentage points")
    print(f"   Performance retention:    {retention(quat_accs, real_accs):.1f}%")
    print(f"\n   → Classical SU(2) (quaternions) captures {retention(quat_accs, real_accs):.1f}% of")
    print(f"     standard MLP performance with structured algebraic constraints on CIFAR-10.")
    
    if results["QNoEnt"]:
        qno_accs = [r["best_acc"] for r in results["QNoEnt"]]
        
        print("\n2. QUANTUM (no entanglement) vs REAL MLP:")
        print("   " + "-" * 70)
        print(f"   Real MLP accuracy:        {np.mean(real_accs):.4f} ± {np.std(real_accs):.4f}")
        print(f"   Quantum (no ent) accuracy:{np.mean(qno_accs):.4f} ± {np.std(qno_accs):.4f}")
        print(f"   Gap:                      {pct_gap(real_accs, qno_accs):.2f} percentage points")
        print(f"   Performance retention:    {retention(qno_accs, real_accs):.1f}%")
        print(f"\n   → Quantum circuits WITHOUT entanglement capture {retention(qno_accs, real_accs):.1f}%")
        print(f"     of MLP performance, suggesting limited benefit over classical rotation gates")
        print(f"     on CIFAR-10's color images.")
    
    if results["QEnt"]:
        qent_accs = [r["best_acc"] for r in results["QEnt"]]
        
        print("\n3. QUANTUM (with entanglement) vs REAL MLP:")
        print("   " + "-" * 70)
        print(f"   Real MLP accuracy:        {np.mean(real_accs):.4f} ± {np.std(real_accs):.4f}")
        print(f"   Quantum (with ent) accuracy:{np.mean(qent_accs):.4f} ± {np.std(qent_accs):.4f}")
        print(f"   Gap:                      {pct_gap(real_accs, qent_accs):.2f} percentage points")
        print(f"   Performance retention:    {retention(qent_accs, real_accs):.1f}%")
        print(f"\n   → Quantum circuits WITH entanglement capture {retention(qent_accs, real_accs):.1f}%")
        print(f"     of MLP performance on CIFAR-10.")
    
    if results["QNoEnt"] and results["QEnt"]:
        print("\n4. ENTANGLEMENT EFFECT:")
        print("   " + "-" * 70)
        print(f"   Quantum (no ent) accuracy:{np.mean(qno_accs):.4f} ± {np.std(qno_accs):.4f}")
        print(f"   Quantum (with ent) accuracy:{np.mean(qent_accs):.4f} ± {np.std(qent_accs):.4f}")
        print(f"   Improvement:              {pct_gap(qent_accs, qno_accs):.2f} percentage points")
        
        if np.mean(qent_accs) > np.mean(qno_accs):
            improvement_pct = ((np.mean(qent_accs) - np.mean(qno_accs)) / np.mean(qno_accs)) * 100
            print(f"\n   → Entanglement IMPROVES performance by {improvement_pct:.1f}%")
            print(f"     (relative improvement over non-entangled baseline)")
            print(f"     → Pattern consistent with paper findings (~0.6pp gain on FashionMNIST)")
        else:
            print(f"\n   → Entanglement DOES NOT improve performance on CIFAR-10")
            print(f"     → Task difficulty may mask entanglement benefit")
    
    if results["QNoEnt"]:
        print("\n5. QUATERNION vs QUANTUM (no entanglement) - Core Research Question:")
        print("   " + "-" * 70)
        print(f"   Quaternion accuracy:      {np.mean(quat_accs):.4f} ± {np.std(quat_accs):.4f}")
        print(f"   Quantum (no ent) accuracy:{np.mean(qno_accs):.4f} ± {np.std(qno_accs):.4f}")
        print(f"   Gap:                      {pct_gap(quat_accs, qno_accs):.2f} percentage points")
        
        gap_abs = abs(pct_gap(quat_accs, qno_accs))
        if gap_abs < 2.0:
            print(f"\n   → Quaternion networks (classical SU(2)) CLOSELY APPROXIMATE quantum circuits")
            print(f"     without entanglement (quantum SU(2)). Gap < 2 percentage points.")
            print(f"\n   → This supports the hypothesis that classical quaternion algebra can serve")
            print(f"     as an effective implementation of SU(2) geometry on CIFAR-10.")
        elif np.mean(quat_accs) > np.mean(qno_accs):
            print(f"\n   → Quaternion networks OUTPERFORM quantum circuits without entanglement")
            print(f"     by {gap_abs:.2f} percentage points.")
            print(f"\n   → Classical SU(2) (quaternions) SUFFICES for learning tasks addressable")
            print(f"     by product-state quantum circuits on CIFAR-10, consistent with")
            print(f"     the paper's findings on FashionMNIST (~6pp advantage).")
        else:
            print(f"\n   → Quantum circuits without entanglement OUTPERFORM quaternions")
            print(f"     by {gap_abs:.2f} percentage points on CIFAR-10.")
    
    if results["QEnt"]:
        print("\n6. QUATERNION vs QUANTUM (with entanglement):")
        print("   " + "-" * 70)
        print(f"   Quaternion accuracy:      {np.mean(quat_accs):.4f} ± {np.std(quat_accs):.4f}")
        print(f"   Quantum (with ent) accuracy:{np.mean(qent_accs):.4f} ± {np.std(qent_accs):.4f}")
        print(f"   Gap:                      {pct_gap(quat_accs, qent_accs):.2f} percentage points")
        
        if np.mean(qent_accs) > np.mean(quat_accs):
            print(f"\n   → Entangled quantum circuits outperform quaternions, demonstrating")
            print(f"     the value of quantum correlations beyond classical SU(2) rotations.")
        else:
            print(f"\n   → Quaternions match or exceed entangled quantum performance on CIFAR-10,")
            print(f"     suggesting entanglement provides limited benefit on this harder task.")
            print(f"     → Measurement loss + optimization challenges may dominate on complex data.")


def key_takeaways(results):
    """Summarize key takeaways for paper"""
    print("\n" + "=" * 90)
    print("KEY TAKEAWAYS FOR PAPER (CIFAR-10)")
    print("=" * 90)
    
    real_accs = [r["best_acc"] for r in results["Real"]]
    quat_accs = [r["best_acc"] for r in results["Quat"]]
    
    def retention(a, b):
        return (np.mean(a) / np.mean(b)) * 100.0
    
    print("\n✓ EMPIRICAL FINDINGS:")
    print(f"  1. Quaternion networks achieve {retention(quat_accs, real_accs):.1f}% of Real MLP performance")
    print(f"     with structured SU(2) algebraic constraints on CIFAR-10")
    print(f"     → Note: Lower absolute accuracy (~{np.mean(quat_accs)*100:.1f}%) reflects")
    print(f"       16-D bottleneck limitation, not quaternion geometry weakness")
    
    if results["QNoEnt"]:
        qno_accs = [r["best_acc"] for r in results["QNoEnt"]]
        gap = abs((np.mean(quat_accs) - np.mean(qno_accs)) * 100.0)
        print(f"\n  2. Quaternions vs Quantum (no ent): {gap:.2f} percentage point gap")
        if gap < 2.0:
            print(f"     → Classical SU(2) effectively approximates quantum SU(2) without entanglement")
        elif np.mean(quat_accs) > np.mean(qno_accs):
            print(f"     → Classical SU(2) OUTPERFORMS quantum product-state circuits")
            print(f"       (consistent with paper: quaternions >> quantum on FashionMNIST)")
        else:
            print(f"     → Quantum product-state circuits outperform classical SU(2)")
            print(f"       (different pattern from FashionMNIST)")
    
    if results["QEnt"] and results["QNoEnt"]:
        qent_accs = [r["best_acc"] for r in results["QEnt"]]
        qno_accs = [r["best_acc"] for r in results["QNoEnt"]]
        ent_improvement = ((np.mean(qent_accs) - np.mean(qno_accs)) / np.mean(qno_accs)) * 100
        if ent_improvement > 0.5:
            print(f"\n  3. Entanglement improves quantum performance by {ent_improvement:.1f}%")
            print(f"     → Demonstrates measurable value of quantum correlations on CIFAR-10")
            print(f"     → Pattern consistent with paper (~0.6pp gain on FashionMNIST)")
        else:
            print(f"\n  3. Entanglement provides minimal/no benefit ({ent_improvement:.1f}%) on CIFAR-10")
            print(f"     → Task difficulty may mask entanglement advantage")
    
    print("\n✓ IMPLICATIONS:")
    print("  • CIFAR-10 results test generalization of paper's findings to harder tasks")
    print("  • If quaternions still outperform quantum (like FashionMNIST):")
    print("    → Strengthens claim that classical SU(2) suffices for product-state VQCs")
    print("  • If entanglement benefit persists (~0.5-1pp):")
    print("    → Confirms entanglement as quantum boundary across task difficulties")
    print("  • Lower absolute accuracies (~30-40%) reflect bottleneck constraint,")
    print("    NOT failure of any specific architecture")
    
    print("\n✓ COMPUTATIONAL EFFICIENCY:")
    real_time = np.mean([r["time"] for r in results["Real"]])
    quat_time = np.mean([r["time"] for r in results["Quat"]])
    print(f"  • Real MLP:  {real_time:.1f}s (baseline)")
    print(f"  • Quaternion: {quat_time:.1f}s ({quat_time/real_time:.2f}x Real)")
    
    if results["QNoEnt"]:
        qno_time = np.mean([r["time"] for r in results["QNoEnt"]])
        print(f"  • Quantum (no ent): {qno_time:.1f}s ({qno_time/real_time:.1f}x Real)")
        print(f"    → Even with lightning.gpu, quantum is {qno_time/quat_time:.1f}x slower than quaternions")
        print(f"    → Pattern consistent across datasets (FashionMNIST: ~500x, CIFAR-10: ~{qno_time/quat_time:.0f}x)")
    
    print("\n✓ CONTROLLED EXPERIMENTAL DESIGN:")
    print("  • All models use IDENTICAL frozen 3072→16 preprocessor")
    print("  • Performance differences reflect HEAD ARCHITECTURE ONLY")
    print("  • Validates paper's methodology on second dataset (CIFAR-10)")
    
    print("=" * 90)


def main():
    print("=" * 90)
    print("BLOCK 5: AGGREGATE RESULTS AND COMPARATIVE ANALYSIS (CIFAR-10)")
    print("=" * 90)
    
    results = load_results()
    
    if results is None:
        print("\n✗ ERROR: Required results files not found.")
        print("Run Blocks 1 and 2 at minimum (Real and Quat) on CIFAR-10.")
        return
    
    # Summary table
    print_summary_table(results)
    
    # Per-seed details
    print_per_seed_table(results)
    
    # Comparative analysis
    comparative_analysis(results)
    
    # Key takeaways
    key_takeaways(results)
    
    print("\n" + "=" * 90)
    print("ANALYSIS COMPLETE (CIFAR-10)")
    print("=" * 90)
    print("\nAll results saved in:")
    print("  • realnet_cifar10_results.pt")
    print("  • quatnet_cifar10_results.pt")
    if results["QNoEnt"]:
        print("  • quantum_noent_cifar10_results.pt")
    if results["QEnt"]:
        print("  • quantum_ent_cifar10_results.pt")
    print("\n" + "=" * 90)
    print("COMPARISON TO PAPER (FashionMNIST)")
    print("=" * 90)
    print("\nPaper's key findings on FashionMNIST:")
    print("  • Quaternions ≈ Real MLP (within 0.2pp)")
    print("  • Quaternions >> Quantum-NoEnt (by ~2.4pp)")
    print("  • Entanglement helps modestly (0.6pp gain)")
    print("\nCIFAR-10 tests whether these patterns generalize to:")
    print("  • Harder task (color images, more complex)")
    print("  • Same frozen bottleneck constraint (16-D)")
    print("  • Same architectural comparison (controlled design)")
    print("=" * 90)


if __name__ == "__main__":
    main()



BLOCK 5: AGGREGATE RESULTS AND COMPARATIVE ANALYSIS (CIFAR-10)
✓ Loaded RealNet (CIFAR-10) results
✓ Loaded QuatNet (CIFAR-10) results
✓ Loaded Quantum (no ent) (CIFAR-10) results
✓ Loaded Quantum (with ent) (CIFAR-10) results

AGGREGATED RESULTS (CIFAR-10) (mean ± std over 3 seeds)
Model           Accuracy             Time (s)             Epochs          Parameters          
------------------------------------------------------------------------------------------
Real            0.4023 ± 0.0044      11.0 ± 1.7           24.0 ± 3.6      50,906              
Quat            0.3792 ± 0.0033      24.3 ± 2.7           36.0 ± 5.7      1,000 (head)        
QNoEnt          0.3510 ± 0.0071      15503.8 ± 916.7      32.3 ± 1.9      162 (head)          
QEnt            0.3492 ± 0.0022      23117.3 ± 7497.6     43.3 ± 14.5     162 (head)          

PER-SEED RESULTS (CIFAR-10)

Seed 42:
Model           Accuracy     Time (s)     Epochs    
--------------------------------------------------
Real   