In [None]:
"""
Block 4: Quantum (WITH Entanglement) Training
==============================================
Loads frozen preprocessor from Block 1 and trains quantum head WITH entanglement.
Uses Lightning-GPU acceleration.

Requirements:
- realnet_results.pt (from Block 1)
- pennylane, pennylane-lightning-gpu

Outputs:
- quantum_ent_results.pt: Contains quantum (with ent) results
"""

import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from collections import defaultdict

try:
    import pennylane as qml
    QUANTUM_DEVICE = "lightning.gpu"
    PENNYLANE_AVAILABLE = True
    print("✓ Using lightning.gpu device")
except ImportError:
    PENNYLANE_AVAILABLE = False
    print("✗ PennyLane not installed")
    exit(1)

device = torch.device("cuda")
print(f"Using PyTorch device: {device}")

# ============================================
# Comprehensive seed setting for reproducibility
# ============================================

def set_all_seeds(seed):
    """Set seeds for all RNG sources for reproducibility"""
    import random
    import numpy as np
    import torch
    
    # Python
    random.seed(seed)
    
    # NumPy
    np.random.seed(seed)
    
    # PyTorch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    
    # PyTorch deterministic mode
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # CuPy (if available)
    try:
        import cupy as cp
        cp.random.seed(seed)
    except ImportError:
        pass
    
    # PennyLane (if available)
    try:
        import pennylane as qml
        qml.numpy.random.seed(seed)
    except (ImportError, AttributeError):
        pass

# ============================================
# Shared Preprocessor (must match Block 1)
# ============================================

class SharedPreprocessor(nn.Module):
    """Shared classical feature extractor: 784 → 16"""
    def __init__(self, input_dim=784, bottleneck_dim=16):
        super().__init__()
        self.fc = nn.Linear(input_dim, bottleneck_dim)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.tanh(self.fc(x))
        return x


# ============================================
# Quantum Head (WITH Entanglement)
# ============================================

class QuantumHead(nn.Module):
    """
    VQC with 4 qubits, 3 layers, WITH entanglement → 10 classes.
    Uses Lightning acceleration (GPU).
    """
    def __init__(self, n_qubits=4, n_layers=3, num_classes=10, device_name=None):
        super().__init__()
        self.n_qubits = n_qubits
        self.n_layers = n_layers

        # Map 16 features → n_qubits
        self.feature_select = nn.Linear(16, n_qubits)

        # Use specified device or default
        if device_name is None:
            device_name = QUANTUM_DEVICE
        
        # Quantum device
        self.dev = qml.device(device_name, wires=n_qubits)

        # Use adjoint differentiation for lightning (much faster)
        diff_method = "adjoint" if "lightning" in device_name else "parameter-shift"

        @qml.qnode(self.dev, interface="torch", diff_method=diff_method)
        def quantum_circuit(inputs, weights):
            """
            Single-sample circuit WITH entanglement.
            inputs: (n_qubits,)
            weights: (n_layers, n_qubits, 2)
            """
            for layer in range(n_layers):
                # Data re-uploading
                for i in range(n_qubits):
                    qml.RY(inputs[i], wires=i)

                # Trainable rotations
                for i in range(n_qubits):
                    qml.RY(weights[layer, i, 0], wires=i)
                    qml.RZ(weights[layer, i, 1], wires=i)

                # ENTANGLEMENT: CNOT ring
                for i in range(n_qubits - 1):
                    qml.CNOT(wires=[i, i + 1])
                if n_qubits > 2:
                    qml.CNOT(wires=[n_qubits - 1, 0])

            return [
                qml.expval(qml.PauliZ(0)),
                qml.expval(qml.PauliZ(1)),
                qml.expval(qml.PauliZ(2)),
                qml.expval(qml.PauliZ(3)),
                qml.expval(qml.PauliZ(0) @ qml.PauliZ(1)),
                qml.expval(qml.PauliZ(2) @ qml.PauliZ(3))
            ]

        self.quantum_circuit = quantum_circuit

        weight_shape = (n_layers, n_qubits, 2)
        self.q_weights = nn.Parameter(torch.randn(weight_shape) * 0.1)
        self.fc_out = nn.Linear(6, num_classes)

    def forward(self, x):
        """
        Process samples with progress tracking.
        x: (batch, 16) real bottleneck features
        """
        batch_size = x.size(0)
        x = torch.tanh(self.feature_select(x))

        # Process in chunks for memory management
        chunk_size = 32
        quantum_outputs = []

        for start_idx in range(0, batch_size, chunk_size):
            end_idx = min(start_idx + chunk_size, batch_size)
            chunk = x[start_idx:end_idx]

            chunk_outputs = []
            for i in range(chunk.size(0)):
                q_raw = self.quantum_circuit(chunk[i], self.q_weights)
                if isinstance(q_raw, (list, tuple)):
                    q_out = torch.stack(q_raw)
                else:
                    q_out = q_raw
                chunk_outputs.append(q_out)

            quantum_outputs.extend(chunk_outputs)

        # Convert to tensor (cast to float32)
        quantum_outputs = torch.stack(quantum_outputs).float()
        quantum_outputs = quantum_outputs.to(self.fc_out.weight.dtype)

        output = self.fc_out(quantum_outputs)
        return output


