In [None]:
"""
# Single Experiment Demo

Run a single tomography experiment step-by-step
"""

In [None]:
# Cell 1: Setup
import sys
sys.path.insert(0, '..')

from src.data_generation import generate_dataset
from src.models import TomographyNet
from src.training import QuantumDataset, train_model
from src.evaluation import evaluate_model
from src.visualization import plot_training_curves, plot_bloch_sphere_failures
from torch.utils.data import DataLoader
import torch

In [None]:
# Cell 2: Generate data
print("Generating dataset...")
train_meas, train_bloch = generate_dataset(
    n_states=10000,
    measurement_type='baseline',
    seed=42
)

val_meas, val_bloch = generate_dataset(
    n_states=2000,
    measurement_type='baseline',
    seed=43
)

test_meas, test_bloch = generate_dataset(
    n_states=2000,
    measurement_type='baseline',
    seed=44
)

print(f"Training set: {train_meas.shape[0]} states")
print(f"Validation set: {val_meas.shape[0]} states")
print(f"Test set: {test_meas.shape[0]} states")

In [None]:
# Cell 3: Create dataloaders
train_dataset = QuantumDataset(train_meas, train_bloch)
val_dataset = QuantumDataset(val_meas, val_bloch)
test_dataset = QuantumDataset(test_meas, test_bloch)

train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=512)
test_loader = DataLoader(test_dataset, batch_size=512)

In [None]:
# Cell 4: Create and train model
model = TomographyNet(input_dim=3, hidden_dims=[256, 128, 64, 32])

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

history = train_model(
    model, train_loader, val_loader,
    epochs=500, lr=1e-3, patience=50,
    device=device
)

In [None]:
# Cell 5: Visualize training
plot_training_curves(history, 'training_curves.png')
print(f"Best epoch: {history['best_epoch']}")
print(f"Final validation fidelity: {history['final_fidelity']:.4f}")

In [None]:
# Cell 6: Evaluate on test set
metrics = evaluate_model(model, test_loader, device=device)

print("\nTest Set Results:")
print(f"  Mean Fidelity: {metrics['mean_fidelity']:.4f} ± {metrics['std_fidelity']:.4f}")
print(f"  RMSE (x,y,z): ({metrics['rmse_x']:.4f}, {metrics['rmse_y']:.4f}, {metrics['rmse_z']:.4f})")
print(f"  Frac > 0.95: {metrics['frac_above_95']:.4f}")

In [None]:
# Cell 7: Visualize failures
plot_bloch_sphere_failures(
    metrics['predictions'],
    metrics['true_values'],
    metrics['fidelity_distribution'],
    'failure_examples.png',
    n_examples=5
)