# Load QM9

In [4]:
"""
QM9 Download via PyTorch Geometric + Custom Feature Extraction
Compatible with PyG 2.x API
"""

import torch
import numpy as np
from torch_geometric.datasets import QM9
from rdkit import Chem
from rdkit.Chem import AllChem
from tqdm import tqdm
import os

def download_qm9_via_pyg(root='./data'):
    """
    Download QM9 using PyTorch Geometric (handles mirrors/auth)
    Returns: PyG dataset object
    """
    print("Downloading QM9 via PyTorch Geometric...")
    print("(This handles authentication and mirrors automatically)")
    
    # No target parameter in newer PyG - it returns all properties
    dataset = QM9(root=root)
    
    print(f"✓ Downloaded {len(dataset)} molecules")
    print(f"✓ Properties available: {dataset.data.y.shape[1]} targets per molecule")
    print(f"   (Index 4 = HOMO-LUMO gap)")
    
    return dataset

def pyg_to_rdkit_mol(data):
    """
    Convert PyG Data object back to RDKit molecule
    (For consistent feature extraction)
    """
    # PyG stores atom types as atomic numbers (Z)
    atom_types = data.z.cpu().numpy()
    
    # Create RDKit molecule
    mol = Chem.RWMol()
    for z in atom_types:
        mol.AddAtom(Chem.Atom(int(z)))
    
    # Add bonds from edge_index
    edge_index = data.edge_index.cpu().numpy()
    bonds_added = set()
    
    for i in range(edge_index.shape[1]):
        src, dst = int(edge_index[0, i]), int(edge_index[1, i])
        
        # Add each bond only once (edges are bidirectional in PyG)
        bond_id = tuple(sorted([src, dst]))
        if bond_id not in bonds_added:
            mol.AddBond(src, dst, Chem.BondType.SINGLE)
            bonds_added.add(bond_id)
    
    # Add 3D coordinates if available
    if hasattr(data, 'pos'):
        pos = data.pos.cpu().numpy()
        conf = Chem.Conformer(len(atom_types))
        for i, xyz in enumerate(pos):
            conf.SetAtomPosition(i, tuple(xyz.tolist()))
        mol.AddConformer(conf)
    
    mol = mol.GetMol()
    
    # Sanitize
    try:
        Chem.SanitizeMol(mol)
        return mol
    except:
        return None

def mol_to_morgan_fp(mol, fp_size=512, radius=2):
    """
    Generate Morgan fingerprint (ECFP-like)
    """
    if mol is None:
        return None
    
    try:
        fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=fp_size)
        arr = np.zeros(fp_size, dtype=np.float32)
        Chem.DataStructs.ConvertToNumpyArray(fp, arr)
        return arr
    except:
        return None

def preprocess_qm9_custom(fp_size=512, output_file='qm9_preprocessed_512d.pt'):
    """
    Complete QM9 preprocessing:
    1. Download via PyG (reliable)
    2. Extract features manually (Morgan FP)
    3. Create balanced 3-class labels from HOMO-LUMO gap
    """
    
    # Step 1: Download via PyG
    dataset = download_qm9_via_pyg(root='./data/qm9')
    
    # Step 2: Extract HOMO-LUMO gap values
    # In QM9, properties are stored as:
    # Index 0: dipole moment
    # Index 1: isotropic polarizability
    # Index 2: HOMO energy
    # Index 3: LUMO energy  
    # Index 4: HOMO-LUMO gap ← THIS ONE
    # Index 5: electronic spatial extent
    # ... etc
    
    print("\nExtracting HOMO-LUMO gap values...")
    gaps = []
    valid_indices = []
    
    for i, data in enumerate(tqdm(dataset)):
        try:
            # Extract gap (index 4 in y tensor)
            gap = float(data.y[0, 4])  # Shape is [1, num_properties]
            gaps.append(gap)
            valid_indices.append(i)
        except:
            # Skip molecules with missing properties
            continue
    
    gaps = np.array(gaps)
    print(f"✓ Extracted {len(gaps)} valid gap values")
    
    # Step 3: Create balanced class labels
    q33, q67 = np.percentile(gaps, [33, 67])
    labels = np.zeros(len(gaps), dtype=np.int64)
    labels[gaps > q33] = 1
    labels[gaps > q67] = 2
    
    print(f"\nClass distribution (HOMO-LUMO gap):")
    print(f"  Class 0 (Low,  ≤{q33:.3f} eV): {np.sum(labels==0)} molecules")
    print(f"  Class 1 (Med,  {q33:.3f}-{q67:.3f} eV): {np.sum(labels==1)} molecules")
    print(f"  Class 2 (High, >{q67:.3f} eV): {np.sum(labels==2)} molecules")
    
    # Step 4: Generate Morgan fingerprints
    print(f"\nGenerating {fp_size}-D Morgan fingerprints...")
    features = []
    final_labels = []
    final_gaps = []
    
    for idx, i in enumerate(tqdm(valid_indices)):
        data = dataset[i]
        
        # Convert PyG Data to RDKit molecule
        mol = pyg_to_rdkit_mol(data)
        
        if mol is None:
            continue
        
        # Generate fingerprint
        fp = mol_to_morgan_fp(mol, fp_size=fp_size)
        
        if fp is None:
            continue
        
        features.append(fp)
        final_labels.append(labels[idx])
        final_gaps.append(gaps[idx])
    
    features = np.stack(features)
    final_labels = np.array(final_labels)
    final_gaps = np.array(final_gaps)
    
    print(f"\n✓ Generated features: {features.shape}")
    
    # Step 5: Save preprocessed data
    torch.save({
        'features': torch.from_numpy(features),
        'labels': torch.from_numpy(final_labels),
        'gap_values': final_gaps,
        'class_thresholds': (q33, q67),
        'feature_type': f'Morgan_fp{fp_size}_radius2',
        'n_samples': len(features),
        'n_classes': 3
    }, output_file)
    
    print(f"\n✓ Saved preprocessed data to {output_file}")
    print(f"  Total samples: {len(features)}")
    print(f"  Feature dim: {fp_size}")
    print(f"  Classes: 3 (balanced)")
    
    # Verify class balance
    print(f"\nFinal class counts:")
    for c in range(3):
        count = np.sum(final_labels == c)
        pct = 100 * count / len(final_labels)
        print(f"  Class {c}: {count} ({pct:.1f}%)")
    
    return features, final_labels