class QuantumNet(nn.Module):
    """Complete Quantum network: Preprocessor + QuantumHead (with ent)"""
    def __init__(self):
        super().__init__()
        self.preprocessor = SharedPreprocessor(784, 16)
        self.head = QuantumHead(n_qubits=4, n_layers=3, num_classes=10)

    def forward(self, x):
        features = self.preprocessor(x)
        return self.head(features)


# ============================================
# Training and Evaluation
# ============================================

def train_one_epoch(model, loader, optimizer, show_progress=True):
    model.train()
    total_loss = 0.0
    total_samples = 0

    for batch_idx, (x, y) in enumerate(loader):
        # Keep data on CPU for quantum models
        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        total_samples += x.size(0)

        if show_progress and batch_idx % 20 == 0:
            print(f"    Batch {batch_idx}/{len(loader)}, samples: {total_samples}", end="\r")

    if show_progress:
        print()
    return total_loss / len(loader.dataset)


def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            # Keep data on CPU
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    return correct / total if total > 0 else 0.0


def train_with_early_stopping(model, train_loader, test_loader, optimizer,
                              max_epochs=200, patience=10, name="Model"):
    best_acc = 0.0
    epochs_without_improvement = 0
    start = time.time()
    last_acc = 0.0

    for epoch in range(1, max_epochs + 1):
        print(f"  [{name}] Epoch {epoch}/{max_epochs}")
        loss = train_one_epoch(model, train_loader, optimizer, show_progress=True)
        acc = evaluate(model, test_loader)
        last_acc = acc

        elapsed = time.time() - start
        print(f"  [{name}] Epoch {epoch:2d} | loss={loss:.4f} "
              f"| test_acc={acc:.4f} | time={elapsed:.1f}s")

        if acc > best_acc:
            best_acc = acc
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print(f"  [{name}] Early stop at epoch {epoch} "
                  f"(no improvement for {patience} epochs)")
            break

    total_time = time.time() - start
    return {
        "best_acc": best_acc,
        "final_acc": last_acc,
        "time": total_time,
        "epochs": epoch,
    }


# ============================================
# Data utilities
# ============================================

def stratified_sample(dataset, n_samples_per_class):
    """
    Create a stratified sample with n_samples_per_class from each class.
    Maintains class balance.
    """
    # Group indices by class
    class_indices = defaultdict(list)
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)
    
    # Sample from each class
    sampled_indices = []
    for class_label in sorted(class_indices.keys()):
        indices = class_indices[class_label]
        # Use fixed random seed for reproducibility
        rng = np.random.RandomState(42)
        selected = rng.choice(indices, size=min(n_samples_per_class, len(indices)), 
                             replace=False)
        sampled_indices.extend(selected)
    
    return sampled_indices


# ============================================
# Main
# ============================================

