# RealNet CNN CIFAR

In [1]:
"""
Block 1: RealNet Training with Frozen CNN (CIFAR-10)
====================================================
Trains the baseline Real MLP head on top of frozen ResNet18 features.
Uses 3 seeds on CIFAR-10 with stratified sampling.
Saves the frozen CNN state for reuse in subsequent blocks.

Outputs:
- realnet_cnn_cifar10_results.pt: Contains results dict and frozen CNN state
"""

import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from torchvision.models import resnet18, ResNet18_Weights

device = torch.device("cuda")
print(f"Using PyTorch device: {device}")

# ============================================
# Comprehensive seed setting for reproducibility
# ============================================

def set_all_seeds(seed):
    """Set seeds for all RNG sources for reproducibility"""
    import random
    import numpy as np
    import torch
    
    # Python
    random.seed(seed)
    
    # NumPy
    np.random.seed(seed)
    
    # PyTorch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    
    # PyTorch deterministic mode
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # CuPy (if available)
    try:
        import cupy as cp
        cp.random.seed(seed)
    except ImportError:
        pass
    
    # PennyLane (if available)
    try:
        import pennylane as qml
        qml.numpy.random.seed(seed)
    except (ImportError, AttributeError):
        pass


# ============================================
# Frozen CNN Feature Extractor (ResNet18)
# ============================================

class FrozenCNNExtractor(nn.Module):
    """
    Frozen ResNet18 feature extractor.
    Outputs 512-dimensional features from penultimate layer.
    Pre-trained on ImageNet.
    """
    def __init__(self):
        super().__init__()
        # Load pre-trained ResNet18
        weights = ResNet18_Weights.IMAGENET1K_V1
        resnet = resnet18(weights=weights)
        
        # Remove the final FC layer (we only want features)
        # ResNet18 structure: conv layers -> avgpool -> fc
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        
        # Freeze all parameters
        for param in self.features.parameters():
            param.requires_grad = False
        
        self.eval()  # Set to eval mode permanently
    
    def forward(self, x):
        with torch.no_grad():  # No gradients needed
            features = self.features(x)
            # Output shape: (batch, 512, 1, 1)
            features = features.view(features.size(0), -1)  # Flatten to (batch, 512)
        return features


# ============================================
# Real-valued Head
# ============================================

class RealHead(nn.Module):
    """Standard MLP: 512 → 128 → 10"""
    def __init__(self, input_dim=512, hidden_dim=128, num_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


class RealNetCNN(nn.Module):
    """Complete Real network: Frozen CNN + RealHead"""
    def __init__(self):
        super().__init__()
        self.cnn_extractor = FrozenCNNExtractor()
        self.head = RealHead(512, 128, 10)

    def forward(self, x):
        features = self.cnn_extractor(x)
        return self.head(features)


# ============================================
# Training and Evaluation
# ============================================

def train_one_epoch(model, loader, optimizer, device, show_progress=False):
    model.train()
    total_loss = 0.0
    total_samples = 0

    for batch_idx, (x, y) in enumerate(loader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        total_samples += x.size(0)

        if show_progress and batch_idx % 50 == 0:
            print(f"    Batch {batch_idx}/{len(loader)}, samples: {total_samples}", end="\r")

    if show_progress:
        print()
    return total_loss / len(loader.dataset)


def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    return correct / total if total > 0 else 0.0


def train_with_early_stopping(model, train_loader, test_loader, optimizer,
                              device, max_epochs=40, patience=10, name="Model"):
    if isinstance(device, str):
        device = torch.device(device)

    best_acc = 0.0
    epochs_without_improvement = 0
    start = time.time()
    last_acc = 0.0

    for epoch in range(1, max_epochs + 1):
        loss = train_one_epoch(model, train_loader, optimizer, device, show_progress=False)
        acc = evaluate(model, test_loader, device)
        last_acc = acc

        elapsed = time.time() - start
        print(f"  [{name}] Epoch {epoch:2d} | loss={loss:.4f} "
              f"| test_acc={acc:.4f} | time={elapsed:.1f}s")

        if acc > best_acc:
            best_acc = acc
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print(f"  [{name}] Early stop at epoch {epoch} "
                  f"(no improvement for {patience} epochs)")
            break

    total_time = time.time() - start
    return {
        "best_acc": best_acc,
        "final_acc": last_acc,
        "time": total_time,
        "epochs": epoch,
    }


# ============================================
# Data utilities
# ============================================

def stratified_sample_from_targets(dataset, n_samples_per_class, seed=42):
    """
    Create a stratified sample with n_samples_per_class from each class.
    Uses dataset.targets directly - NO image loading during sampling.
    """
    if hasattr(dataset, 'targets'):
        targets = np.array(dataset.targets)
    else:
        targets = np.array([dataset[i][1] for i in range(len(dataset))])
    
    rng = np.random.RandomState(seed)
    
    sampled_indices = []
    for c in range(10):
        idx_c = np.where(targets == c)[0]
        k = min(n_samples_per_class, len(idx_c))
        selected = rng.choice(idx_c, size=k, replace=False).tolist()
        sampled_indices.extend(selected)
    
    rng.shuffle(sampled_indices)
    return sampled_indices


# ============================================
# Main
# ============================================

def main():
    print("=" * 70)
    print("BLOCK 1: RealNet Training with Frozen CNN (CIFAR-10)")
    print("=" * 70)
    print("\nConfiguration:")
    print("  • Dataset: CIFAR-10 (32×32 RGB)")
    print("  • CNN Backbone: ResNet18 (ImageNet pre-trained, FROZEN)")
    print("  • Feature dimension: 512")
    print("  • Batch size: 128")
    print("  • Patience: 10")
    print("  • Seeds: [42, 123, 456]")
    print("  • Architecture: CNN(frozen) → 512 → 128 → 10")
    print("=" * 70)

    # ImageNet normalization (required for pre-trained ResNet)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],  # ImageNet stats
            std=[0.229, 0.224, 0.225]
        )
    ])

    # Load full datasets
    full_train_ds = datasets.CIFAR10(root="./data", train=True,
                                     download=True, transform=transform)
    full_test_ds = datasets.CIFAR10(root="./data", train=False,
                                    download=True, transform=transform)

    # Create stratified samples
    print("\nCreating stratified samples...")
    t0 = time.time()
    train_indices = stratified_sample_from_targets(full_train_ds, n_samples_per_class=1500, seed=42)
    test_indices = stratified_sample_from_targets(full_test_ds, n_samples_per_class=300, seed=42)
    print(f"  Sampling took {time.time()-t0:.2f}s")
    
    train_ds = Subset(full_train_ds, train_indices)
    test_ds = Subset(full_test_ds, test_indices)
    
    print(f"  Train samples: {len(train_ds)} (stratified, 1500 per class)")
    print(f"  Test samples:  {len(test_ds)} (stratified, 300 per class)")

    train_loader = DataLoader(train_ds, batch_size=128,
                              shuffle=True, num_workers=4)
    test_loader = DataLoader(test_ds, batch_size=256,
                             shuffle=False, num_workers=4)

    seeds = [42, 123, 456]
    all_results = []
    frozen_cnn_state = None

    for seed in seeds:
        print(f"\n{'=' * 70}")
        print(f"SEED {seed}")
        print("=" * 70)
        
        set_all_seeds(seed)

        print(f"\n  Training RealNet-CNN (seed={seed})...")
        real_model = RealNetCNN().to(device)
        
        # Only optimize the head parameters (CNN is frozen)
        real_opt = torch.optim.Adam(real_model.head.parameters(), lr=1e-3)
        
        result = train_with_early_stopping(
            real_model, train_loader, test_loader, real_opt, device,
            max_epochs=40, patience=10, name="Real-CNN"
        )
        
        # Count only trainable parameters
        trainable_params = sum(p.numel() for p in real_model.head.parameters())
        total_params = sum(p.numel() for p in real_model.parameters())
        
        result["trainable_params"] = trainable_params
        result["total_params"] = total_params
        result["seed"] = seed
        
        all_results.append(result)
        
        # Save frozen CNN state from first seed
        if frozen_cnn_state is None:
            frozen_cnn_state = real_model.cnn_extractor.state_dict()
            print(f"\n  → Saved frozen CNN state from seed {seed}")

    # Summary
    print("\n" + "=" * 70)
    print("REALNET-CNN SUMMARY (CIFAR-10)")
    print("=" * 70)
    
    accs = [r["best_acc"] for r in all_results]
    times = [r["time"] for r in all_results]
    epochs = [r["epochs"] for r in all_results]
    trainable = all_results[0]["trainable_params"]
    total = all_results[0]["total_params"]
    
    print(f"\nAccuracy:          {np.mean(accs):.4f} ± {np.std(accs):.4f}")
    print(f"Time:              {np.mean(times):.1f}s ± {np.std(times):.1f}s")
    print(f"Epochs:            {np.mean(epochs):.1f} ± {np.std(epochs):.1f}")
    print(f"Trainable params:  {trainable:,}")
    print(f"Total params:      {total:,}")
    print(f"Frozen params:     {total - trainable:,}")
    
    print("\nPer-seed results:")
    for r in all_results:
        print(f"  Seed {r['seed']}: acc={r['best_acc']:.4f}, "
              f"time={r['time']:.1f}s, epochs={r['epochs']}")

    # Save results and frozen CNN
    save_dict = {
        "results": all_results,
        "frozen_cnn_state": frozen_cnn_state,
        "summary": {
            "mean_acc": np.mean(accs),
            "std_acc": np.std(accs),
            "mean_time": np.mean(times),
            "trainable_params": trainable,
            "total_params": total
        }
    }
    
    torch.save(save_dict, "realnet_cnn_cifar10_results.pt")
    print(f"\n✓ Saved results to: realnet_cnn_cifar10_results.pt")
    print("=" * 70)