if __name__ == "__main__":
    preprocess_qm9_custom(fp_size=512, output_file='qm9_preprocessed_512d.pt')

Downloading QM9 via PyTorch Geometric...
(This handles authentication and mirrors automatically)


Downloading https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/molnet_publish/qm9.zip
Extracting data/qm9/raw/qm9.zip
Downloading https://ndownloader.figshare.com/files/3195404
Processing...
100%|█████████████████████████████████████████████████████████████████████████| 133885/133885 [01:24<00:00, 1593.34it/s]
Done!
  print(f"✓ Properties available: {dataset.data.y.shape[1]} targets per molecule")


✓ Downloaded 130831 molecules
✓ Properties available: 19 targets per molecule
   (Index 4 = HOMO-LUMO gap)

Extracting HOMO-LUMO gap values...


100%|████████████████████████████████████████████████████████████████████████| 130831/130831 [00:07<00:00, 16862.40it/s]


✓ Extracted 130831 valid gap values

Class distribution (HOMO-LUMO gap):
  Class 0 (Low,  ≤6.131 eV): 43177 molecules
  Class 1 (Med,  6.131-7.494 eV): 44534 molecules
  Class 2 (High, >7.494 eV): 43120 molecules

Generating 512-D Morgan fingerprints...


100%|█████████████████████████████████████████████████████████████████████████| 130831/130831 [00:15<00:00, 8618.27it/s]



✓ Generated features: (129227, 512)

✓ Saved preprocessed data to qm9_preprocessed_512d.pt
  Total samples: 129227
  Feature dim: 512
  Classes: 3 (balanced)

Final class counts:
  Class 0: 42728 (33.1%)
  Class 1: 44187 (34.2%)
  Class 2: 42312 (32.7%)


# RealNet

In [17]:
"""
Block 1: RealNet Training on QM9
=================================
Trains real-valued MLP baseline on QM9 HOMO-LUMO gap classification.
Creates frozen feature extractor (identity - features already frozen).

Requirements:
- qm9_preprocessed_512d.pt (from Block 0)

Outputs:
- realnet_qm9_results.pt: Contains RealNet results
"""
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, Subset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ============================================
# Seed setting
# ============================================

def set_all_seeds(seed):
    """Set seeds for all RNG sources for reproducibility"""
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# ============================================
# RealNet Model
# ============================================

class RealNet(nn.Module):
    """
    Real-valued MLP classifier for QM9
    512-D features → 64 hidden → 3 classes
    """
    def __init__(self, input_dim=512, hidden_dim=64, num_classes=3):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)

# ============================================
# Training and Evaluation
# ============================================

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
    
    return total_loss / len(loader.dataset)

def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    
    return correct / total if total > 0 else 0.0

def train_with_early_stopping(model, train_loader, test_loader, optimizer,
                              device, max_epochs=200, patience=10, name="Model"):
    best_acc = 0.0
    epochs_without_improvement = 0
    start = time.time()
    last_acc = 0.0
    
    for epoch in range(1, max_epochs + 1):
        loss = train_one_epoch(model, train_loader, optimizer, device)
        acc = evaluate(model, test_loader, device)
        last_acc = acc
        
        elapsed = time.time() - start
        print(f"  [{name}] Epoch {epoch:2d} | loss={loss:.4f} | "
              f"test_acc={acc:.4f} | time={elapsed:.1f}s")
        
        if acc > best_acc:
            best_acc = acc
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1
        
        if epochs_without_improvement >= patience:
            print(f"  [{name}] Early stop at epoch {epoch} "
                  f"(no improvement for {patience} epochs)")
            break
    
    total_time = time.time() - start
    return {
        "best_acc": best_acc,
        "final_acc": last_acc,
        "time": total_time,
        "epochs": epoch,
    }

# ============================================
# Data utilities
# ============================================

def stratified_sample(labels, n_per_class, seed=42):
    """Create stratified sample with n_per_class from each class"""
    rng = np.random.RandomState(seed)
    n_classes = int(labels.max().item()) + 1
    
    sampled_indices = []
    for c in range(n_classes):
        idx_c = torch.where(labels == c)[0].numpy()
        k = min(n_per_class, len(idx_c))
        selected = rng.choice(idx_c, size=k, replace=False).tolist()
        sampled_indices.extend(selected)
    
    rng.shuffle(sampled_indices)
    return sampled_indices

# ============================================
# Main
# ============================================