def main():
    print("=" * 70)
    print("BLOCK 4: Quantum (WITH Entanglement) Training")
    print("=" * 70)
    print("\nConfiguration:")
    print("  • Batch size: 32")
    print("  • Patience: 10")
    print("  • Max epochs: 200")
    print("  • Seeds: [42, 123, 456]")
    print("  • Architecture: [frozen 784→16] → 4 qubits (3 layers, WITH ent) → 10")
    print(f"  • Quantum device: {QUANTUM_DEVICE}")
    print("=" * 70)

    # Load frozen preprocessor from Block 1
    print("\nLoading frozen preprocessor from Block 1...")
    try:
        realnet_data = torch.load("realnet_results.pt", weights_only=False)
        preprocessor_state = realnet_data["preprocessor_state"]
        print("✓ Loaded preprocessor state from realnet_results.pt")
    except FileNotFoundError:
        print("✗ ERROR: realnet_results.pt not found!")
        print("  Run Block 1 first to train RealNet and save preprocessor.")
        return

    # Create frozen preprocessor (CPU for quantum)
    shared_preprocessor = SharedPreprocessor(784, 16)
    shared_preprocessor.load_state_dict(preprocessor_state)
    for p in shared_preprocessor.parameters():
        p.requires_grad = False
    
    preprocessor_params = sum(p.numel() for p in shared_preprocessor.parameters())
    print(f"  Preprocessor frozen with {preprocessor_params:,} params")

    # Load data
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    full_train_ds = datasets.MNIST(root="./data", train=True,
                                   download=True, transform=transform)
    full_test_ds = datasets.MNIST(root="./data", train=False,
                                  download=True, transform=transform)

    # Create stratified samples (same as Block 1)
    print("\nCreating stratified samples...")
    train_indices = stratified_sample(full_train_ds, n_samples_per_class=1500)
    test_indices = stratified_sample(full_test_ds, n_samples_per_class=300)
    
    train_ds = Subset(full_train_ds, train_indices)
    test_ds = Subset(full_test_ds, test_indices)
    
    print(f"  Train samples: {len(train_ds)} (stratified, 1500 per class)")
    print(f"  Test samples:  {len(test_ds)} (stratified, 300 per class)")

    # Smaller batches for quantum
    train_loader = DataLoader(train_ds, batch_size=32,
                              shuffle=True, num_workers=0)
    test_loader = DataLoader(test_ds, batch_size=64,
                             shuffle=False, num_workers=0)

    seeds = [42, 123, 456]
    all_results = []

    for seed in seeds:
        print(f"\n{'=' * 70}")
        print(f"SEED {seed}")
        print("=" * 70)
        
        set_all_seeds(seed)

        print(f"\n  Training QuantumNet (WITH entanglement, seed={seed})...")
        quantum_model = QuantumNet()
        quantum_model.preprocessor = shared_preprocessor  # Use frozen preprocessor
        
        quantum_opt = torch.optim.Adam(quantum_model.head.parameters(), lr=1e-3)
        
        result = train_with_early_stopping(
            quantum_model, train_loader, test_loader, quantum_opt,
            max_epochs=200, patience=10, name="QuantumEnt"
        )
        
        result["trainable_params"] = sum(p.numel() for p in quantum_model.parameters()
                                         if p.requires_grad)
        result["total_params"] = sum(p.numel() for p in quantum_model.parameters())
        result["seed"] = seed
        
        all_results.append(result)

    # Summary
    print("\n" + "=" * 70)
    print("QUANTUM (WITH ENTANGLEMENT) SUMMARY")
    print("=" * 70)
    
    accs = [r["best_acc"] for r in all_results]
    times = [r["time"] for r in all_results]
    epochs = [r["epochs"] for r in all_results]
    trainable = all_results[0]["trainable_params"]
    total = all_results[0]["total_params"]
    
    print(f"\nAccuracy:   {np.mean(accs):.4f} ± {np.std(accs):.4f}")
    print(f"Time:       {np.mean(times):.1f}s ± {np.std(times):.1f}s")
    print(f"Epochs:     {np.mean(epochs):.1f} ± {np.std(epochs):.1f}")
    print(f"Parameters: {trainable:,} trainable (head), {total:,} total")
    
    print("\nPer-seed results:")
    for r in all_results:
        print(f"  Seed {r['seed']}: acc={r['best_acc']:.4f}, "
              f"time={r['time']:.1f}s, epochs={r['epochs']}")

    # Save results
    save_dict = {
        "results": all_results,
        "summary": {
            "mean_acc": np.mean(accs),
            "std_acc": np.std(accs),
            "mean_time": np.mean(times),
            "trainable_params": trainable,
            "total_params": total
        }
    }
    
    torch.save(save_dict, "quantum_ent_results.pt")
    print(f"\n✓ Saved results to: quantum_ent_results.pt")
    print("=" * 70)


if __name__ == "__main__":
    main()

✓ Using lightning.gpu device
Using PyTorch device: cuda
BLOCK 4: Quantum (WITH Entanglement) Training

Configuration:
  • Batch size: 32
  • Patience: 10
  • Max epochs: 200
  • Seeds: [42, 123, 456]
  • Architecture: [frozen 784→16] → 4 qubits (3 layers, WITH ent) → 10
  • Quantum device: lightning.gpu

Loading frozen preprocessor from Block 1...
✓ Loaded preprocessor state from realnet_results.pt
  Preprocessor frozen with 12,560 params

Creating stratified samples...
  Train samples: 15000 (stratified, 1500 per class)
  Test samples:  3000 (stratified, 300 per class)

SEED 42

  Training QuantumNet (WITH entanglement, seed=42)...
  [QuantumEnt] Epoch 1/200
    Batch 40/469, samples: 1312