if __name__ == "__main__":
    main()

Using PyTorch device: cuda
BLOCK 1: RealNet Training with Frozen CNN (CIFAR-10)

Configuration:
  • Dataset: CIFAR-10 (32×32 RGB)
  • CNN Backbone: ResNet18 (ImageNet pre-trained, FROZEN)
  • Feature dimension: 512
  • Batch size: 128
  • Patience: 10
  • Seeds: [42, 123, 456]
  • Architecture: CNN(frozen) → 512 → 128 → 10

Creating stratified samples...
  Sampling took 0.00s
  Train samples: 15000 (stratified, 1500 per class)
  Test samples:  3000 (stratified, 300 per class)

SEED 42

  Training RealNet-CNN (seed=42)...
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/dustoff06/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████████████████████████████████████████████████████████████████████████| 44.7M/44.7M [00:01<00:00, 30.3MB/s]


  [Real-CNN] Epoch  1 | loss=1.8288 | test_acc=0.4087 | time=1.2s
  [Real-CNN] Epoch  2 | loss=1.5903 | test_acc=0.4397 | time=1.9s
  [Real-CNN] Epoch  3 | loss=1.5181 | test_acc=0.4413 | time=2.5s
  [Real-CNN] Epoch  4 | loss=1.4819 | test_acc=0.4503 | time=3.2s
  [Real-CNN] Epoch  5 | loss=1.4372 | test_acc=0.4463 | time=3.8s
  [Real-CNN] Epoch  6 | loss=1.4145 | test_acc=0.4527 | time=4.4s
  [Real-CNN] Epoch  7 | loss=1.3820 | test_acc=0.4633 | time=5.0s
  [Real-CNN] Epoch  8 | loss=1.3512 | test_acc=0.4660 | time=5.6s
  [Real-CNN] Epoch  9 | loss=1.3328 | test_acc=0.4543 | time=6.2s
  [Real-CNN] Epoch 10 | loss=1.2916 | test_acc=0.4593 | time=6.8s
  [Real-CNN] Epoch 11 | loss=1.2570 | test_acc=0.4657 | time=7.4s
  [Real-CNN] Epoch 12 | loss=1.2416 | test_acc=0.4533 | time=8.0s
  [Real-CNN] Epoch 13 | loss=1.2096 | test_acc=0.4527 | time=8.6s
  [Real-CNN] Epoch 14 | loss=1.1876 | test_acc=0.4543 | time=9.2s
  [Real-CNN] Epoch 15 | loss=1.1723 | test_acc=0.4560 | time=9.8s
  [Real-CN

# Quaternion with CNN Head

In [2]:
"""
Block 2: QuatNet Training with Frozen CNN (CIFAR-10)
====================================================
Loads frozen CNN from Block 1 and trains quaternion head on 3 seeds.

Requirements:
- realnet_cnn_cifar10_results.pt (from Block 1)

Outputs:
- quatnet_cnn_cifar10_results.pt: Contains quaternion head results
"""

import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from torchvision.models import resnet18, ResNet18_Weights

device = torch.device("cuda")
print(f"Using PyTorch device: {device}")

# ============================================
# Comprehensive seed setting for reproducibility
# ============================================

def set_all_seeds(seed):
    """Set seeds for all RNG sources for reproducibility"""
    import random
    import numpy as np
    import torch
    
    # Python
    random.seed(seed)
    
    # NumPy
    np.random.seed(seed)
    
    # PyTorch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    
    # PyTorch deterministic mode
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # CuPy (if available)
    try:
        import cupy as cp
        cp.random.seed(seed)
    except ImportError:
        pass
    
    # PennyLane (if available)
    try:
        import pennylane as qml
        qml.numpy.random.seed(seed)
    except (ImportError, AttributeError):
        pass


# ============================================
# Quaternion utilities (PyTorch tensors)
# ============================================

def q_normalize(q):
    norm = torch.linalg.norm(q, dim=-1, keepdim=True) + 1e-8
    return q / norm

def q_conj(q):
    w, x, y, z = torch.unbind(q, dim=-1)
    return torch.stack([w, -x, -y, -z], dim=-1)

def q_mul(a, b):
    """Hamilton product of two quaternions"""
    aw, ax, ay, az = torch.unbind(a, dim=-1)
    bw, bx, by, bz = torch.unbind(b, dim=-1)

    w = aw * bw - ax * bx - ay * by - az * bz
    x = aw * bx + ax * bw + ay * bz - az * by
    y = aw * by - ax * bz + ay * bw + az * bx
    z = aw * bz + ax * by - ay * bx + az * bw

    return torch.stack([w, x, y, z], dim=-1)


# ============================================
# Frozen CNN Feature Extractor (from Block 1)
# ============================================

class FrozenCNNExtractor(nn.Module):
    """
    Frozen ResNet18 feature extractor.
    Outputs 512-dimensional features from penultimate layer.
    """
    def __init__(self):
        super().__init__()
        weights = ResNet18_Weights.IMAGENET1K_V1
        resnet = resnet18(weights=weights)
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        
        for param in self.features.parameters():
            param.requires_grad = False
        
        self.eval()
    
    def forward(self, x):
        with torch.no_grad():
            features = self.features(x)
            features = features.view(features.size(0), -1)
        return features


# ============================================
# Quaternion Head and Network
# ============================================

class QuaternionLinear(nn.Module):
    def __init__(self, in_features, out_features):
        """
        in_features, out_features are in "quaternion units".
        Internally weight: (out_features, in_features, 4)
        """
        super().__init__()
        self.weight = nn.Parameter(torch.empty(out_features, in_features, 4))
        self.bias = nn.Parameter(torch.empty(out_features, 4))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.weight, mean=0.0, std=0.1)
        with torch.no_grad():
            self.weight[:] = q_normalize(self.weight)
            nn.init.constant_(self.bias[..., 0], 1.0)
            nn.init.constant_(self.bias[..., 1:], 0.0)

    def forward(self, x):
        """
        x: (B, in_features, 4)
        Returns: (B, out_features, 4)
        """
        w = self.weight.unsqueeze(0)
        x_exp = x.unsqueeze(1)
        prod = q_mul(w, x_exp)
        out = prod.sum(dim=2) + self.bias
        return out


class QuatHead(nn.Module):
    """
    Quaternion head: 128 quats → 32 quats → 10 quats → 10 logits.
    Input: 512 real features → 128 quaternions
    """
    def __init__(self, input_dim=512, num_classes=10):
        super().__init__()
        # 512 real features = 128 quaternions (512/4)
        self.input_quats = input_dim // 4
        self.quat_fc1 = QuaternionLinear(self.input_quats, 32)
        self.quat_fc2 = QuaternionLinear(32, num_classes)

    def real_to_quat(self, x):
        """Convert 512 real features to 128 quaternions"""
        B = x.size(0)
        return x.view(B, self.input_quats, 4)

    def quat_to_real(self, q):
        """Extract real part of quaternions for classification"""
        return q[..., 0]

    def forward(self, x):
        q_in = self.real_to_quat(x)
        hq = self.quat_fc1(q_in)
        hq = q_normalize(hq)
        hq = torch.tanh(hq)
        q_out = self.quat_fc2(hq)
        logits = self.quat_to_real(q_out)
        return logits


class QuatNetCNN(nn.Module):
    """Complete Quaternion network: Frozen CNN + QuatHead"""
    def __init__(self):
        super().__init__()
        self.cnn_extractor = FrozenCNNExtractor()
        self.head = QuatHead(512, 10)

    def forward(self, x):
        features = self.cnn_extractor(x)
        return self.head(features)


# ============================================
# Training and Evaluation
# ============================================