def main():
    print("=" * 70)
    print("BLOCK 1: RealNet Training on QM9")
    print("=" * 70)
    print("\nConfiguration:")
    print("  • Dataset: QM9 HOMO-LUMO gap classification")
    print("  • Features: 512-D Morgan fingerprints (frozen)")
    print("  • Classes: 3 (Low/Med/High gap)")
    print("  • Architecture: 512 → 64 → 3")
    print("  • Batch size: 128 (train), 256 (test)")
    print("  • Max epochs: 200")
    print("  • Patience: 10")
    print("  • Seeds: [42, 123, 456]")
    print("=" * 70)
    
    # Load preprocessed data
    print("\nLoading preprocessed QM9 data...")
    data = torch.load("qm9_preprocessed_512d.pt", weights_only=False)
    features = data['features']
    labels = data['labels']
    
    print(f"  Loaded {len(features)} molecules")
    print(f"  Feature dim: {features.shape[1]}")
    print(f"  Classes: {len(torch.unique(labels))}")
    
    # Create stratified train/test splits
    print("\nCreating stratified samples...")
    train_indices = stratified_sample(labels, n_per_class=1500, seed=42)
    test_indices = stratified_sample(labels, n_per_class=300, seed=43)
    
    # Remove overlap
    test_indices = [i for i in test_indices if i not in train_indices]
    
    print(f"  Train samples: {len(train_indices)} (1500 per class)")
    print(f"  Test samples: {len(test_indices)} (300 per class)")
    
    # Create datasets
    full_dataset = TensorDataset(features, labels)
    train_ds = Subset(full_dataset, train_indices)
    test_ds = Subset(full_dataset, test_indices)
    
    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
    test_loader = DataLoader(test_ds, batch_size=256, shuffle=False)
    
    seeds = [42, 123, 456]
    all_results = []
    
    for seed in seeds:
        print(f"\n{'=' * 70}")
        print(f"SEED {seed}")
        print("=" * 70)
        
        set_all_seeds(seed)
        
        print(f"\n  Training RealNet (seed={seed})...")
        model = RealNet(input_dim=512, hidden_dim=64, num_classes=3).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        
        result = train_with_early_stopping(
            model, train_loader, test_loader, optimizer, device,
            max_epochs=200, patience=10, name="RealNet"
        )
        
        trainable_params = sum(p.numel() for p in model.parameters())
        result["trainable_params"] = trainable_params
        result["seed"] = seed
        
        all_results.append(result)
    
    # Summary
    print("\n" + "=" * 70)
    print("REALNET QM9 SUMMARY")
    print("=" * 70)
    
    accs = [r["best_acc"] for r in all_results]
    times = [r["time"] for r in all_results]
    epochs = [r["epochs"] for r in all_results]
    params = all_results[0]["trainable_params"]
    
    print(f"\nAccuracy:     {np.mean(accs):.4f} ± {np.std(accs):.4f}")
    print(f"Time:         {np.mean(times):.1f}s ± {np.std(times):.1f}s")
    print(f"Epochs:       {np.mean(epochs):.1f} ± {np.std(epochs):.1f}")
    print(f"Parameters:   {params:,}")
    
    print("\nPer-seed results:")
    for r in all_results:
        print(f"  Seed {r['seed']}: acc={r['best_acc']:.4f}, "
              f"time={r['time']:.1f}s, epochs={r['epochs']}")
    
    # Save results
    save_dict = {
        "results": all_results,
        "summary": {
            "mean_acc": np.mean(accs),
            "std_acc": np.std(accs),
            "mean_time": np.mean(times),
            "trainable_params": params
        },
        "train_indices": train_indices,
        "test_indices": test_indices
    }
    
    torch.save(save_dict, "realnet_qm9_results.pt")
    print(f"\n✓ Saved results to: realnet_qm9_results.pt")
    print("=" * 70)

if __name__ == "__main__":
    main()

Using device: cuda
BLOCK 1: RealNet Training on QM9

Configuration:
  • Dataset: QM9 HOMO-LUMO gap classification
  • Features: 512-D Morgan fingerprints (frozen)
  • Classes: 3 (Low/Med/High gap)
  • Architecture: 512 → 64 → 3
  • Batch size: 128 (train), 256 (test)
  • Max epochs: 200
  • Patience: 10
  • Seeds: [42, 123, 456]

Loading preprocessed QM9 data...
  Loaded 129227 molecules
  Feature dim: 512
  Classes: 3

Creating stratified samples...
  Train samples: 4500 (1500 per class)
  Test samples: 875 (300 per class)

SEED 42

  Training RealNet (seed=42)...
  [RealNet] Epoch  1 | loss=1.0426 | test_acc=0.6491 | time=0.1s
  [RealNet] Epoch  2 | loss=0.8262 | test_acc=0.6754 | time=0.1s
  [RealNet] Epoch  3 | loss=0.6712 | test_acc=0.7120 | time=0.2s
  [RealNet] Epoch  4 | loss=0.5893 | test_acc=0.7177 | time=0.2s
  [RealNet] Epoch  5 | loss=0.5453 | test_acc=0.7154 | time=0.3s
  [RealNet] Epoch  6 | loss=0.5134 | test_acc=0.7211 | time=0.3s
  [RealNet] Epoch  7 | loss=0.4938 | t

# QuatNet

In [21]:
"""
Block 2: QuatNet Training on QM9
=================================
Trains quaternion-valued classifier on QM9 HOMO-LUMO gap classification.
Uses frozen molecular fingerprints from Block 0.

Requirements:
- qm9_preprocessed_512d.pt (from Block 0)
- realnet_qm9_results.pt (from Block 1) - for train/test indices

Outputs:
- quatnet_qm9_results.pt: Contains QuatNet results
"""

import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, Subset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ============================================
# Seed setting
# ============================================

def set_all_seeds(seed):
    """Set seeds for all RNG sources for reproducibility"""
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# ============================================
# Quaternion Layer Implementation
# ============================================

class QuaternionLinear(nn.Module):
    """
    Quaternion linear layer with Hamilton product (VECTORIZED)
    Input: (batch, n_in, 4) - n_in quaternions
    Output: (batch, n_out, 4) - n_out quaternions
    """
    def __init__(self, in_features, out_features):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Weights: (out_features, in_features, 4)
        self.weight = nn.Parameter(torch.randn(out_features, in_features, 4) * 0.1)
        self.bias = nn.Parameter(torch.zeros(out_features, 4))
        
        # Initialize as unit quaternions
        with torch.no_grad():
            self._normalize_weights()
    
    def _normalize_weights(self):
        """Normalize weights to unit quaternions"""
        norms = torch.sqrt((self.weight ** 2).sum(dim=2, keepdim=True))
        self.weight.data = self.weight.data / (norms + 1e-8)
    
    def hamilton_product_vectorized(self, q1, q2):
        """
        Vectorized Hamilton product supporting arbitrary leading dimensions
        q1, q2: (..., 4) tensors [w, x, y, z]
        Returns: (..., 4) tensor
        """
        w1, x1, y1, z1 = q1[..., 0], q1[..., 1], q1[..., 2], q1[..., 3]
        w2, x2, y2, z2 = q2[..., 0], q2[..., 1], q2[..., 2], q2[..., 3]
        
        w = w1*w2 - x1*x2 - y1*y2 - z1*z2
        x = w1*x2 + x1*w2 + y1*z2 - z1*y2
        y = w1*y2 - x1*z2 + y1*w2 + z1*x2
        z = w1*z2 + x1*y2 - y1*x2 + z1*w2
        
        return torch.stack([w, x, y, z], dim=-1)
    
    def forward(self, x):
        """
        Vectorized forward pass
        x: (batch, in_features, 4)
        Returns: (batch, out_features, 4)
        """
        # Normalize weights
        self._normalize_weights()
        
        batch_size = x.size(0)
        
        # Expand dimensions for broadcasting
        # x: (batch, 1, in_features, 4)
        # weight: (1, out_features, in_features, 4)
        x_expanded = x.unsqueeze(1)  # (batch, 1, in_features, 4)
        w_expanded = self.weight.unsqueeze(0)  # (1, out_features, in_features, 4)
        
        # Vectorized Hamilton product
        # Computes all products at once: (batch, out_features, in_features, 4)
        products = self.hamilton_product_vectorized(w_expanded, x_expanded)
        
        # Sum over input features
        output = products.sum(dim=2)  # (batch, out_features, 4)
        
        # Add bias
        output = output + self.bias.unsqueeze(0)  # (batch, out_features, 4)
        
        # Normalize outputs
        norms = torch.sqrt((output ** 2).sum(dim=2, keepdim=True))
        output = output / (norms + 1e-8)
        
        return output


class QuatNet(nn.Module):
    """
    Quaternion classifier for QM9 with MATCHED capacity to RealNet
    512-D → 64 quats → 3 quats → 3 classes (~33K params, same as RealNet)
    """
    def __init__(self, input_dim=512, num_classes=3):
        super().__init__()
        
        # input_dim must be divisible by 4
        assert input_dim % 4 == 0, "input_dim must be divisible by 4"
        
        self.n_input_quats = input_dim // 4  # 128 quaternions
        
        # Match RealNet structure: 512 → 64 → 3
        # In quaternion space: 128 quats → 64 quats → 3 quats
        self.quat1 = QuaternionLinear(self.n_input_quats, 64)  # 128×64×4 = 32,768
        self.quat2 = QuaternionLinear(64, num_classes)         # 64×3×4 = 768
        # Total: ~33,536 params (matches RealNet's 33,027)
    
    def forward(self, x):
        """
        x: (batch, 512) real features
        Returns: (batch, 3) class logits
        """
        # Reshape to quaternions: (batch, 128, 4)
        batch_size = x.size(0)
        x = x.view(batch_size, self.n_input_quats, 4)
        
        # First quaternion layer
        x = self.quat1(x)  # (batch, 64, 4)
        x = torch.tanh(x)
        
        # Second quaternion layer
        x = self.quat2(x)  # (batch, 3, 4)
        
        # Extract real part for classification
        logits = x[..., 0]  # (batch, 3)
        
        return logits

# ============================================
# Training and Evaluation
# ============================================

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
    
    return total_loss / len(loader.dataset)

def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    
    return correct / total if total > 0 else 0.0

def train_with_early_stopping(model, train_loader, test_loader, optimizer,
                              device, max_epochs=200, patience=10, name="Model"):
    best_acc = 0.0
    epochs_without_improvement = 0
    start = time.time()
    last_acc = 0.0
    
    for epoch in range(1, max_epochs + 1):
        loss = train_one_epoch(model, train_loader, optimizer, device)
        acc = evaluate(model, test_loader, device)
        last_acc = acc
        
        elapsed = time.time() - start
        print(f"  [{name}] Epoch {epoch:2d} | loss={loss:.4f} | "
              f"test_acc={acc:.4f} | time={elapsed:.1f}s")
        
        if acc > best_acc:
            best_acc = acc
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1
        
        if epochs_without_improvement >= patience:
            print(f"  [{name}] Early stop at epoch {epoch} "
                  f"(no improvement for {patience} epochs)")
            break
    
    total_time = time.time() - start
    return {
        "best_acc": best_acc,
        "final_acc": last_acc,
        "time": total_time,
        "epochs": epoch,
    }

# ============================================
# Main
# ============================================