def train_one_epoch(model, loader, optimizer, device, show_progress=False):
    model.train()
    total_loss = 0.0
    total_samples = 0

    for batch_idx, (x, y) in enumerate(loader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        total_samples += x.size(0)

        if show_progress and batch_idx % 50 == 0:
            print(f"    Batch {batch_idx}/{len(loader)}, samples: {total_samples}", end="\r")

    if show_progress:
        print()
    return total_loss / len(loader.dataset)


def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    return correct / total if total > 0 else 0.0


def train_with_early_stopping(model, train_loader, test_loader, optimizer,
                              device, max_epochs=40, patience=10, name="Model"):
    if isinstance(device, str):
        device = torch.device(device)

    best_acc = 0.0
    epochs_without_improvement = 0
    start = time.time()
    last_acc = 0.0

    for epoch in range(1, max_epochs + 1):
        loss = train_one_epoch(model, train_loader, optimizer, device, show_progress=False)
        acc = evaluate(model, test_loader, device)
        last_acc = acc

        elapsed = time.time() - start
        print(f"  [{name}] Epoch {epoch:2d} | loss={loss:.4f} "
              f"| test_acc={acc:.4f} | time={elapsed:.1f}s")

        if acc > best_acc:
            best_acc = acc
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print(f"  [{name}] Early stop at epoch {epoch} "
                  f"(no improvement for {patience} epochs)")
            break

    total_time = time.time() - start
    return {
        "best_acc": best_acc,
        "final_acc": last_acc,
        "time": total_time,
        "epochs": epoch,
    }


# ============================================
# Data utilities
# ============================================

def stratified_sample_from_targets(dataset, n_samples_per_class, seed=42):
    """
    Create a stratified sample with n_samples_per_class from each class.
    Uses dataset.targets directly - NO image loading during sampling.
    """
    if hasattr(dataset, 'targets'):
        targets = np.array(dataset.targets)
    else:
        targets = np.array([dataset[i][1] for i in range(len(dataset))])
    
    rng = np.random.RandomState(seed)
    
    sampled_indices = []
    for c in range(10):
        idx_c = np.where(targets == c)[0]
        k = min(n_samples_per_class, len(idx_c))
        selected = rng.choice(idx_c, size=k, replace=False).tolist()
        sampled_indices.extend(selected)
    
    rng.shuffle(sampled_indices)
    return sampled_indices


# ============================================
# Main
# ============================================

def main():
    print("=" * 70)
    print("BLOCK 2: QuatNet Training with Frozen CNN (CIFAR-10)")
    print("=" * 70)
    print("\nConfiguration:")
    print("  • Dataset: CIFAR-10 (32×32 RGB)")
    print("  • CNN Backbone: ResNet18 (FROZEN)")
    print("  • Batch size: 128")
    print("  • Patience: 10")
    print("  • Seeds: [42, 123, 456]")
    print("  • Architecture: CNN(frozen) → 512 → 128 quats → 32 quats → 10")
    print("=" * 70)

    # Load frozen CNN from Block 1
    print("\nLoading frozen CNN from Block 1...")
    try:
        realnet_data = torch.load("realnet_cnn_cifar10_results.pt", weights_only=False)
        frozen_cnn_state = realnet_data["frozen_cnn_state"]
        print("✓ Loaded frozen CNN state from realnet_cnn_cifar10_results.pt")
    except FileNotFoundError:
        print("✗ ERROR: realnet_cnn_cifar10_results.pt not found!")
        print("  Run Block 1 first to train RealNet and save frozen CNN.")
        return

    # Create frozen CNN
    shared_cnn = FrozenCNNExtractor().to(device)
    shared_cnn.load_state_dict(frozen_cnn_state)
    for p in shared_cnn.parameters():
        p.requires_grad = False
    
    cnn_params = sum(p.numel() for p in shared_cnn.parameters())
    print(f"  CNN frozen with {cnn_params:,} params")

    # Load data - ImageNet normalization (required for pre-trained CNN)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    full_train_ds = datasets.CIFAR10(root="./data", train=True,
                                     download=True, transform=transform)
    full_test_ds = datasets.CIFAR10(root="./data", train=False,
                                    download=True, transform=transform)

    # Create stratified samples
    print("\nCreating stratified samples...")
    t0 = time.time()
    train_indices = stratified_sample_from_targets(full_train_ds, n_samples_per_class=1500, seed=42)
    test_indices = stratified_sample_from_targets(full_test_ds, n_samples_per_class=300, seed=42)
    print(f"  Sampling took {time.time()-t0:.2f}s")
    
    train_ds = Subset(full_train_ds, train_indices)
    test_ds = Subset(full_test_ds, test_indices)
    
    print(f"  Train samples: {len(train_ds)} (stratified, 1500 per class)")
    print(f"  Test samples:  {len(test_ds)} (stratified, 300 per class)")

    train_loader = DataLoader(train_ds, batch_size=128,
                              shuffle=True, num_workers=4)
    test_loader = DataLoader(test_ds, batch_size=256,
                             shuffle=False, num_workers=4)

    seeds = [42, 123, 456]
    all_results = []

    for seed in seeds:
        print(f"\n{'=' * 70}")
        print(f"SEED {seed}")
        print("=" * 70)
        
        set_all_seeds(seed)

        print(f"\n  Training QuatNet-CNN (seed={seed})...")
        quat_model = QuatNetCNN().to(device)
        quat_model.cnn_extractor = shared_cnn  # Use frozen CNN
        
        # Only optimize head parameters
        quat_opt = torch.optim.Adam(quat_model.head.parameters(), lr=1e-3)
        
        result = train_with_early_stopping(
            quat_model, train_loader, test_loader, quat_opt, device,
            max_epochs=40, patience=10, name="Quat-CNN"
        )
        
        trainable_params = sum(p.numel() for p in quat_model.head.parameters())
        total_params = sum(p.numel() for p in quat_model.parameters())
        
        result["trainable_params"] = trainable_params
        result["total_params"] = total_params
        result["seed"] = seed
        
        all_results.append(result)

    # Summary
    print("\n" + "=" * 70)
    print("QUATNET-CNN SUMMARY (CIFAR-10)")
    print("=" * 70)
    
    accs = [r["best_acc"] for r in all_results]
    times = [r["time"] for r in all_results]
    epochs = [r["epochs"] for r in all_results]
    trainable = all_results[0]["trainable_params"]
    total = all_results[0]["total_params"]
    
    print(f"\nAccuracy:          {np.mean(accs):.4f} ± {np.std(accs):.4f}")
    print(f"Time:              {np.mean(times):.1f}s ± {np.std(times):.1f}s")
    print(f"Epochs:            {np.mean(epochs):.1f} ± {np.std(epochs):.1f}")
    print(f"Trainable params:  {trainable:,}")
    print(f"Total params:      {total:,}")
    print(f"Frozen params:     {total - trainable:,}")
    
    print("\nPer-seed results:")
    for r in all_results:
        print(f"  Seed {r['seed']}: acc={r['best_acc']:.4f}, "
              f"time={r['time']:.1f}s, epochs={r['epochs']}")

    # Save results
    save_dict = {
        "results": all_results,
        "summary": {
            "mean_acc": np.mean(accs),
            "std_acc": np.std(accs),
            "mean_time": np.mean(times),
            "trainable_params": trainable,
            "total_params": total
        }
    }
    
    torch.save(save_dict, "quatnet_cnn_cifar10_results.pt")
    print(f"\n✓ Saved results to: quatnet_cnn_cifar10_results.pt")
    print("=" * 70)


if __name__ == "__main__":
    main()

Using PyTorch device: cuda
BLOCK 2: QuatNet Training with Frozen CNN (CIFAR-10)

Configuration:
  • Dataset: CIFAR-10 (32×32 RGB)
  • CNN Backbone: ResNet18 (FROZEN)
  • Batch size: 128
  • Patience: 10
  • Seeds: [42, 123, 456]
  • Architecture: CNN(frozen) → 512 → 128 quats → 32 quats → 10

Loading frozen CNN from Block 1...
✓ Loaded frozen CNN state from realnet_cnn_cifar10_results.pt
  CNN frozen with 11,176,512 params

Creating stratified samples...
  Sampling took 0.00s
  Train samples: 15000 (stratified, 1500 per class)
  Test samples:  3000 (stratified, 300 per class)

SEED 42

  Training QuatNet-CNN (seed=42)...
  [Quat-CNN] Epoch  1 | loss=3.2419 | test_acc=0.1780 | time=1.1s
  [Quat-CNN] Epoch  2 | loss=2.4684 | test_acc=0.2443 | time=1.9s
  [Quat-CNN] Epoch  3 | loss=2.1628 | test_acc=0.2850 | time=2.9s
  [Quat-CNN] Epoch  4 | loss=1.9959 | test_acc=0.3163 | time=3.7s
  [Quat-CNN] Epoch  5 | loss=1.9119 | test_acc=0.3347 | time=4.7s
  [Quat-CNN] Epoch  6 | loss=1.8316 | tes

# Quantum No Ent CNN

In [4]:
"""
Block 3: Quantum (No Entanglement) Training with Frozen CNN (CIFAR-10)
=======================================================================
Loads frozen CNN from Block 1 and trains quantum head WITHOUT entanglement.
Uses Lightning-GPU acceleration.

Requirements:
- realnet_cnn_cifar10_results.pt (from Block 1)
- pennylane, pennylane-lightning-gpu

Outputs:
- quantum_noent_cnn_cifar10_results.pt: Contains quantum (no ent) results
"""

import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from torchvision.models import resnet18, ResNet18_Weights

try:
    import pennylane as qml
    QUANTUM_DEVICE = "lightning.gpu"
    PENNYLANE_AVAILABLE = True
    print("✓ Using lightning.gpu device")
except ImportError:
    PENNYLANE_AVAILABLE = False
    print("✗ PennyLane not installed")
    exit(1)

device = torch.device("cuda")
print(f"Using PyTorch device: {device}")

# ============================================
# Comprehensive seed setting for reproducibility
# ============================================

def set_all_seeds(seed):
    """Set seeds for all RNG sources for reproducibility"""
    import random
    import numpy as np
    import torch
    
    # Python
    random.seed(seed)
    
    # NumPy
    np.random.seed(seed)
    
    # PyTorch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    
    # PyTorch deterministic mode
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # CuPy (if available)
    try:
        import cupy as cp
        cp.random.seed(seed)
    except ImportError:
        pass
    
    # PennyLane (if available)
    try:
        import pennylane as qml
        qml.numpy.random.seed(seed)
    except (ImportError, AttributeError):
        pass


# ============================================
# Frozen CNN Feature Extractor (from Block 1)
# ============================================

class FrozenCNNExtractor(nn.Module):
    """
    Frozen ResNet18 feature extractor.
    Outputs 512-dimensional features from penultimate layer.
    """
    def __init__(self):
        super().__init__()
        weights = ResNet18_Weights.IMAGENET1K_V1
        resnet = resnet18(weights=weights)
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        
        for param in self.features.parameters():
            param.requires_grad = False
        
        self.eval()
    
    def forward(self, x):
        with torch.no_grad():
            features = self.features(x)
            features = features.view(features.size(0), -1)
        return features


# ============================================
# Quantum Head (No Entanglement)
# ============================================