def main():
    print("=" * 70)
    print("BLOCK 2: QuatNet Training on QM9 (MATCHED CAPACITY)")
    print("=" * 70)
    print("\nConfiguration:")
    print("  • Dataset: QM9 HOMO-LUMO gap classification")
    print("  • Features: 512-D Morgan fingerprints (frozen)")
    print("  • Classes: 3 (Low/Med/High gap)")
    print("  • Architecture: 128 quats → 64 quats → 3 quats")
    print("  • Target params: ~33K (matching RealNet)")
    print("  • Batch size: 128 (train), 256 (test)")
    print("  • Max epochs: 200")
    print("  • Patience: 10")
    print("  • Seeds: [42, 123, 456]")
    print("=" * 70)
    
    # Load preprocessed data
    print("\nLoading preprocessed QM9 data...")
    data = torch.load("qm9_preprocessed_512d.pt", weights_only=False)
    features = data['features']
    labels = data['labels']
    
    print(f"  Loaded {len(features)} molecules")
    print(f"  Feature dim: {features.shape[1]}")
    
    # Load train/test indices from Block 1
    print("\nLoading train/test split from Block 1...")
    realnet_data = torch.load("realnet_qm9_results.pt", weights_only=False)
    train_indices = realnet_data['train_indices']
    test_indices = realnet_data['test_indices']
    
    print(f"  Train samples: {len(train_indices)}")
    print(f"  Test samples: {len(test_indices)}")
    
    # Create datasets
    full_dataset = TensorDataset(features, labels)
    train_ds = Subset(full_dataset, train_indices)
    test_ds = Subset(full_dataset, test_indices)
    
    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
    test_loader = DataLoader(test_ds, batch_size=256, shuffle=False)
    
    seeds = [42, 123, 456]
    all_results = []
    
    for seed in seeds:
        print(f"\n{'=' * 70}")
        print(f"SEED {seed}")
        print("=" * 70)
        
        set_all_seeds(seed)
        
        print(f"\n  Training QuatNet (matched capacity, seed={seed})...")
        model = QuatNet(input_dim=512, num_classes=3).to(device)
        
        # Verify parameter count
        trainable_params = sum(p.numel() for p in model.parameters())
        print(f"  QuatNet parameters: {trainable_params:,} (RealNet: 33,027)")
        
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        
        result = train_with_early_stopping(
            model, train_loader, test_loader, optimizer, device,
            max_epochs=200, patience=10, name="QuatNet"
        )
        
        result["trainable_params"] = trainable_params
        result["seed"] = seed
        
        all_results.append(result)
    
    # Summary
    print("\n" + "=" * 70)
    print("QUATNET QM9 SUMMARY (MATCHED CAPACITY)")
    print("=" * 70)
    
    accs = [r["best_acc"] for r in all_results]
    times = [r["time"] for r in all_results]
    epochs = [r["epochs"] for r in all_results]
    params = all_results[0]["trainable_params"]
    
    print(f"\nAccuracy:     {np.mean(accs):.4f} ± {np.std(accs):.4f}")
    print(f"Time:         {np.mean(times):.1f}s ± {np.std(times):.1f}s")
    print(f"Epochs:       {np.mean(epochs):.1f} ± {np.std(epochs):.1f}")
    print(f"Parameters:   {params:,}")
    
    print("\nPer-seed results:")
    for r in all_results:
        print(f"  Seed {r['seed']}: acc={r['best_acc']:.4f}, "
              f"time={r['time']:.1f}s, epochs={r['epochs']}")
    
    # Compare to RealNet
    realnet_acc = 0.7280  # From Block 1
    quat_acc = np.mean(accs)
    gap = quat_acc - realnet_acc
    
    print(f"\n{'=' * 70}")
    print("COMPARISON TO REALNET (MATCHED CAPACITY)")
    print("=" * 70)
    print(f"RealNet:  72.80% (33,027 params)")
    print(f"QuatNet:  {quat_acc*100:.2f}% ({params:,} params)")
    print(f"Gap:      {gap*100:+.2f}pp")
    
    if abs(gap) < 0.01:
        print("\n✓ QuatNet MATCHES RealNet with fair comparison!")
        print("  Conclusion: Earlier degradation was parameter starvation")
    else:
        print(f"\n✗ QuatNet still underperforms by {abs(gap)*100:.2f}pp")
        print("  Conclusion: Geometric mismatch confirmed with matched capacity")
    
    # Save results
    save_dict = {
        "results": all_results,
        "summary": {
            "mean_acc": np.mean(accs),
            "std_acc": np.std(accs),
            "mean_time": np.mean(times),
            "trainable_params": params
        }
    }
    
    torch.save(save_dict, "quatnet_qm9_results.pt")
    print(f"\n✓ Saved results to: quatnet_qm9_results.pt")
    print("=" * 70)

if __name__ == "__main__":
    main()




Using device: cuda
BLOCK 2: QuatNet Training on QM9 (MATCHED CAPACITY)

Configuration:
  • Dataset: QM9 HOMO-LUMO gap classification
  • Features: 512-D Morgan fingerprints (frozen)
  • Classes: 3 (Low/Med/High gap)
  • Architecture: 128 quats → 64 quats → 3 quats
  • Target params: ~33K (matching RealNet)
  • Batch size: 128 (train), 256 (test)
  • Max epochs: 200
  • Patience: 10
  • Seeds: [42, 123, 456]

Loading preprocessed QM9 data...
  Loaded 129227 molecules
  Feature dim: 512

Loading train/test split from Block 1...
  Train samples: 4500
  Test samples: 875

SEED 42

  Training QuatNet (matched capacity, seed=42)...
  QuatNet parameters: 33,804 (RealNet: 33,027)
  [QuatNet] Epoch  1 | loss=1.1132 | test_acc=0.3886 | time=0.1s
  [QuatNet] Epoch  2 | loss=1.0424 | test_acc=0.4560 | time=0.3s
  [QuatNet] Epoch  3 | loss=0.9892 | test_acc=0.4949 | time=0.4s
  [QuatNet] Epoch  4 | loss=0.9448 | test_acc=0.5269 | time=0.6s
  [QuatNet] Epoch  5 | loss=0.9072 | test_acc=0.5543 | time

# Quantum No Entanglement

In [13]:
"""
Block 3: Quantum (No Entanglement) Training on QM9
===================================================
Trains quantum classifier WITHOUT entanglement on QM9 HOMO-LUMO gap.
Uses frozen molecular fingerprints from Block 0.

Requirements:
- qm9_preprocessed_512d.pt (from Block 0)
- realnet_qm9_results.pt (from Block 1) - for train/test indices
- pennylane, pennylane-lightning-gpu

Outputs:
- quantum_noent_qm9_results.pt: Contains quantum (no ent) results
"""

import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, Subset

try:
    import pennylane as qml
    QUANTUM_DEVICE = "lightning.gpu"
    PENNYLANE_AVAILABLE = True
    print("✓ Using lightning.gpu device")