class QuantumHead(nn.Module):
    """
    VQC with 8 qubits, 3 layers, NO entanglement → 10 classes.
    Uses Lightning acceleration (GPU).
    Maps 512 CNN features → 8 qubits.
    """
    def __init__(self, input_dim=512, n_qubits=8, n_layers=3, num_classes=10, device_name=None):
        super().__init__()
        self.n_qubits = n_qubits
        self.n_layers = n_layers

        # Map 512 features → n_qubits
        self.feature_select = nn.Linear(input_dim, n_qubits)

        # Use specified device or default
        if device_name is None:
            device_name = QUANTUM_DEVICE
        
        # Quantum device
        self.dev = qml.device(device_name, wires=n_qubits)

        # Use adjoint differentiation for lightning (much faster)
        diff_method = "adjoint" if "lightning" in device_name else "parameter-shift"

        @qml.qnode(self.dev, interface="torch", diff_method=diff_method)
        def quantum_circuit(inputs, weights):
            """
            Single-sample circuit WITHOUT entanglement.
            inputs: (n_qubits,)
            weights: (n_layers, n_qubits, 2)
            """
            for layer in range(n_layers):
                # Data re-uploading
                for i in range(n_qubits):
                    qml.RY(inputs[i], wires=i)

                # Trainable rotations
                for i in range(n_qubits):
                    qml.RY(weights[layer, i, 0], wires=i)
                    qml.RZ(weights[layer, i, 1], wires=i)

                # NO ENTANGLEMENT

            # Measure more observables for richer output (12 measurements)
            measurements = []
            # Single-qubit Z measurements
            for i in range(n_qubits):
                measurements.append(qml.expval(qml.PauliZ(i)))
            # Two-qubit ZZ measurements (pairs)
            for i in range(0, n_qubits-1, 2):
                measurements.append(qml.expval(qml.PauliZ(i) @ qml.PauliZ(i+1)))
            
            return measurements

        self.quantum_circuit = quantum_circuit

        weight_shape = (n_layers, n_qubits, 2)
        self.q_weights = nn.Parameter(torch.randn(weight_shape) * 0.1)
        
        # Output layer: 12 measurements → 10 classes
        n_measurements = n_qubits + (n_qubits // 2)
        self.fc_out = nn.Linear(n_measurements, num_classes)

    def forward(self, x):
        """
        Process samples with progress tracking.
        x: (batch, 512) CNN features
        """
        batch_size = x.size(0)
        x = torch.tanh(self.feature_select(x))

        # Process in chunks for memory management
        chunk_size = 32
        quantum_outputs = []

        for start_idx in range(0, batch_size, chunk_size):
            end_idx = min(start_idx + chunk_size, batch_size)
            chunk = x[start_idx:end_idx]

            chunk_outputs = []
            for i in range(chunk.size(0)):
                q_raw = self.quantum_circuit(chunk[i], self.q_weights)
                if isinstance(q_raw, (list, tuple)):
                    q_out = torch.stack(q_raw)
                else:
                    q_out = q_raw
                chunk_outputs.append(q_out)

            quantum_outputs.extend(chunk_outputs)

        # Convert to tensor (cast to float32)
        quantum_outputs = torch.stack(quantum_outputs).float()
        quantum_outputs = quantum_outputs.to(self.fc_out.weight.dtype)

        output = self.fc_out(quantum_outputs)
        return output


class QuantumNetCNN(nn.Module):
    """Complete Quantum network: Frozen CNN + QuantumHead (no ent)"""
    def __init__(self):
        super().__init__()
        self.cnn_extractor = FrozenCNNExtractor()
        self.head = QuantumHead(input_dim=512, n_qubits=8, n_layers=3, num_classes=10)

    def forward(self, x):
        # CNN on GPU
        features = self.cnn_extractor(x)
        # Move to CPU for quantum processing
        features_cpu = features.cpu()
        return self.head(features_cpu)


# ============================================
# Training and Evaluation
# ============================================

def train_one_epoch(model, loader, optimizer, device, show_progress=True):
    model.train()
    total_loss = 0.0
    total_samples = 0

    for batch_idx, (x, y) in enumerate(loader):
        # Move images to GPU for CNN
        x = x.to(device)
        # Keep labels on CPU
        
        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        total_samples += x.size(0)

        if show_progress and batch_idx % 20 == 0:
            print(f"    Batch {batch_idx}/{len(loader)}, samples: {total_samples}", end="\r")

    if show_progress:
        print()
    return total_loss / len(loader.dataset)


def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            # Move images to GPU for CNN
            x = x.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    return correct / total if total > 0 else 0.0


def train_with_early_stopping(model, train_loader, test_loader, optimizer,
                              device, max_epochs=200, patience=10, name="Model"):
    best_acc = 0.0
    epochs_without_improvement = 0
    start = time.time()
    last_acc = 0.0

    for epoch in range(1, max_epochs + 1):
        print(f"  [{name}] Epoch {epoch}/{max_epochs}")
        loss = train_one_epoch(model, train_loader, optimizer, device, show_progress=True)
        acc = evaluate(model, test_loader, device)
        last_acc = acc

        elapsed = time.time() - start
        print(f"  [{name}] Epoch {epoch:2d} | loss={loss:.4f} "
              f"| test_acc={acc:.4f} | time={elapsed:.1f}s")

        if acc > best_acc:
            best_acc = acc
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print(f"  [{name}] Early stop at epoch {epoch} "
                  f"(no improvement for {patience} epochs)")
            break

    total_time = time.time() - start
    return {
        "best_acc": best_acc,
        "final_acc": last_acc,
        "time": total_time,
        "epochs": epoch,
    }


# ============================================
# Data utilities
# ============================================

def stratified_sample_from_targets(dataset, n_samples_per_class, seed=42):
    """
    Create a stratified sample with n_samples_per_class from each class.
    Uses dataset.targets directly - NO image loading during sampling.
    """
    if hasattr(dataset, 'targets'):
        targets = np.array(dataset.targets)
    else:
        targets = np.array([dataset[i][1] for i in range(len(dataset))])
    
    rng = np.random.RandomState(seed)
    
    sampled_indices = []
    for c in range(10):
        idx_c = np.where(targets == c)[0]
        k = min(n_samples_per_class, len(idx_c))
        selected = rng.choice(idx_c, size=k, replace=False).tolist()
        sampled_indices.extend(selected)
    
    rng.shuffle(sampled_indices)
    return sampled_indices


# ============================================
# Main
# ============================================

def main():
    print("=" * 70)
    print("BLOCK 3: Quantum (NO Entanglement) Training with Frozen CNN (CIFAR-10)")
    print("=" * 70)
    print("\nConfiguration:")
    print("  • Dataset: CIFAR-10 (32×32 RGB)")
    print("  • CNN Backbone: ResNet18 (FROZEN)")
    print("  • Batch size: 32")
    print("  • Patience: 10")
    print("  • Max epochs: 200")
    print("  • Seeds: [42, 123, 456]")
    print("  • Architecture: CNN(frozen) → 512 → 8 qubits (3 layers, NO ent) → 10")
    print(f"  • Quantum device: {QUANTUM_DEVICE}")
    print("=" * 70)

    # Load frozen CNN from Block 1
    print("\nLoading frozen CNN from Block 1...")
    try:
        realnet_data = torch.load("realnet_cnn_cifar10_results.pt", weights_only=False)
        frozen_cnn_state = realnet_data["frozen_cnn_state"]
        print("✓ Loaded frozen CNN state from realnet_cnn_cifar10_results.pt")
    except FileNotFoundError:
        print("✗ ERROR: realnet_cnn_cifar10_results.pt not found!")
        print("  Run Block 1 first to train RealNet and save frozen CNN.")
        return

    # Create frozen CNN
    shared_cnn = FrozenCNNExtractor().to(device)
    shared_cnn.load_state_dict(frozen_cnn_state)
    for p in shared_cnn.parameters():
        p.requires_grad = False
    
    cnn_params = sum(p.numel() for p in shared_cnn.parameters())
    print(f"  CNN frozen with {cnn_params:,} params")

    # Load data - ImageNet normalization
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    full_train_ds = datasets.CIFAR10(root="./data", train=True,
                                     download=True, transform=transform)
    full_test_ds = datasets.CIFAR10(root="./data", train=False,
                                    download=True, transform=transform)

    # Create stratified samples
    print("\nCreating stratified samples...")
    t0 = time.time()
    train_indices = stratified_sample_from_targets(full_train_ds, n_samples_per_class=1500, seed=42)
    test_indices = stratified_sample_from_targets(full_test_ds, n_samples_per_class=300, seed=42)
    print(f"  Sampling took {time.time()-t0:.2f}s")
    
    train_ds = Subset(full_train_ds, train_indices)
    test_ds = Subset(full_test_ds, test_indices)
    
    print(f"  Train samples: {len(train_ds)} (stratified, 1500 per class)")
    print(f"  Test samples:  {len(test_ds)} (stratified, 300 per class)")

    # Smaller batches for quantum
    train_loader = DataLoader(train_ds, batch_size=32,
                              shuffle=True, num_workers=0)
    test_loader = DataLoader(test_ds, batch_size=64,
                             shuffle=False, num_workers=0)

    seeds = [42, 123, 456]
    all_results = []

    for seed in seeds:
        print(f"\n{'=' * 70}")
        print(f"SEED {seed}")
        print("=" * 70)
        
        set_all_seeds(seed)

        print(f"\n  Training QuantumNet-CNN (NO entanglement, seed={seed})...")
        quantum_model = QuantumNetCNN()
        quantum_model.cnn_extractor = shared_cnn  # Use frozen CNN
        
        # Only optimize quantum head parameters
        quantum_opt = torch.optim.Adam(quantum_model.head.parameters(), lr=1e-3)
        
        result = train_with_early_stopping(
            quantum_model, train_loader, test_loader, quantum_opt, device,
            max_epochs=200, patience=10, name="QuantumNoEnt-CNN"
        )
        
        trainable_params = sum(p.numel() for p in quantum_model.head.parameters())
        total_params = sum(p.numel() for p in quantum_model.parameters())
        
        result["trainable_params"] = trainable_params
        result["total_params"] = total_params
        result["seed"] = seed
        
        all_results.append(result)

    # Summary
    print("\n" + "=" * 70)
    print("QUANTUM (NO ENTANGLEMENT) + CNN SUMMARY (CIFAR-10)")
    print("=" * 70)
    
    accs = [r["best_acc"] for r in all_results]
    times = [r["time"] for r in all_results]
    epochs = [r["epochs"] for r in all_results]
    trainable = all_results[0]["trainable_params"]
    total = all_results[0]["total_params"]
    
    print(f"\nAccuracy:          {np.mean(accs):.4f} ± {np.std(accs):.4f}")
    print(f"Time:              {np.mean(times):.1f}s ± {np.std(times):.1f}s")
    print(f"Epochs:            {np.mean(epochs):.1f} ± {np.std(epochs):.1f}")
    print(f"Trainable params:  {trainable:,}")
    print(f"Total params:      {total:,}")
    print(f"Frozen params:     {total - trainable:,}")
    
    print("\nPer-seed results:")
    for r in all_results:
        print(f"  Seed {r['seed']}: acc={r['best_acc']:.4f}, "
              f"time={r['time']:.1f}s, epochs={r['epochs']}")

    # Save results
    save_dict = {
        "results": all_results,
        "summary": {
            "mean_acc": np.mean(accs),
            "std_acc": np.std(accs),
            "mean_time": np.mean(times),
            "trainable_params": trainable,
            "total_params": total
        }
    }
    
    torch.save(save_dict, "quantum_noent_cnn_cifar10_results.pt")
    print(f"\n✓ Saved results to: quantum_noent_cnn_cifar10_results.pt")
    print("=" * 70)


if __name__ == "__main__":
    main()

✓ Using lightning.gpu device
Using PyTorch device: cuda
BLOCK 3: Quantum (NO Entanglement) Training with Frozen CNN (CIFAR-10)

Configuration:
  • Dataset: CIFAR-10 (32×32 RGB)
  • CNN Backbone: ResNet18 (FROZEN)
  • Batch size: 32
  • Patience: 10
  • Max epochs: 200
  • Seeds: [42, 123, 456]
  • Architecture: CNN(frozen) → 512 → 8 qubits (3 layers, NO ent) → 10
  • Quantum device: lightning.gpu

Loading frozen CNN from Block 1...
✓ Loaded frozen CNN state from realnet_cnn_cifar10_results.pt
  CNN frozen with 11,176,512 params

Creating stratified samples...
  Sampling took 0.00s
  Train samples: 15000 (stratified, 1500 per class)
  Test samples:  3000 (stratified, 300 per class)

SEED 42

  Training QuantumNet-CNN (NO entanglement, seed=42)...
  [QuantumNoEnt-CNN] Epoch 1/200
    Batch 460/469, samples: 14752
  [QuantumNoEnt-CNN] Epoch  1 | loss=2.1942 | test_acc=0.2683 | time=1263.8s
  [QuantumNoEnt-CNN] Epoch 2/200
    Batch 460/469, samples: 14752
  [QuantumNoEnt-CNN] Epoch  2 | l

# Quantum Ent CNN

In [5]:
"""
Block 4: Quantum (WITH Entanglement) Training with Frozen CNN (CIFAR-10)
=========================================================================
Loads frozen CNN from Block 1 and trains quantum head WITH entanglement.
Uses Lightning-GPU acceleration.

Requirements:
- realnet_cnn_cifar10_results.pt (from Block 1)
- pennylane, pennylane-lightning-gpu

Outputs:
- quantum_ent_cnn_cifar10_results.pt: Contains quantum (with ent) results
"""

import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from torchvision.models import resnet18, ResNet18_Weights

try:
    import pennylane as qml
    QUANTUM_DEVICE = "lightning.gpu"
    PENNYLANE_AVAILABLE = True
    print("✓ Using lightning.gpu device")
except ImportError:
    PENNYLANE_AVAILABLE = False
    print("✗ PennyLane not installed")
    exit(1)

device = torch.device("cuda")
print(f"Using PyTorch device: {device}")

# ============================================
# Comprehensive seed setting for reproducibility
# ============================================

def set_all_seeds(seed):
    """Set seeds for all RNG sources for reproducibility"""
    import random
    import numpy as np
    import torch
    
    # Python
    random.seed(seed)
    
    # NumPy
    np.random.seed(seed)
    
    # PyTorch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    
    # PyTorch deterministic mode
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # CuPy (if available)
    try:
        import cupy as cp
        cp.random.seed(seed)
    except ImportError:
        pass
    
    # PennyLane (if available)
    try:
        import pennylane as qml
        qml.numpy.random.seed(seed)
    except (ImportError, AttributeError):
        pass


# ============================================
# Frozen CNN Feature Extractor (from Block 1)
# ============================================

class FrozenCNNExtractor(nn.Module):
    """
    Frozen ResNet18 feature extractor.
    Outputs 512-dimensional features from penultimate layer.
    """
    def __init__(self):
        super().__init__()
        weights = ResNet18_Weights.IMAGENET1K_V1
        resnet = resnet18(weights=weights)
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        
        for param in self.features.parameters():
            param.requires_grad = False
        
        self.eval()
    
    def forward(self, x):
        with torch.no_grad():
            features = self.features(x)
            features = features.view(features.size(0), -1)
        return features


# ============================================
# Quantum Head (WITH Entanglement)
# ============================================

class QuantumHead(nn.Module):
    """
    VQC with 8 qubits, 3 layers, WITH entanglement → 10 classes.
    Uses Lightning acceleration (GPU).
    Maps 512 CNN features → 8 qubits.
    """
    def __init__(self, input_dim=512, n_qubits=8, n_layers=3, num_classes=10, device_name=None):
        super().__init__()
        self.n_qubits = n_qubits
        self.n_layers = n_layers

        # Map 512 features → n_qubits
        self.feature_select = nn.Linear(input_dim, n_qubits)

        # Use specified device or default
        if device_name is None:
            device_name = QUANTUM_DEVICE
        
        # Quantum device
        self.dev = qml.device(device_name, wires=n_qubits)

        # Use adjoint differentiation for lightning (much faster)
        diff_method = "adjoint" if "lightning" in device_name else "parameter-shift"

        @qml.qnode(self.dev, interface="torch", diff_method=diff_method)
        def quantum_circuit(inputs, weights):
            """
            Single-sample circuit WITH entanglement.
            inputs: (n_qubits,)
            weights: (n_layers, n_qubits, 2)
            """
            for layer in range(n_layers):
                # Data re-uploading
                for i in range(n_qubits):
                    qml.RY(inputs[i], wires=i)

                # Trainable rotations
                for i in range(n_qubits):
                    qml.RY(weights[layer, i, 0], wires=i)
                    qml.RZ(weights[layer, i, 1], wires=i)

                # ENTANGLEMENT: CNOT ring
                for i in range(n_qubits - 1):
                    qml.CNOT(wires=[i, i + 1])
                if n_qubits > 2:
                    qml.CNOT(wires=[n_qubits - 1, 0])

            # Measure more observables for richer output (12 measurements)
            measurements = []
            # Single-qubit Z measurements
            for i in range(n_qubits):
                measurements.append(qml.expval(qml.PauliZ(i)))
            # Two-qubit ZZ measurements (pairs)
            for i in range(0, n_qubits-1, 2):
                measurements.append(qml.expval(qml.PauliZ(i) @ qml.PauliZ(i+1)))
            
            return measurements

        self.quantum_circuit = quantum_circuit

        weight_shape = (n_layers, n_qubits, 2)
        self.q_weights = nn.Parameter(torch.randn(weight_shape) * 0.1)
        
        # Output layer: 12 measurements → 10 classes
        n_measurements = n_qubits + (n_qubits // 2)
        self.fc_out = nn.Linear(n_measurements, num_classes)

    def forward(self, x):
        """
        Process samples with progress tracking.
        x: (batch, 512) CNN features
        """
        batch_size = x.size(0)
        x = torch.tanh(self.feature_select(x))

        # Process in chunks for memory management
        chunk_size = 32
        quantum_outputs = []

        for start_idx in range(0, batch_size, chunk_size):
            end_idx = min(start_idx + chunk_size, batch_size)
            chunk = x[start_idx:end_idx]

            chunk_outputs = []
            for i in range(chunk.size(0)):
                q_raw = self.quantum_circuit(chunk[i], self.q_weights)
                if isinstance(q_raw, (list, tuple)):
                    q_out = torch.stack(q_raw)
                else:
                    q_out = q_raw
                chunk_outputs.append(q_out)

            quantum_outputs.extend(chunk_outputs)

        # Convert to tensor (cast to float32)
        quantum_outputs = torch.stack(quantum_outputs).float()
        quantum_outputs = quantum_outputs.to(self.fc_out.weight.dtype)

        output = self.fc_out(quantum_outputs)
        return output


class QuantumNetCNN(nn.Module):
    """Complete Quantum network: Frozen CNN + QuantumHead (with ent)"""
    def __init__(self):
        super().__init__()
        self.cnn_extractor = FrozenCNNExtractor()
        self.head = QuantumHead(input_dim=512, n_qubits=8, n_layers=3, num_classes=10)

    def forward(self, x):
        # CNN on GPU
        features = self.cnn_extractor(x)
        # Move to CPU for quantum processing
        features_cpu = features.cpu()
        return self.head(features_cpu)


# ============================================
# Training and Evaluation
# ============================================

def train_one_epoch(model, loader, optimizer, device, show_progress=True):
    model.train()
    total_loss = 0.0
    total_samples = 0

    for batch_idx, (x, y) in enumerate(loader):
        # Move images to GPU for CNN
        x = x.to(device)
        # Keep labels on CPU
        
        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        total_samples += x.size(0)

        if show_progress and batch_idx % 20 == 0:
            print(f"    Batch {batch_idx}/{len(loader)}, samples: {total_samples}", end="\r")

    if show_progress:
        print()
    return total_loss / len(loader.dataset)


def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            # Move images to GPU for CNN
            x = x.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    return correct / total if total > 0 else 0.0


def train_with_early_stopping(model, train_loader, test_loader, optimizer,
                              device, max_epochs=200, patience=10, name="Model"):
    best_acc = 0.0
    epochs_without_improvement = 0
    start = time.time()
    last_acc = 0.0

    for epoch in range(1, max_epochs + 1):
        print(f"  [{name}] Epoch {epoch}/{max_epochs}")
        loss = train_one_epoch(model, train_loader, optimizer, device, show_progress=True)
        acc = evaluate(model, test_loader, device)
        last_acc = acc

        elapsed = time.time() - start
        print(f"  [{name}] Epoch {epoch:2d} | loss={loss:.4f} "
              f"| test_acc={acc:.4f} | time={elapsed:.1f}s")

        if acc > best_acc:
            best_acc = acc
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print(f"  [{name}] Early stop at epoch {epoch} "
                  f"(no improvement for {patience} epochs)")
            break

    total_time = time.time() - start
    return {
        "best_acc": best_acc,
        "final_acc": last_acc,
        "time": total_time,
        "epochs": epoch,
    }


# ============================================
# Data utilities
# ============================================

def stratified_sample_from_targets(dataset, n_samples_per_class, seed=42):
    """
    Create a stratified sample with n_samples_per_class from each class.
    Uses dataset.targets directly - NO image loading during sampling.
    """
    if hasattr(dataset, 'targets'):
        targets = np.array(dataset.targets)
    else:
        targets = np.array([dataset[i][1] for i in range(len(dataset))])
    
    rng = np.random.RandomState(seed)
    
    sampled_indices = []
    for c in range(10):
        idx_c = np.where(targets == c)[0]
        k = min(n_samples_per_class, len(idx_c))
        selected = rng.choice(idx_c, size=k, replace=False).tolist()
        sampled_indices.extend(selected)
    
    rng.shuffle(sampled_indices)
    return sampled_indices


# ============================================
# Main
# ============================================

def main():
    print("=" * 70)
    print("BLOCK 4: Quantum (WITH Entanglement) Training with Frozen CNN (CIFAR-10)")
    print("=" * 70)
    print("\nConfiguration:")
    print("  • Dataset: CIFAR-10 (32×32 RGB)")
    print("  • CNN Backbone: ResNet18 (FROZEN)")
    print("  • Batch size: 32")
    print("  • Patience: 10")
    print("  • Max epochs: 200")
    print("  • Seeds: [42, 123, 456]")
    print("  • Architecture: CNN(frozen) → 512 → 8 qubits (3 layers, WITH ent) → 10")
    print(f"  • Quantum device: {QUANTUM_DEVICE}")
    print("=" * 70)

    # Load frozen CNN from Block 1
    print("\nLoading frozen CNN from Block 1...")
    try:
        realnet_data = torch.load("realnet_cnn_cifar10_results.pt", weights_only=False)
        frozen_cnn_state = realnet_data["frozen_cnn_state"]
        print("✓ Loaded frozen CNN state from realnet_cnn_cifar10_results.pt")
    except FileNotFoundError:
        print("✗ ERROR: realnet_cnn_cifar10_results.pt not found!")
        print("  Run Block 1 first to train RealNet and save frozen CNN.")
        return

    # Create frozen CNN
    shared_cnn = FrozenCNNExtractor().to(device)
    shared_cnn.load_state_dict(frozen_cnn_state)
    for p in shared_cnn.parameters():
        p.requires_grad = False
    
    cnn_params = sum(p.numel() for p in shared_cnn.parameters())
    print(f"  CNN frozen with {cnn_params:,} params")

    # Load data - ImageNet normalization
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    full_train_ds = datasets.CIFAR10(root="./data", train=True,
                                     download=True, transform=transform)
    full_test_ds = datasets.CIFAR10(root="./data", train=False,
                                    download=True, transform=transform)

    # Create stratified samples
    print("\nCreating stratified samples...")
    t0 = time.time()
    train_indices = stratified_sample_from_targets(full_train_ds, n_samples_per_class=1500, seed=42)
    test_indices = stratified_sample_from_targets(full_test_ds, n_samples_per_class=300, seed=42)
    print(f"  Sampling took {time.time()-t0:.2f}s")
    
    train_ds = Subset(full_train_ds, train_indices)
    test_ds = Subset(full_test_ds, test_indices)
    
    print(f"  Train samples: {len(train_ds)} (stratified, 1500 per class)")
    print(f"  Test samples:  {len(test_ds)} (stratified, 300 per class)")

    # Smaller batches for quantum
    train_loader = DataLoader(train_ds, batch_size=32,
                              shuffle=True, num_workers=0)
    test_loader = DataLoader(test_ds, batch_size=64,
                             shuffle=False, num_workers=0)

    seeds = [42, 123, 456]
    all_results = []

    for seed in seeds:
        print(f"\n{'=' * 70}")
        print(f"SEED {seed}")
        print("=" * 70)
        
        set_all_seeds(seed)

        print(f"\n  Training QuantumNet-CNN (WITH entanglement, seed={seed})...")
        quantum_model = QuantumNetCNN()
        quantum_model.cnn_extractor = shared_cnn  # Use frozen CNN
        
        # Only optimize quantum head parameters
        quantum_opt = torch.optim.Adam(quantum_model.head.parameters(), lr=1e-3)
        
        result = train_with_early_stopping(
            quantum_model, train_loader, test_loader, quantum_opt, device,
            max_epochs=200, patience=10, name="QuantumEnt-CNN"
        )
        
        trainable_params = sum(p.numel() for p in quantum_model.head.parameters())
        total_params = sum(p.numel() for p in quantum_model.parameters())
        
        result["trainable_params"] = trainable_params
        result["total_params"] = total_params
        result["seed"] = seed
        
        all_results.append(result)

    # Summary
    print("\n" + "=" * 70)
    print("QUANTUM (WITH ENTANGLEMENT) + CNN SUMMARY (CIFAR-10)")
    print("=" * 70)
    
    accs = [r["best_acc"] for r in all_results]
    times = [r["time"] for r in all_results]
    epochs = [r["epochs"] for r in all_results]
    trainable = all_results[0]["trainable_params"]
    total = all_results[0]["total_params"]
    
    print(f"\nAccuracy:          {np.mean(accs):.4f} ± {np.std(accs):.4f}")
    print(f"Time:              {np.mean(times):.1f}s ± {np.std(times):.1f}s")
    print(f"Epochs:            {np.mean(epochs):.1f} ± {np.std(epochs):.1f}")
    print(f"Trainable params:  {trainable:,}")
    print(f"Total params:      {total:,}")
    print(f"Frozen params:     {total - trainable:,}")
    
    print("\nPer-seed results:")
    for r in all_results:
        print(f"  Seed {r['seed']}: acc={r['best_acc']:.4f}, "
              f"time={r['time']:.1f}s, epochs={r['epochs']}")

    # Save results
    save_dict = {
        "results": all_results,
        "summary": {
            "mean_acc": np.mean(accs),
            "std_acc": np.std(accs),
            "mean_time": np.mean(times),
            "trainable_params": trainable,
            "total_params": total
        }
    }
    
    torch.save(save_dict, "quantum_ent_cnn_cifar10_results.pt")
    print(f"\n✓ Saved results to: quantum_ent_cnn_cifar10_results.pt")
    print("=" * 70)


if __name__ == "__main__":
    main()

✓ Using lightning.gpu device
Using PyTorch device: cuda
BLOCK 4: Quantum (WITH Entanglement) Training with Frozen CNN (CIFAR-10)

Configuration:
  • Dataset: CIFAR-10 (32×32 RGB)
  • CNN Backbone: ResNet18 (FROZEN)
  • Batch size: 32
  • Patience: 10
  • Max epochs: 200
  • Seeds: [42, 123, 456]
  • Architecture: CNN(frozen) → 512 → 8 qubits (3 layers, WITH ent) → 10
  • Quantum device: lightning.gpu

Loading frozen CNN from Block 1...
✓ Loaded frozen CNN state from realnet_cnn_cifar10_results.pt
  CNN frozen with 11,176,512 params

Creating stratified samples...
  Sampling took 0.00s
  Train samples: 15000 (stratified, 1500 per class)
  Test samples:  3000 (stratified, 300 per class)

SEED 42

  Training QuantumNet-CNN (WITH entanglement, seed=42)...
  [QuantumEnt-CNN] Epoch 1/200
    Batch 460/469, samples: 14752
  [QuantumEnt-CNN] Epoch  1 | loss=2.2425 | test_acc=0.2273 | time=1439.6s
  [QuantumEnt-CNN] Epoch 2/200
    Batch 460/469, samples: 14752
  [QuantumEnt-CNN] Epoch  2 | los

# Comparison

In [6]:
"""
Block 5: Aggregate Results and Comparative Analysis with CNN (CIFAR-10)
========================================================================
Loads results from Blocks 1-4 and performs comprehensive comparison.

Requirements:
- realnet_cnn_cifar10_results.pt (from Block 1)
- quatnet_cnn_cifar10_results.pt (from Block 2)
- quantum_noent_cnn_cifar10_results.pt (from Block 3)
- quantum_ent_cnn_cifar10_results.pt (from Block 4)

Outputs:
- Comprehensive comparison tables
- Statistical analysis
- Research question answers
"""

import numpy as np
import torch

def load_results():
    """Load all results files"""
    results = {}
    
    try:
        real_data = torch.load("realnet_cnn_cifar10_results.pt", weights_only=False)
        results["Real"] = real_data["results"]
        print("✓ Loaded RealNet-CNN (CIFAR-10) results")
    except FileNotFoundError:
        print("✗ realnet_cnn_cifar10_results.pt not found")
        return None
    
    try:
        quat_data = torch.load("quatnet_cnn_cifar10_results.pt", weights_only=False)
        results["Quat"] = quat_data["results"]
        print("✓ Loaded QuatNet-CNN (CIFAR-10) results")
    except FileNotFoundError:
        print("✗ quatnet_cnn_cifar10_results.pt not found")
        return None
    
    try:
        qno_data = torch.load("quantum_noent_cnn_cifar10_results.pt", weights_only=False)
        results["QNoEnt"] = qno_data["results"]
        print("✓ Loaded Quantum-NoEnt-CNN (CIFAR-10) results")
    except FileNotFoundError:
        print("⚠ quantum_noent_cnn_cifar10_results.pt not found (skipping)")
        results["QNoEnt"] = []
    
    try:
        qent_data = torch.load("quantum_ent_cnn_cifar10_results.pt", weights_only=False)
        results["QEnt"] = qent_data["results"]
        print("✓ Loaded Quantum-Ent-CNN (CIFAR-10) results")
    except FileNotFoundError:
        print("⚠ quantum_ent_cnn_cifar10_results.pt not found (skipping)")
        results["QEnt"] = []
    
    return results


def print_summary_table(results):
    """Print aggregated summary table"""
    print("\n" + "=" * 95)
    print("AGGREGATED RESULTS (CIFAR-10 + CNN) (mean ± std over 3 seeds)")
    print("=" * 95)
    print(f"{'Model':<15} {'Accuracy':<20} {'Time (s)':<20} {'Epochs':<15} {'Trainable Params':<20}")
    print("-" * 95)
    
    for name in ["Real", "Quat", "QNoEnt", "QEnt"]:
        if not results[name]:
            continue
            
        accs = [r["best_acc"] for r in results[name]]
        times = [r["time"] for r in results[name]]
        epochs = [r["epochs"] for r in results[name]]
        
        acc_str = f"{np.mean(accs):.4f} ± {np.std(accs):.4f}"
        time_str = f"{np.mean(times):.1f} ± {np.std(times):.1f}"
        epoch_str = f"{np.mean(epochs):.1f} ± {np.std(epochs):.1f}"
        
        trainable = results[name][0]["trainable_params"]
        param_str = f"{trainable:,}"
        
        print(f"{name:<15} {acc_str:<20} {time_str:<20} {epoch_str:<15} {param_str:<20}")
    
    print("-" * 95)
    print(f"{'Frozen CNN':<15} {'(shared)':<20} {'-':<20} {'-':<15} {'~11M (ResNet18)':<20}")
    print("=" * 95)


def print_per_seed_table(results):
    """Print detailed per-seed results"""
    print("\n" + "=" * 70)
    print("PER-SEED RESULTS (CIFAR-10 + CNN)")
    print("=" * 70)
    
    seeds = [42, 123, 456]
    
    for seed_idx, seed in enumerate(seeds):
        print(f"\nSeed {seed}:")
        print(f"{'Model':<15} {'Accuracy':<12} {'Time (s)':<12} {'Epochs':<10}")
        print("-" * 50)
        
        for name in ["Real", "Quat", "QNoEnt", "QEnt"]:
            if not results[name]:
                continue
            
            r = results[name][seed_idx]
            print(f"{name:<15} {r['best_acc']:<12.4f} {r['time']:<12.1f} {r['epochs']:<10}")


def comparative_analysis(results):
    """Perform comparative analysis answering research questions"""
    print("\n" + "=" * 95)
    print("COMPARATIVE ANALYSIS (CIFAR-10 + CNN): Answering Research Questions")
    print("=" * 95)
    
    real_accs = [r["best_acc"] for r in results["Real"]]
    quat_accs = [r["best_acc"] for r in results["Quat"]]
    
    def pct_gap(a, b):
        """Percentage point gap (a - b)"""
        return (np.mean(a) - np.mean(b)) * 100.0
    
    def retention(a, b):
        """Percentage of performance retained"""
        return (np.mean(a) / np.mean(b)) * 100.0
    
    print("\n1. QUATERNION vs REAL MLP (with CNN features):")
    print("   " + "-" * 75)
    print(f"   Real MLP accuracy:        {np.mean(real_accs):.4f} ± {np.std(real_accs):.4f}")
    print(f"   Quaternion accuracy:      {np.mean(quat_accs):.4f} ± {np.std(quat_accs):.4f}")
    print(f"   Gap:                      {pct_gap(real_accs, quat_accs):+.2f} percentage points")
    print(f"   Performance retention:    {retention(quat_accs, real_accs):.1f}%")
    print(f"\n   → Classical SU(2) (quaternions) captures {retention(quat_accs, real_accs):.1f}% of")
    print(f"     standard MLP performance with structured algebraic constraints on")
    print(f"     rich CNN features (512-D from ImageNet-pretrained ResNet18).")
    
    if results["QNoEnt"]:
        qno_accs = [r["best_acc"] for r in results["QNoEnt"]]
        
        print("\n2. QUANTUM (no entanglement) vs REAL MLP (with CNN features):")
        print("   " + "-" * 75)
        print(f"   Real MLP accuracy:         {np.mean(real_accs):.4f} ± {np.std(real_accs):.4f}")
        print(f"   Quantum (no ent) accuracy: {np.mean(qno_accs):.4f} ± {np.std(qno_accs):.4f}")
        print(f"   Gap:                       {pct_gap(real_accs, qno_accs):+.2f} percentage points")
        print(f"   Performance retention:     {retention(qno_accs, real_accs):.1f}%")
        print(f"\n   → Quantum circuits WITHOUT entanglement capture {retention(qno_accs, real_accs):.1f}%")
        print(f"     of MLP performance on rich CNN features, suggesting limited benefit")
        print(f"     over classical rotation gates for CIFAR-10.")
    
    if results["QEnt"]:
        qent_accs = [r["best_acc"] for r in results["QEnt"]]
        
        print("\n3. QUANTUM (with entanglement) vs REAL MLP (with CNN features):")
        print("   " + "-" * 75)
        print(f"   Real MLP accuracy:         {np.mean(real_accs):.4f} ± {np.std(real_accs):.4f}")
        print(f"   Quantum (with ent) accuracy:{np.mean(qent_accs):.4f} ± {np.std(qent_accs):.4f}")
        print(f"   Gap:                       {pct_gap(real_accs, qent_accs):+.2f} percentage points")
        print(f"   Performance retention:     {retention(qent_accs, real_accs):.1f}%")
        print(f"\n   → Quantum circuits WITH entanglement capture {retention(qent_accs, real_accs):.1f}%")
        print(f"     of MLP performance on CNN features.")
    
    if results["QNoEnt"] and results["QEnt"]:
        print("\n4. ENTANGLEMENT EFFECT (with CNN features):")
        print("   " + "-" * 75)
        print(f"   Quantum (no ent) accuracy: {np.mean(qno_accs):.4f} ± {np.std(qno_accs):.4f}")
        print(f"   Quantum (with ent) accuracy:{np.mean(qent_accs):.4f} ± {np.std(qent_accs):.4f}")
        print(f"   Improvement:               {pct_gap(qent_accs, qno_accs):+.2f} percentage points")
        
        if np.mean(qent_accs) > np.mean(qno_accs):
            improvement_pct = ((np.mean(qent_accs) - np.mean(qno_accs)) / np.mean(qno_accs)) * 100
            print(f"\n   → Entanglement IMPROVES performance by {improvement_pct:.1f}%")
            print(f"     (relative improvement over non-entangled baseline)")
            print(f"     → Pattern consistent with paper findings (~0.6pp gain on FashionMNIST)")
        else:
            print(f"\n   → Entanglement DOES NOT improve performance on CNN-powered CIFAR-10")
            print(f"     → Rich features may saturate quantum advantage")
    
    if results["QNoEnt"]:
        print("\n5. QUATERNION vs QUANTUM (no entanglement) - Core Research Question:")
        print("   " + "-" * 75)
        print(f"   Quaternion accuracy:       {np.mean(quat_accs):.4f} ± {np.std(quat_accs):.4f}")
        print(f"   Quantum (no ent) accuracy: {np.mean(qno_accs):.4f} ± {np.std(qno_accs):.4f}")
        print(f"   Gap:                       {pct_gap(quat_accs, qno_accs):+.2f} percentage points")
        
        gap_abs = abs(pct_gap(quat_accs, qno_accs))
        if gap_abs < 2.0:
            print(f"\n   → Quaternion networks (classical SU(2)) CLOSELY APPROXIMATE quantum circuits")
            print(f"     without entanglement (quantum SU(2)). Gap < 2 percentage points.")
            print(f"\n   → This supports the hypothesis that classical quaternion algebra can serve")
            print(f"     as an effective implementation of SU(2) geometry on rich CNN features.")
        elif np.mean(quat_accs) > np.mean(qno_accs):
            print(f"\n   → Quaternion networks OUTPERFORM quantum circuits without entanglement")
            print(f"     by {gap_abs:.2f} percentage points on CNN features.")
            print(f"\n   → Classical SU(2) (quaternions) SUFFICES for learning tasks addressable")
            print(f"     by product-state quantum circuits, even with high-quality CNN features.")
            print(f"     → Consistent with paper's findings on FashionMNIST (~6pp advantage).")
        else:
            print(f"\n   → Quantum circuits without entanglement OUTPERFORM quaternions")
            print(f"     by {gap_abs:.2f} percentage points on CNN features.")
            print(f"     → Different pattern from paper's FashionMNIST results")
    
    if results["QEnt"]:
        print("\n6. QUATERNION vs QUANTUM (with entanglement) on CNN features:")
        print("   " + "-" * 75)
        print(f"   Quaternion accuracy:       {np.mean(quat_accs):.4f} ± {np.std(quat_accs):.4f}")
        print(f"   Quantum (with ent) accuracy:{np.mean(qent_accs):.4f} ± {np.std(qent_accs):.4f}")
        print(f"   Gap:                       {pct_gap(quat_accs, qent_accs):+.2f} percentage points")
        
        if np.mean(qent_accs) > np.mean(quat_accs):
            print(f"\n   → Entangled quantum circuits outperform quaternions, demonstrating")
            print(f"     the value of quantum correlations beyond classical SU(2) rotations")
            print(f"     when operating on rich CNN features.")
        else:
            print(f"\n   → Quaternions match or exceed entangled quantum performance,")
            print(f"     suggesting entanglement provides limited benefit on CNN features.")
            print(f"     → Measurement/optimization challenges may dominate on complex features.")


def key_takeaways(results):
    """Summarize key takeaways for paper"""
    print("\n" + "=" * 95)
    print("KEY TAKEAWAYS FOR PAPER (CIFAR-10 + CNN)")
    print("=" * 95)
    
    real_accs = [r["best_acc"] for r in results["Real"]]
    quat_accs = [r["best_acc"] for r in results["Quat"]]
    
    def retention(a, b):
        return (np.mean(a) / np.mean(b)) * 100.0
    
    print("\n✓ EMPIRICAL FINDINGS:")
    print(f"  1. Quaternion networks achieve {retention(quat_accs, real_accs):.1f}% of Real MLP performance")
    print(f"     with structured SU(2) algebraic constraints on rich CNN features (512-D)")
    print(f"     → Absolute accuracy: ~{np.mean(quat_accs)*100:.1f}% on CIFAR-10")
    print(f"     → Tests quaternion geometry on high-quality ImageNet-pretrained features")
    
    if results["QNoEnt"]:
        qno_accs = [r["best_acc"] for r in results["QNoEnt"]]
        gap = (np.mean(quat_accs) - np.mean(qno_accs)) * 100.0
        gap_abs = abs(gap)
        print(f"\n  2. Quaternions vs Quantum (no ent): {gap:+.2f} percentage point gap")
        if gap_abs < 2.0:
            print(f"     → Classical SU(2) effectively approximates quantum SU(2) without entanglement")
            print(f"     → Pattern holds even with rich CNN features (not just learned bottlenecks)")
        elif gap > 0:
            print(f"     → Classical SU(2) OUTPERFORMS quantum product-state circuits")
            print(f"       (consistent with paper: quaternions >> quantum on FashionMNIST)")
            print(f"     → Pattern generalizes from simple bottleneck (16-D) to CNN features (512-D)")
        else:
            print(f"     → Quantum product-state circuits outperform classical SU(2)")
            print(f"       (different pattern from FashionMNIST - may reflect CNN feature quality)")
    
    if results["QEnt"] and results["QNoEnt"]:
        qent_accs = [r["best_acc"] for r in results["QEnt"]]
        qno_accs = [r["best_acc"] for r in results["QNoEnt"]]
        ent_improvement_pp = (np.mean(qent_accs) - np.mean(qno_accs)) * 100.0
        ent_improvement_pct = (ent_improvement_pp / (np.mean(qno_accs) * 100.0)) * 100.0
        
        if ent_improvement_pp > 0.3:
            print(f"\n  3. Entanglement improves quantum performance by {ent_improvement_pp:.2f}pp")
            print(f"     ({ent_improvement_pct:.1f}% relative improvement)")
            print(f"     → Demonstrates measurable value of quantum correlations on CNN features")
            if abs(ent_improvement_pp - 0.6) < 0.3:
                print(f"     → Magnitude consistent with paper (~0.6pp gain on FashionMNIST)")
        else:
            print(f"\n  3. Entanglement provides minimal/no benefit ({ent_improvement_pp:+.2f}pp)")
            print(f"     → Rich CNN features may saturate quantum advantage")
            print(f"     → Different pattern from FashionMNIST (0.6pp gain)")
    
    print("\n✓ IMPLICATIONS:")
    print("  • CNN-based comparison tests whether paper's findings hold with STRONG features")
    print("  • Moving from learned bottleneck (16-D) → pretrained CNN (512-D) tests robustness")
    print("  • If quaternions still match/beat quantum (like FashionMNIST):")
    print("    → Strengthens claim that classical SU(2) suffices for product-state VQCs")
    print("    → Shows pattern is NOT artifact of weak feature extraction")
    print("  • If entanglement benefit persists (~0.5-1pp):")
    print("    → Confirms entanglement as quantum boundary across feature qualities")
    print("  • Higher absolute accuracies (vs. bottleneck version) validate CNN upgrade")
    
    print("\n✓ COMPUTATIONAL EFFICIENCY:")
    real_time = np.mean([r["time"] for r in results["Real"]])
    quat_time = np.mean([r["time"] for r in results["Quat"]])
    print(f"  • Real MLP:   {real_time:.1f}s (baseline)")
    print(f"  • Quaternion: {quat_time:.1f}s ({quat_time/real_time:.2f}x Real)")
    
    if results["QNoEnt"]:
        qno_time = np.mean([r["time"] for r in results["QNoEnt"]])
        print(f"  • Quantum (no ent): {qno_time:.1f}s ({qno_time/real_time:.1f}x Real)")
        print(f"    → Even with lightning.gpu, quantum is {qno_time/quat_time:.1f}x slower than quaternions")
        print(f"    → Speedup gap narrows with CNN (GPU-accelerated forward pass)")
    
    print("\n✓ CONTROLLED EXPERIMENTAL DESIGN:")
    print("  • All models use IDENTICAL frozen ResNet18 CNN (~11M params)")
    print("  • CNN pretrained on ImageNet → high-quality 512-D features")
    print("  • Performance differences reflect HEAD ARCHITECTURE ONLY")
    print("  • Validates paper's methodology with industrial-strength feature extraction")
    
    print("\n✓ COMPARISON TO LEARNED BOTTLENECK:")
    print("  • Learned bottleneck (3072→16): Tests quaternion geometry with minimal features")
    print("  • CNN features (ResNet18→512): Tests quaternion geometry with rich features")
    print("  • Together: Shows whether findings depend on feature quality or are robust")
    
    print("=" * 95)


def main():
    print("=" * 95)
    print("BLOCK 5: AGGREGATE RESULTS AND COMPARATIVE ANALYSIS (CIFAR-10 + CNN)")
    print("=" * 95)
    
    results = load_results()
    
    if results is None:
        print("\n✗ ERROR: Required results files not found.")
        print("Run Blocks 1 and 2 at minimum (Real and Quat) with CNN on CIFAR-10.")
        return
    
    # Summary table
    print_summary_table(results)
    
    # Per-seed details
    print_per_seed_table(results)
    
    # Comparative analysis
    comparative_analysis(results)
    
    # Key takeaways
    key_takeaways(results)
    
    print("\n" + "=" * 95)
    print("ANALYSIS COMPLETE (CIFAR-10 + CNN)")
    print("=" * 95)
    print("\nAll results saved in:")
    print("  • realnet_cnn_cifar10_results.pt")
    print("  • quatnet_cnn_cifar10_results.pt")
    if results["QNoEnt"]:
        print("  • quantum_noent_cnn_cifar10_results.pt")
    if results["QEnt"]:
        print("  • quantum_ent_cnn_cifar10_results.pt")
    print("\n" + "=" * 95)
    print("COMPARISON TO PAPER (FashionMNIST with learned bottleneck)")
    print("=" * 95)
    print("\nPaper's key findings on FashionMNIST (learned 16-D bottleneck):")
    print("  • Quaternions ≈ Real MLP (within 0.2pp)")
    print("  • Quaternions >> Quantum-NoEnt (by ~2.4pp)")
    print("  • Entanglement helps modestly (0.6pp gain)")
    print("\nCIFAR-10 + CNN tests whether these patterns hold with:")
    print("  • Harder task (color images, more complex)")
    print("  • STRONG features (512-D ImageNet-pretrained CNN vs 16-D learned)")
    print("  • Same architectural comparison (controlled design)")
    print("  • Industrial-strength preprocessing (ResNet18)")
    print("\nKey question: Do quaternions still match/beat quantum without rich features?")
    print("  → If YES: Classical SU(2) sufficiency is robust finding")
    print("  → If NO:  Feature quality matters for quaternion vs quantum comparison")
    print("=" * 95)


if __name__ == "__main__":
    main()

BLOCK 5: AGGREGATE RESULTS AND COMPARATIVE ANALYSIS (CIFAR-10 + CNN)
✓ Loaded RealNet-CNN (CIFAR-10) results
✓ Loaded QuatNet-CNN (CIFAR-10) results
✓ Loaded Quantum-NoEnt-CNN (CIFAR-10) results
✓ Loaded Quantum-Ent-CNN (CIFAR-10) results

AGGREGATED RESULTS (CIFAR-10 + CNN) (mean ± std over 3 seeds)
Model           Accuracy             Time (s)             Epochs          Trainable Params    
-----------------------------------------------------------------------------------------------
Real            0.4713 ± 0.0070      14.2 ± 2.5           22.3 ± 4.0      66,954              
Quat            0.4521 ± 0.0038      36.1 ± 0.2           40.0 ± 0.0      17,832              
QNoEnt          0.4171 ± 0.0094      64739.2 ± 13399.7    48.7 ± 10.9     4,282               
QEnt            0.3246 ± 0.0318      66114.3 ± 5305.9     42.7 ± 3.9      4,282               
-----------------------------------------------------------------------------------------------
Frozen CNN      (shared)       