except ImportError:
    PENNYLANE_AVAILABLE = False
    print("✗ PennyLane not installed")
    exit(1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using PyTorch device: {device}")

# ============================================
# Seed setting
# ============================================

def set_all_seeds(seed):
    """Set seeds for all RNG sources for reproducibility"""
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    try:
        import cupy as cp
        cp.random.seed(seed)
    except ImportError:
        pass
    
    try:
        qml.numpy.random.seed(seed)
    except (ImportError, AttributeError):
        pass

# ============================================
# Quantum Head (No Entanglement)
# ============================================

class QuantumHead(nn.Module):
    """
    VQC with 4 qubits, 3 layers, NO entanglement → 3 classes
    Uses Lightning acceleration (GPU)
    Maps 512 features → 4 qubits
    """
    def __init__(self, input_dim=512, n_qubits=4, n_layers=3, num_classes=3):
        super().__init__()
        self.n_qubits = n_qubits
        self.n_layers = n_layers
        
        # Map 512 features → n_qubits
        self.feature_select = nn.Linear(input_dim, n_qubits)
        
        # Quantum device
        self.dev = qml.device(QUANTUM_DEVICE, wires=n_qubits)
        
        # Use adjoint differentiation for lightning
        @qml.qnode(self.dev, interface="torch", diff_method="adjoint")
        def quantum_circuit(inputs, weights):
            """
            Single-sample circuit WITHOUT entanglement
            inputs: (n_qubits,)
            weights: (n_layers, n_qubits, 2)
            """
            for layer in range(n_layers):
                # Data re-uploading
                for i in range(n_qubits):
                    qml.RY(inputs[i], wires=i)
                
                # Trainable rotations
                for i in range(n_qubits):
                    qml.RY(weights[layer, i, 0], wires=i)
                    qml.RZ(weights[layer, i, 1], wires=i)
                
                # NO ENTANGLEMENT
            
            # Measurements
            measurements = []
            # Single-qubit Z
            for i in range(n_qubits):
                measurements.append(qml.expval(qml.PauliZ(i)))
            # Two-qubit ZZ (pairs)
            for i in range(0, n_qubits-1, 2):
                measurements.append(qml.expval(qml.PauliZ(i) @ qml.PauliZ(i+1)))
            
            return measurements
        
        self.quantum_circuit = quantum_circuit
        
        weight_shape = (n_layers, n_qubits, 2)
        self.q_weights = nn.Parameter(torch.randn(weight_shape) * 0.1)
        
        # Output: 6 measurements → 3 classes
        n_measurements = n_qubits + (n_qubits // 2)
        self.fc_out = nn.Linear(n_measurements, num_classes)
    
    def forward(self, x):
        """
        x: (batch, 512) features
        Returns: (batch, 3) logits
        """
        batch_size = x.size(0)
        x = torch.tanh(self.feature_select(x))
        
        # Process samples one by one
        quantum_outputs = []
        for i in range(batch_size):
            q_raw = self.quantum_circuit(x[i], self.q_weights)
            if isinstance(q_raw, (list, tuple)):
                q_out = torch.stack(q_raw)
            else:
                q_out = q_raw
            quantum_outputs.append(q_out)
        
        quantum_outputs = torch.stack(quantum_outputs).float()
        quantum_outputs = quantum_outputs.to(self.fc_out.weight.device)
        
        output = self.fc_out(quantum_outputs)
        return output

# ============================================
# Training and Evaluation
# ============================================

def train_one_epoch(model, loader, optimizer, device, show_progress=True):
    model.train()
    total_loss = 0.0
    
    for batch_idx, (x, y) in enumerate(loader):
        x = x.to(device)
        # y stays on CPU for quantum model
        
        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        
        if show_progress and batch_idx % 20 == 0:
            print(f"    Batch {batch_idx}/{len(loader)}", end="\r")
    
    if show_progress:
        print()
    return total_loss / len(loader.dataset)

def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    
    return correct / total if total > 0 else 0.0

def train_with_early_stopping(model, train_loader, test_loader, optimizer,
                              device, max_epochs=200, patience=10, name="Model"):
    best_acc = 0.0
    epochs_without_improvement = 0
    start = time.time()
    last_acc = 0.0
    
    for epoch in range(1, max_epochs + 1):
        print(f"  [{name}] Epoch {epoch}/{max_epochs}")
        loss = train_one_epoch(model, train_loader, optimizer, device, show_progress=True)
        acc = evaluate(model, test_loader, device)
        last_acc = acc
        
        elapsed = time.time() - start
        print(f"  [{name}] Epoch {epoch:2d} | loss={loss:.4f} | "
              f"test_acc={acc:.4f} | time={elapsed:.1f}s")
        
        if acc > best_acc:
            best_acc = acc
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1
        
        if epochs_without_improvement >= patience:
            print(f"  [{name}] Early stop at epoch {epoch} "
                  f"(no improvement for {patience} epochs)")
            break
    
    total_time = time.time() - start
    return {
        "best_acc": best_acc,
        "final_acc": last_acc,
        "time": total_time,
        "epochs": epoch,
    }

# ============================================
# Main
# ============================================

def main():
    print("=" * 70)
    print("BLOCK 3: Quantum (NO Entanglement) Training on QM9")
    print("=" * 70)
    print("\nConfiguration:")
    print("  • Dataset: QM9 HOMO-LUMO gap classification")
    print("  • Features: 512-D Morgan fingerprints (frozen)")
    print("  • Classes: 3 (Low/Med/High gap)")
    print("  • Architecture: 512 → 4 qubits (3 layers, NO ent) → 6 meas → 3")
    print("  • Batch size: 32 (train), 64 (test)")
    print("  • Max epochs: 200")
    print("  • Patience: 10")
    print("  • Seeds: [42, 123, 456]")
    print(f"  • Quantum device: {QUANTUM_DEVICE}")
    print("=" * 70)
    
    # Load preprocessed data
    print("\nLoading preprocessed QM9 data...")
    data = torch.load("qm9_preprocessed_512d.pt", weights_only=False)
    features = data['features']
    labels = data['labels']
    
    print(f"  Loaded {len(features)} molecules")
    print(f"  Feature dim: {features.shape[1]}")
    
    # Load train/test indices from Block 1
    print("\nLoading train/test split from Block 1...")
    realnet_data = torch.load("realnet_qm9_results.pt")
    train_indices = realnet_data['train_indices']
    test_indices = realnet_data['test_indices']
    
    print(f"  Train samples: {len(train_indices)}")
    print(f"  Test samples: {len(test_indices)}")
    
    # Create datasets
    full_dataset = TensorDataset(features, labels)
    train_ds = Subset(full_dataset, train_indices)
    test_ds = Subset(full_dataset, test_indices)
    
    # Smaller batches for quantum
    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=0)
    
    seeds = [42, 123, 456]
    all_results = []
    
    for seed in seeds:
        print(f"\n{'=' * 70}")
        print(f"SEED {seed}")
        print("=" * 70)
        
        set_all_seeds(seed)
        
        print(f"\n  Training QuantumNet (NO entanglement, seed={seed})...")
        model = QuantumHead(input_dim=512, n_qubits=4, n_layers=3, num_classes=3)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        
        result = train_with_early_stopping(
            model, train_loader, test_loader, optimizer, device,
            max_epochs=200, patience=10, name="QuantumNoEnt"
        )
        
        trainable_params = sum(p.numel() for p in model.parameters())
        result["trainable_params"] = trainable_params
        result["seed"] = seed
        
        all_results.append(result)
    
    # Summary
    print("\n" + "=" * 70)
    print("QUANTUM (NO ENTANGLEMENT) QM9 SUMMARY")
    print("=" * 70)
    
    accs = [r["best_acc"] for r in all_results]
    times = [r["time"] for r in all_results]
    epochs = [r["epochs"] for r in all_results]
    params = all_results[0]["trainable_params"]
    
    print(f"\nAccuracy:     {np.mean(accs):.4f} ± {np.std(accs):.4f}")
    print(f"Time:         {np.mean(times):.1f}s ± {np.std(times):.1f}s")
    print(f"Epochs:       {np.mean(epochs):.1f} ± {np.std(epochs):.1f}")
    print(f"Parameters:   {params:,}")
    
    print("\nPer-seed results:")
    for r in all_results:
        print(f"  Seed {r['seed']}: acc={r['best_acc']:.4f}, "
              f"time={r['time']:.1f}s, epochs={r['epochs']}")
    
    # Save results
    save_dict = {
        "results": all_results,
        "summary": {
            "mean_acc": np.mean(accs),
            "std_acc": np.std(accs),
            "mean_time": np.mean(times),
            "trainable_params": params
        }
    }
    
    torch.save(save_dict, "quantum_noent_qm9_results.pt")
    print(f"\n✓ Saved results to: quantum_noent_qm9_results.pt")
    print("=" * 70)

if __name__ == "__main__":
    main()

✓ Using lightning.gpu device
Using PyTorch device: cuda
BLOCK 3: Quantum (NO Entanglement) Training on QM9

Configuration:
  • Dataset: QM9 HOMO-LUMO gap classification
  • Features: 512-D Morgan fingerprints (frozen)
  • Classes: 3 (Low/Med/High gap)
  • Architecture: 512 → 4 qubits (3 layers, NO ent) → 6 meas → 3
  • Batch size: 32 (train), 64 (test)
  • Max epochs: 200
  • Patience: 10
  • Seeds: [42, 123, 456]
  • Quantum device: lightning.gpu

Loading preprocessed QM9 data...
  Loaded 129227 molecules
  Feature dim: 512

Loading train/test split from Block 1...


UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL numpy.core.multiarray.scalar was not an allowed global by default. Please use `torch.serialization.add_safe_globals([numpy.core.multiarray.scalar])` or the `torch.serialization.safe_globals([numpy.core.multiarray.scalar])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

# Quantum Entanglement

In [None]:
"""
Block 4: Quantum (WITH Entanglement) Training on QM9
=====================================================
Trains quantum classifier WITH entanglement on QM9 HOMO-LUMO gap.
Uses frozen molecular fingerprints from Block 0.

Requirements:
- qm9_preprocessed_512d.pt (from Block 0)
- realnet_qm9_results.pt (from Block 1) - for train/test indices
- pennylane, pennylane-lightning-gpu

Outputs:
- quantum_ent_qm9_results.pt: Contains quantum (with ent) results
"""

import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, Subset

try:
    import pennylane as qml
    QUANTUM_DEVICE = "lightning.gpu"
    PENNYLANE_AVAILABLE = True
    print("✓ Using lightning.gpu device")
except ImportError:
    PENNYLANE_AVAILABLE = False
    print("✗ PennyLane not installed")
    exit(1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using PyTorch device: {device}")

# ============================================
# Seed setting
# ============================================

def set_all_seeds(seed):
    """Set seeds for all RNG sources for reproducibility"""
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    try:
        import cupy as cp
        cp.random.seed(seed)
    except ImportError:
        pass
    
    try:
        qml.numpy.random.seed(seed)
    except (ImportError, AttributeError):
        pass

# ============================================
# Quantum Head (WITH Entanglement)
# ============================================

class QuantumHeadEnt(nn.Module):
    """
    VQC with 4 qubits, 3 layers, WITH CNOT ring entanglement → 3 classes
    Uses Lightning acceleration (GPU)
    Maps 512 features → 4 qubits
    """
    def __init__(self, input_dim=512, n_qubits=4, n_layers=3, num_classes=3):
        super().__init__()
        self.n_qubits = n_qubits
        self.n_layers = n_layers
        
        # Map 512 features → n_qubits
        self.feature_select = nn.Linear(input_dim, n_qubits)
        
        # Quantum device
        self.dev = qml.device(QUANTUM_DEVICE, wires=n_qubits)
        
        # Use adjoint differentiation for lightning
        @qml.qnode(self.dev, interface="torch", diff_method="adjoint")
        def quantum_circuit(inputs, weights):
            """
            Single-sample circuit WITH CNOT ring entanglement
            inputs: (n_qubits,)
            weights: (n_layers, n_qubits, 2)
            """
            for layer in range(n_layers):
                # Data re-uploading
                for i in range(n_qubits):
                    qml.RY(inputs[i], wires=i)
                
                # Trainable rotations
                for i in range(n_qubits):
                    qml.RY(weights[layer, i, 0], wires=i)
                    qml.RZ(weights[layer, i, 1], wires=i)
                
                # CNOT ring: 0→1→2→3→0
                for i in range(n_qubits):
                    qml.CNOT(wires=[i, (i + 1) % n_qubits])
            
            # Measurements
            measurements = []
            # Single-qubit Z
            for i in range(n_qubits):
                measurements.append(qml.expval(qml.PauliZ(i)))
            # Two-qubit ZZ (pairs)
            for i in range(0, n_qubits-1, 2):
                measurements.append(qml.expval(qml.PauliZ(i) @ qml.PauliZ(i+1)))
            
            return measurements
        
        self.quantum_circuit = quantum_circuit
        
        weight_shape = (n_layers, n_qubits, 2)
        self.q_weights = nn.Parameter(torch.randn(weight_shape) * 0.1)
        
        # Output: 6 measurements → 3 classes
        n_measurements = n_qubits + (n_qubits // 2)
        self.fc_out = nn.Linear(n_measurements, num_classes)
    
    def forward(self, x):
        """
        x: (batch, 512) features
        Returns: (batch, 3) logits
        """
        batch_size = x.size(0)
        x = torch.tanh(self.feature_select(x))
        
        # Process samples one by one
        quantum_outputs = []
        for i in range(batch_size):
            q_raw = self.quantum_circuit(x[i], self.q_weights)
            if isinstance(q_raw, (list, tuple)):
                q_out = torch.stack(q_raw)
            else:
                q_out = q_raw
            quantum_outputs.append(q_out)
        
        quantum_outputs = torch.stack(quantum_outputs).float()
        quantum_outputs = quantum_outputs.to(self.fc_out.weight.device)
        
        output = self.fc_out(quantum_outputs)
        return output

# ============================================
# Training and Evaluation
# ============================================

def train_one_epoch(model, loader, optimizer, device, show_progress=True):
    model.train()
    total_loss = 0.0
    
    for batch_idx, (x, y) in enumerate(loader):
        x = x.to(device)
        # y stays on CPU for quantum model
        
        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        
        if show_progress and batch_idx % 20 == 0:
            print(f"    Batch {batch_idx}/{len(loader)}", end="\r")
    
    if show_progress:
        print()
    return total_loss / len(loader.dataset)

def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    
    return correct / total if total > 0 else 0.0

def train_with_early_stopping(model, train_loader, test_loader, optimizer,
                              device, max_epochs=200, patience=10, name="Model"):
    best_acc = 0.0
    epochs_without_improvement = 0
    start = time.time()
    last_acc = 0.0
    
    for epoch in range(1, max_epochs + 1):
        print(f"  [{name}] Epoch {epoch}/{max_epochs}")
        loss = train_one_epoch(model, train_loader, optimizer, device, show_progress=True)
        acc = evaluate(model, test_loader, device)
        last_acc = acc
        
        elapsed = time.time() - start
        print(f"  [{name}] Epoch {epoch:2d} | loss={loss:.4f} | "
              f"test_acc={acc:.4f} | time={elapsed:.1f}s")
        
        if acc > best_acc:
            best_acc = acc
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1
        
        if epochs_without_improvement >= patience:
            print(f"  [{name}] Early stop at epoch {epoch} "
                  f"(no improvement for {patience} epochs)")
            break
    
    total_time = time.time() - start
    return {
        "best_acc": best_acc,
        "final_acc": last_acc,
        "time": total_time,
        "epochs": epoch,
    }

# ============================================
# Main
# ============================================

def main():
    print("=" * 70)
    print("BLOCK 4: Quantum (WITH Entanglement) Training on QM9")
    print("=" * 70)
    print("\nConfiguration:")
    print("  • Dataset: QM9 HOMO-LUMO gap classification")
    print("  • Features: 512-D Morgan fingerprints (frozen)")
    print("  • Classes: 3 (Low/Med/High gap)")
    print("  • Architecture: 512 → 4 qubits (3 layers, CNOT ring) → 6 meas → 3")
    print("  • Batch size: 32 (train), 64 (test)")
    print("  • Max epochs: 200")
    print("  • Patience: 10")
    print("  • Seeds: [42, 123, 456]")
    print(f"  • Quantum device: {QUANTUM_DEVICE}")
    print("=" * 70)
    
    # Load preprocessed data
    print("\nLoading preprocessed QM9 data...")
    data = torch.load("qm9_preprocessed_512d.pt", weights_only=False)
    features = data['features']
    labels = data['labels']
    
    print(f"  Loaded {len(features)} molecules")
    print(f"  Feature dim: {features.shape[1]}")
    
    # Load train/test indices from Block 1
    print("\nLoading train/test split from Block 1...")
    realnet_data = torch.load("realnet_qm9_results.pt")
    train_indices = realnet_data['train_indices']
    test_indices = realnet_data['test_indices']
    
    print(f"  Train samples: {len(train_indices)}")
    print(f"  Test samples: {len(test_indices)}")
    
    # Create datasets
    full_dataset = TensorDataset(features, labels)
    train_ds = Subset(full_dataset, train_indices)
    test_ds = Subset(full_dataset, test_indices)
    
    # Smaller batches for quantum
    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=0)
    
    seeds = [42, 123, 456]
    all_results = []
    
    for seed in seeds:
        print(f"\n{'=' * 70}")
        print(f"SEED {seed}")
        print("=" * 70)
        
        set_all_seeds(seed)
        
        print(f"\n  Training QuantumNet (WITH entanglement, seed={seed})...")
        model = QuantumHeadEnt(input_dim=512, n_qubits=4, n_layers=3, num_classes=3)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        
        result = train_with_early_stopping(
            model, train_loader, test_loader, optimizer, device,
            max_epochs=200, patience=10, name="QuantumEnt"
        )
        
        trainable_params = sum(p.numel() for p in model.parameters())
        result["trainable_params"] = trainable_params
        result["seed"] = seed
        
        all_results.append(result)
    
    # Summary
    print("\n" + "=" * 70)
    print("QUANTUM (WITH ENTANGLEMENT) QM9 SUMMARY")
    print("=" * 70)
    
    accs = [r["best_acc"] for r in all_results]
    times = [r["time"] for r in all_results]
    epochs = [r["epochs"] for r in all_results]
    params = all_results[0]["trainable_params"]
    
    print(f"\nAccuracy:     {np.mean(accs):.4f} ± {np.std(accs):.4f}")
    print(f"Time:         {np.mean(times):.1f}s ± {np.std(times):.1f}s")
    print(f"Epochs:       {np.mean(epochs):.1f} ± {np.std(epochs):.1f}")
    print(f"Parameters:   {params:,}")
    
    print("\nPer-seed results:")
    for r in all_results:
        print(f"  Seed {r['seed']}: acc={r['best_acc']:.4f}, "
              f"time={r['time']:.1f}s, epochs={r['epochs']}")
    
    # Save results
    save_dict = {
        "results": all_results,
        "summary": {
            "mean_acc": np.mean(accs),
            "std_acc": np.std(accs),
            "mean_time": np.mean(times),
            "trainable_params": params
        }
    }
    
    torch.save(save_dict, "quantum_ent_qm9_results.pt")
    print(f"\n✓ Saved results to: quantum_ent_qm9_results.pt")
    print("=" * 70)

if __name__ == "__main__":
    main()