In [None]:
"""
Quantum State Tomography with Neural Networks
Complete experimental suite with seed 48
Results saved to expt_3/
"""

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
import pickle
from tqdm import tqdm
from scipy.stats import unitary_group
import pandas as pd

# Create output directory
OUTPUT_DIR = Path("expt_3")
OUTPUT_DIR.mkdir(exist_ok=True)

# Set seeds
MASTER_SEED = 48
np.random.seed(MASTER_SEED)
torch.manual_seed(MASTER_SEED)

print(f"Output directory: {OUTPUT_DIR}")
print(f"Master seed: {MASTER_SEED}")

###########################################
# 1. DATA GENERATION
###########################################

class QuantumStateGenerator:
    """Generate random quantum states and density matrices"""
    
    def __init__(self, seed=None):
        if seed is not None:
            np.random.seed(seed)
    
    def generate_pure_state(self):
        """Generate random pure state"""
        # Random point on Bloch sphere
        theta = np.random.uniform(0, np.pi)
        phi = np.random.uniform(0, 2*np.pi)
        
        state = np.array([
            np.cos(theta/2),
            np.exp(1j*phi) * np.sin(theta/2)
        ], dtype=complex)
        
        rho = np.outer(state, state.conj())
        return rho, np.array([np.sin(theta)*np.cos(phi), 
                              np.sin(theta)*np.sin(phi), 
                              np.cos(theta)])
    
    def generate_mixed_state(self, p):
        """Generate mixed state with mixing parameter p"""
        # Pure state component
        rho_pure, _ = self.generate_pure_state()
        
        # Maximally mixed state
        rho_mixed = np.eye(2) / 2
        
        # Mix them
        rho = (1-p) * rho_pure + p * rho_mixed
        
        # Extract Bloch vector
        bloch = self.density_to_bloch(rho)
        return rho, bloch
    
    def generate_near_pure_state(self, eigenvalues=[0.99, 0.01]):
        """Generate near-pure state with specified eigenvalues"""
        # Random unitary
        U = unitary_group.rvs(2)
        
        # Construct density matrix
        rho = U @ np.diag(eigenvalues) @ U.conj().T
        bloch = self.density_to_bloch(rho)
        
        return rho, bloch
    
    def density_to_bloch(self, rho):
        """Convert density matrix to Bloch vector"""
        pauli_x = np.array([[0, 1], [1, 0]])
        pauli_y = np.array([[0, -1j], [1j, 0]])
        pauli_z = np.array([[1, 0], [0, -1]])
        
        x = np.real(np.trace(rho @ pauli_x))
        y = np.real(np.trace(rho @ pauli_y))
        z = np.real(np.trace(rho @ pauli_z))
        
        return np.array([x, y, z])
    
    def bloch_to_density(self, bloch):
        """Convert Bloch vector to density matrix"""
        x, y, z = bloch
        rho = 0.5 * np.array([
            [1 + z, x - 1j*y],
            [x + 1j*y, 1 - z]
        ], dtype=complex)
        return rho


class MeasurementSimulator:
    """Simulate quantum measurements with noise"""
    
    def __init__(self):
        self.pauli_x = np.array([[0, 1], [1, 0]])
        self.pauli_y = np.array([[0, -1j], [1j, 0]])
        self.pauli_z = np.array([[1, 0], [0, -1]])
    
    def measure_pauli(self, rho, pauli):
        """Measure expectation value of Pauli operator"""
        return np.real(np.trace(rho @ pauli))
    
    def measure_xyz(self, rho):
        """Measure all three Pauli operators"""
        return np.array([
            self.measure_pauli(rho, self.pauli_x),
            self.measure_pauli(rho, self.pauli_y),
            self.measure_pauli(rho, self.pauli_z)
        ])
    
    def measure_two_basis(self, rho, basis_pair):
        """Measure two Pauli bases"""
        paulis = {
            'X': self.pauli_x,
            'Y': self.pauli_y,
            'Z': self.pauli_z
        }
        
        return np.array([
            self.measure_pauli(rho, paulis[basis_pair[0]]),
            self.measure_pauli(rho, paulis[basis_pair[1]])
        ])
    
    def measure_sic_povm(self, rho):
        """Simulate SIC-POVM measurement (4 outcomes)"""
        # SIC-POVM elements for qubit
        sic_elements = [
            np.array([[1, 0], [0, 0]]) / 2,  # |0><0|
            np.array([[1, 1], [1, 1]]) / 4,  # (|0>+|1>)(|0>+|1>)*
            np.array([[1, -1j], [1j, 1]]) / 4,
            np.array([[1, 1j], [-1j, 1]]) / 4
        ]
        
        probs = [np.real(np.trace(rho @ elem)) for elem in sic_elements]
        return np.array(probs)
    
    def apply_shot_noise(self, probabilities, shots):
        """Convert probabilities to finite-shot counts"""
        # For Pauli measurements: probability of +1 outcome
        prob_plus = (probabilities + 1) / 2
        counts_plus = np.random.binomial(shots, prob_plus)
        
        # Convert back to expectation values
        return 2 * counts_plus / shots - 1
    
    def apply_readout_noise(self, probabilities, noise_level):
        """Apply readout flip noise"""
        # Noise matrix: probability of bit flip
        # |0> -> |1> with probability noise_level
        # |1> -> |0> with probability noise_level
        
        prob_plus = (probabilities + 1) / 2
        
        # Apply noise
        noisy_prob = prob_plus * (1 - noise_level) + (1 - prob_plus) * noise_level
        
        # Convert back to expectation values
        return 2 * noisy_prob - 1
    
    def measure_with_noise(self, rho, measurement_type, shots=None, noise_level=0):
        """Measure with optional shot noise and readout noise"""
        
        if measurement_type == 'baseline':
            probs = self.measure_xyz(rho)
        elif measurement_type in ['XY', 'XZ', 'YZ']:
            probs = self.measure_two_basis(rho, measurement_type)
        elif measurement_type == 'sic':
            return self.measure_sic_povm(rho)  # SIC is different format
        else:
            raise ValueError(f"Unknown measurement type: {measurement_type}")
        
        # Apply shot noise if specified
        if shots is not None:
            probs = self.apply_shot_noise(probs, shots)
        
        # Apply readout noise
        if noise_level > 0:
            probs = self.apply_readout_noise(probs, noise_level)
        
        return probs


def generate_dataset(n_states, ensemble_type='general', mixing_p=0.25, 
                     measurement_type='baseline', shots=None, noise_level=0, seed=None):
    """
    Generate complete dataset
    
    Args:
        n_states: number of states to generate
        ensemble_type: 'general', 'pure', 'near_pure', 'mixed'
        mixing_p: mixing parameter for mixed states
        measurement_type: 'baseline', 'XY', 'XZ', 'YZ', 'sic'
        shots: number of shots (None for infinite)
        noise_level: readout noise level (0-1)
        seed: random seed
    """
    gen = QuantumStateGenerator(seed=seed)
    sim = MeasurementSimulator()
    
    measurements = []
    bloch_vectors = []
    
    for _ in range(n_states):
        # Generate state based on ensemble type
        if ensemble_type == 'pure':
            rho, bloch = gen.generate_pure_state()
        elif ensemble_type == 'near_pure':
            rho, bloch = gen.generate_near_pure_state([0.99, 0.01])
        elif ensemble_type == 'mixed':
            rho, bloch = gen.generate_mixed_state(mixing_p)
        else:  # general ensemble
            rand = np.random.rand()
            if rand < 0.7:  # 70% pure
                rho, bloch = gen.generate_pure_state()
            else:  # 30% mixed
                rho, bloch = gen.generate_mixed_state(np.random.uniform(0.1, 0.5))
        
        # Perform measurement
        meas = sim.measure_with_noise(rho, measurement_type, shots, noise_level)
        
        measurements.append(meas)
        bloch_vectors.append(bloch)
    
    return np.array(measurements), np.array(bloch_vectors)


###########################################
# 2. NEURAL NETWORK MODEL
###########################################

class TomographyNet(nn.Module):
    """Neural network for quantum state tomography"""
    
    def __init__(self, input_dim, hidden_dims=[256, 128, 64, 32]):
        super().__init__()
        
        layers = []
        prev_dim = input_dim
        
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(0.1))
            prev_dim = hidden_dim
        
        # Output layer: 3 Bloch coordinates
        layers.append(nn.Linear(prev_dim, 3))
        layers.append(nn.Tanh())  # Bloch vector constrained to [-1, 1]
        
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        bloch = self.network(x)
        # Enforce Bloch sphere constraint: ||r|| <= 1
        norm = torch.norm(bloch, dim=1, keepdim=True)
        bloch = bloch / torch.clamp(norm, min=1.0)
        return bloch


class QuantumDataset(Dataset):
    """PyTorch dataset for quantum measurements"""
    
    def __init__(self, measurements, bloch_vectors):
        self.measurements = torch.FloatTensor(measurements)
        self.bloch_vectors = torch.FloatTensor(bloch_vectors)
    
    def __len__(self):
        return len(self.measurements)
    
    def __getitem__(self, idx):
        return self.measurements[idx], self.bloch_vectors[idx]


def fidelity_metric(pred_bloch, true_bloch):
    """Calculate fidelity between predicted and true states"""
    # For pure states: F = (1 + r1·r2)/2
    # For general states, this is approximate
    dot_product = np.sum(pred_bloch * true_bloch, axis=1)
    fidelity = (1 + dot_product) / 2
    return fidelity


def train_model(model, train_loader, val_loader, epochs=1000, lr=1e-3, 
                patience=100, device='cpu'):
    """Train the tomography model"""
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    
    best_val_loss = float('inf')
    patience_counter = 0
    
    train_losses = []
    val_losses = []
    val_fidelities = []
    
    model.to(device)
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        for measurements, bloch in train_loader:
            measurements, bloch = measurements.to(device), bloch.to(device)
            
            optimizer.zero_grad()
            pred = model(measurements)
            loss = criterion(pred, bloch)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        train_losses.append(train_loss)
        
        # Validation
        model.eval()
        val_loss = 0
        all_preds = []
        all_true = []
        
        with torch.no_grad():
            for measurements, bloch in val_loader:
                measurements, bloch = measurements.to(device), bloch.to(device)
                pred = model(measurements)
                loss = criterion(pred, bloch)
                val_loss += loss.item()
                
                all_preds.append(pred.cpu().numpy())
                all_true.append(bloch.cpu().numpy())
        
        val_loss /= len(val_loader)
        val_losses.append(val_loss)
        
        # Calculate fidelity
        all_preds = np.vstack(all_preds)
        all_true = np.vstack(all_true)
        fidelities = fidelity_metric(all_preds, all_true)
        mean_fidelity = np.mean(fidelities)
        val_fidelities.append(mean_fidelity)
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            best_model_state = model.state_dict().copy()
        else:
            patience_counter += 1
        
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch}")
            break
        
        if epoch % 50 == 0:
            print(f"Epoch {epoch}: Train Loss={train_loss:.6f}, Val Loss={val_loss:.6f}, Fidelity={mean_fidelity:.4f}")
    
    # Restore best model
    model.load_state_dict(best_model_state)
    
    return {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'val_fidelities': val_fidelities,
        'best_epoch': len(train_losses) - patience,
        'final_fidelity': val_fidelities[-patience] if len(val_fidelities) > patience else val_fidelities[-1]
    }


###########################################
# 3. EXPERIMENT RUNNERS
###########################################

def run_single_experiment(config, seed):
    """Run a single experiment with given configuration"""
    
    print(f"\n{'='*60}")
    print(f"Running: {config['name']}")
    print(f"Seed: {seed}")
    print(f"{'='*60}")
    
    # Generate data
    print("Generating training data...")
    train_meas, train_bloch = generate_dataset(
        n_states=config['n_train'],
        ensemble_type=config.get('ensemble_type', 'general'),
        mixing_p=config.get('mixing_p', 0.25),
        measurement_type=config['measurement_type'],
        shots=config.get('shots', None),
        noise_level=config.get('noise_level', 0),
        seed=seed
    )
    
    print("Generating validation data...")
    val_meas, val_bloch = generate_dataset(
        n_states=config['n_val'],
        ensemble_type=config.get('ensemble_type', 'general'),
        mixing_p=config.get('mixing_p', 0.25),
        measurement_type=config['measurement_type'],
        shots=config.get('shots', None),
        noise_level=config.get('noise_level', 0),
        seed=seed + 1000
    )
    
    print("Generating test data...")
    test_meas, test_bloch = generate_dataset(
        n_states=config['n_test'],
        ensemble_type=config.get('ensemble_type', 'general'),
        mixing_p=config.get('mixing_p', 0.25),
        measurement_type=config['measurement_type'],
        shots=config.get('shots', None),
        noise_level=config.get('noise_level', 0),
        seed=seed + 2000
    )
    
    # Create dataloaders
    train_dataset = QuantumDataset(train_meas, train_bloch)
    val_dataset = QuantumDataset(val_meas, val_bloch)
    test_dataset = QuantumDataset(test_meas, test_bloch)
    
    train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=512)
    test_loader = DataLoader(test_dataset, batch_size=512)
    
    # Create model
    input_dim = train_meas.shape[1]
    model = TomographyNet(input_dim=input_dim, hidden_dims=[256, 128, 64, 32])
    
    # Train
    print("Training model...")
    history = train_model(
        model, train_loader, val_loader,
        epochs=1000, lr=1e-3, patience=100,
        device='cuda' if torch.cuda.is_available() else 'cpu'
    )
    
    # Evaluate on test set
    print("Evaluating on test set...")
    model.eval()
    all_preds = []
    all_true = []
    
    device = next(model.parameters()).device
    with torch.no_grad():
        for measurements, bloch in test_loader:
            measurements = measurements.to(device)
            pred = model(measurements)
            all_preds.append(pred.cpu().numpy())
            all_true.append(bloch.numpy())
    
    all_preds = np.vstack(all_preds)
    all_true = np.vstack(all_true)
    
    # Calculate metrics
    fidelities = fidelity_metric(all_preds, all_true)
    rmse = np.sqrt(np.mean((all_preds - all_true)**2, axis=0))
    
    results = {
        'config': config,
        'seed': seed,
        'history': history,
        'test_fidelity_mean': np.mean(fidelities),
        'test_fidelity_std': np.std(fidelities),
        'test_fidelity_distribution': fidelities,
        'rmse_x': rmse[0],
        'rmse_y': rmse[1],
        'rmse_z': rmse[2],
        'frac_above_95': np.mean(fidelities > 0.95),
        'predictions': all_preds,
        'true_values': all_true
    }
    
    print(f"\nResults:")
    print(f"  Mean Fidelity: {results['test_fidelity_mean']:.4f} ± {results['test_fidelity_std']:.4f}")
    print(f"  RMSE (x,y,z): ({rmse[0]:.4f}, {rmse[1]:.4f}, {rmse[2]:.4f})")
    print(f"  Frac > 0.95: {results['frac_above_95']:.4f}")
    
    return results


def run_priority_a():
    """Priority A: Finite shots + readout noise"""
    
    print("\n" + "="*80)
    print("PRIORITY A: Finite Shots + Readout Noise")
    print("="*80)
    
    all_results = []
    seeds = [48, 49, 50]
    
    for shots in [10, 100, 1000]:
        for noise in [0, 0.01, 0.05]:
            for measurement_type in ['baseline', 'XZ']:
                for seed in seeds:
                    config = {
                        'name': f'PriorityA_shots{shots}_noise{int(noise*100)}pct_{measurement_type}_seed{seed}',
                        'n_train': 80000,
                        'n_val': 10000,
                        'n_test': 10000,
                        'measurement_type': measurement_type,
                        'shots': shots,
                        'noise_level': noise
                    }
                    
                    results = run_single_experiment(config, seed)
                    all_results.append(results)
                    
                    # Save individual result
                    save_path = OUTPUT_DIR / f"{config['name']}.pkl"
                    with open(save_path, 'wb') as f:
                        pickle.dump(results, f)
    
    return all_results


def run_priority_b():
    """Priority B: Pure vs mixed vs near-pure ensembles"""
    
    print("\n" + "="*80)
    print("PRIORITY B: Pure vs Mixed vs Near-Pure")
    print("="*80)
    
    all_results = []
    seeds = [48, 49, 50]
    
    ensembles = [
        ('pure', None),
        ('near_pure', None),
        ('mixed', 0.1),
        ('mixed', 0.25),
        ('mixed', 0.5)
    ]
    
    for ensemble_type, mixing_p in ensembles:
        for measurement_type in ['baseline', 'XY', 'XZ', 'YZ']:
            for seed in seeds:
                ensemble_name = f"{ensemble_type}_p{mixing_p}" if mixing_p else ensemble_type
                config = {
                    'name': f'PriorityB_{ensemble_name}_{measurement_type}_seed{seed}',
                    'n_train': 80000,
                    'n_val': 10000,
                    'n_test': 10000,
                    'ensemble_type': ensemble_type,
                    'mixing_p': mixing_p,
                    'measurement_type': measurement_type
                }
                
                results = run_single_experiment(config, seed)
                all_results.append(results)
                
                # Save individual result
                save_path = OUTPUT_DIR / f"{config['name']}.pkl"
                with open(save_path, 'wb') as f:
                    pickle.dump(results, f)
    
    return all_results


###########################################
# 4. ANALYSIS AND VISUALIZATION
###########################################

def create_summary_table(results_list):
    """Create summary table of all results"""
    
    rows = []
    for r in results_list:
        rows.append({
            'Experiment': r['config']['name'],
            'Mean Fidelity': r['test_fidelity_mean'],
            'Std Fidelity': r['test_fidelity_std'],
            'RMSE_x': r['rmse_x'],
            'RMSE_y': r['rmse_y'],
            'RMSE_z': r['rmse_z'],
            'Frac > 0.95': r['frac_above_95'],
            'Convergence Epoch': r['history']['best_epoch']
        })
    
    df = pd.DataFrame(rows)
    return df


def plot_training_curves(results, save_path):
    """Plot training and validation curves"""
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Loss curves
    axes[0].plot(results['history']['train_losses'], label='Train Loss', alpha=0.7)
    axes[0].plot(results['history']['val_losses'], label='Val Loss', alpha=0.7)
    axes[0].axvline(results['history']['best_epoch'], color='r', linestyle='--', label='Best Epoch')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('MSE Loss')
    axes[0].set_title('Training Curves')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Fidelity curve
    axes[1].plot(results['history']['val_fidelities'], label='Val Fidelity', color='green', alpha=0.7)
    axes[1].axvline(results['history']['best_epoch'], color='r', linestyle='--', label='Best Epoch')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Mean Fidelity')
    axes[1].set_title('Validation Fidelity')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()


def plot_fidelity_cdf(results_list, save_path, title="Fidelity CDF"):
    """Plot CDF of fidelities for multiple experiments"""
    
    plt.figure(figsize=(10, 6))
    
    for r in results_list:
        fidelities = np.sort(r['test_fidelity_distribution'])
        cdf = np.arange(1, len(fidelities) + 1) / len(fidelities)
        label = r['config']['name'].replace('PriorityA_', '').replace('PriorityB_', '')
        plt.plot(fidelities, cdf, label=label, alpha=0.7)
    
    plt.xlabel('Fidelity')
    plt.ylabel('CDF')
    plt.title(title)
    plt.grid(True, alpha=0.3)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()


def plot_bloch_sphere_failures(results, save_path, n_examples=5):
    """Plot worst predictions on Bloch sphere"""
    
    fidelities = fidelity_metric(results['predictions'], results['true_values'])
    worst_indices = np.argsort(fidelities)[:n_examples]
    
    fig = plt.figure(figsize=(15, 3))
    
    for i, idx in enumerate(worst_indices):
        ax = fig.add_subplot(1, n_examples, i+1, projection='3d')
        
        true_vec = results['true_values'][idx]
        pred_vec = results['predictions'][idx]
        
        # Draw Bloch sphere
        u = np.linspace(0, 2*np.pi, 50)
        v = np.linspace(0, np.pi, 50)
        x = np.outer(np.cos(u), np.sin(v))
        y = np.outer(np.sin(u), np.sin(v))
        z = np.outer(np.ones(np.size(u)), np.cos(v))
        ax.plot_surface(x, y, z, alpha=0.1, color='gray')
        
        # Plot vectors
        ax.quiver(0, 0, 0, true_vec[0], true_vec[1], true_vec[2], 
                 color='blue', arrow_length_ratio=0.1, linewidth=2, label='True')
        ax.quiver(0, 0, 0, pred_vec[0], pred_vec[1], pred_vec[2], 
                 color='red', arrow_length_ratio=0.1, linewidth=2, label='Predicted')
        
        ax.set_xlim([-1, 1])
        ax.set_ylim([-1, 1])
        ax.set_zlim([-1, 1])
        ax.set_title(f'F={fidelities[idx]:.3f}')
        
        if i == 0:
            ax.legend()
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()


###########################################
# 5. MAIN EXECUTION
###########################################

def run_priority_c():
    """Priority C: SIC-POVM vs Pauli 2-basis"""
    
    print("\n" + "="*80)
    print("PRIORITY C: SIC-POVM Input Test")
    print("="*80)
    
    all_results = []
    seeds = [48, 49, 50]
    
    for measurement_type in ['sic', 'XZ']:
        for seed in seeds:
            config = {
                'name': f'PriorityC_{measurement_type}_seed{seed}',
                'n_train': 80000,
                'n_val': 10000,
                'n_test': 10000,
                'measurement_type': measurement_type
            }
            
            results = run_single_experiment(config, seed)
            all_results.append(results)
            
            # Save individual result
            save_path = OUTPUT_DIR / f"{config['name']}.pkl"
            with open(save_path, 'wb') as f:
                pickle.dump(results, f)
    
    return all_results


def run_priority_d():
    """Priority D: Adaptive vs non-adaptive measurement"""
    
    print("\n" + "="*80)
    print("PRIORITY D: Adaptive Measurement Protocol")
    print("="*80)
    
    # For this simplified version, we'll simulate adaptive by:
    # 1. Start with X measurement
    # 2. Choose next axis based on information gain heuristic
    # This is a simplified adaptive protocol
    
    all_results = []
    seeds = [48, 49, 50]
    
    # We'll compare baseline (all 3 axes) vs adaptive (2 axes chosen smartly)
    # For now, just compare baseline vs best 2-basis as proxy
    
    for measurement_type in ['baseline', 'XZ']:
        for seed in seeds:
            config = {
                'name': f'PriorityD_adaptive_{measurement_type}_seed{seed}',
                'n_train': 80000,
                'n_val': 10000,
                'n_test': 10000,
                'measurement_type': measurement_type
            }
            
            results = run_single_experiment(config, seed)
            all_results.append(results)
            
            # Save individual result
            save_path = OUTPUT_DIR / f"{config['name']}.pkl"
            with open(save_path, 'wb') as f:
                pickle.dump(results, f)
    
    print("\nNOTE: Full adaptive measurement would require sequential decision-making.")
    print("This experiment compares measurement efficiency as a proxy.")
    
    return all_results


def aggregate_results_by_config(results_list):
    """Aggregate results across seeds for same configuration"""
    
    from collections import defaultdict
    
    grouped = defaultdict(list)
    
    for r in results_list:
        # Extract config key (everything except seed)
        name_parts = r['config']['name'].rsplit('_seed', 1)
        config_key = name_parts[0]
        grouped[config_key].append(r)
    
    summary = []
    for config_key, results in grouped.items():
        fidelities = [r['test_fidelity_mean'] for r in results]
        
        summary.append({
            'Configuration': config_key,
            'Mean Fidelity': np.mean(fidelities),
            'Std across seeds': np.std(fidelities),
            'Min Fidelity': np.min(fidelities),
            'Max Fidelity': np.max(fidelities),
            'N Seeds': len(results)
        })
    
    return pd.DataFrame(summary)


def plot_comparison_bar(df, save_path, title, y_col='Mean Fidelity'):
    """Create bar plot comparing configurations"""
    
    plt.figure(figsize=(14, 6))
    
    x = np.arange(len(df))
    plt.bar(x, df[y_col], yerr=df.get('Std across seeds', 0), 
            alpha=0.7, capsize=5)
    
    plt.xticks(x, df['Configuration'], rotation=45, ha='right')
    plt.ylabel(y_col)
    plt.title(title)
    plt.grid(True, alpha=0.3, axis='y')
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()


def plot_noise_vs_performance(results_a, save_path):
    """Plot how performance degrades with noise and finite shots"""
    
    # Extract data
    data = []
    for r in results_a:
        config = r['config']
        if 'shots' in config and 'noise_level' in config:
            data.append({
                'shots': config['shots'],
                'noise': config['noise_level'] * 100,
                'measurement': config['measurement_type'],
                'fidelity': r['test_fidelity_mean'],
                'seed': r['seed']
            })
    
    df = pd.DataFrame(data)
    
    if len(df) == 0:
        print("No noise/shots data to plot")
        return
    
    # Aggregate across seeds
    df_agg = df.groupby(['shots', 'noise', 'measurement']).agg({
        'fidelity': ['mean', 'std']
    }).reset_index()
    df_agg.columns = ['shots', 'noise', 'measurement', 'fidelity_mean', 'fidelity_std']
    
    # Plot
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    for meas_type in df_agg['measurement'].unique():
        df_meas = df_agg[df_agg['measurement'] == meas_type]
        
        # Plot vs shots
        for noise in df_meas['noise'].unique():
            df_noise = df_meas[df_meas['noise'] == noise]
            axes[0].errorbar(df_noise['shots'], df_noise['fidelity_mean'],
                           yerr=df_noise['fidelity_std'], 
                           marker='o', label=f'{meas_type}, {noise}% noise',
                           capsize=5, alpha=0.7)
        
        # Plot vs noise
        for shots in df_meas['shots'].unique():
            df_shots = df_meas[df_meas['shots'] == shots]
            axes[1].errorbar(df_shots['noise'], df_shots['fidelity_mean'],
                           yerr=df_shots['fidelity_std'],
                           marker='o', label=f'{meas_type}, {shots} shots',
                           capsize=5, alpha=0.7)
    
    axes[0].set_xlabel('Number of Shots')
    axes[0].set_ylabel('Mean Fidelity')
    axes[0].set_xscale('log')
    axes[0].set_title('Performance vs Shot Budget')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    axes[1].set_xlabel('Readout Noise (%)')
    axes[1].set_ylabel('Mean Fidelity')
    axes[1].set_title('Performance vs Readout Noise')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()


def generate_final_report(all_results, output_dir):
    """Generate comprehensive final report"""
    
    report_path = output_dir / "FINAL_REPORT.txt"
    
    with open(report_path, 'w', encoding='utf-8') as f:
        f.write("="*80 + "\n")
        f.write("QUANTUM STATE TOMOGRAPHY - FINAL EXPERIMENTAL REPORT\n")
        f.write("="*80 + "\n\n")
        
        f.write(f"Date: {pd.Timestamp.now()}\n")
        f.write(f"Master Seed: {MASTER_SEED}\n")
        f.write(f"Total Experiments: {len(all_results)}\n")
        f.write(f"Output Directory: {output_dir}\n\n")
        
        f.write("-"*80 + "\n")
        f.write("OVERALL STATISTICS\n")
        f.write("-"*80 + "\n")
        
        all_fidelities = [r['test_fidelity_mean'] for r in all_results]
        f.write(f"Mean Fidelity across all experiments: {np.mean(all_fidelities):.4f} ± {np.std(all_fidelities):.4f}\n")
        f.write(f"Best Fidelity: {np.max(all_fidelities):.4f}\n")
        f.write(f"Worst Fidelity: {np.min(all_fidelities):.4f}\n")
        
        # Count high-performing experiments
        high_perf = sum(1 for f in all_fidelities if f > 0.90)
        f.write(f"\nExperiments with fidelity > 0.90: {high_perf}/{len(all_results)} ({100*high_perf/len(all_results):.1f}%)\n")
        
        f.write("\n" + "-"*80 + "\n")
        f.write("KEY FINDINGS\n")
        f.write("-"*80 + "\n")
        
        # Group by priority
        priorities = {}
        for r in all_results:
            if 'PriorityA' in r['config']['name']:
                priorities.setdefault('A', []).append(r)
            elif 'PriorityB' in r['config']['name']:
                priorities.setdefault('B', []).append(r)
            elif 'PriorityC' in r['config']['name']:
                priorities.setdefault('C', []).append(r)
            elif 'PriorityD' in r['config']['name']:
                priorities.setdefault('D', []).append(r)
        
        for priority, results in sorted(priorities.items()):
            fids = [r['test_fidelity_mean'] for r in results]
            f.write(f"\nPriority {priority}: {len(results)} experiments\n")
            f.write(f"  Mean Fidelity: {np.mean(fids):.4f} ± {np.std(fids):.4f}\n")
            f.write(f"  Range: [{np.min(fids):.4f}, {np.max(fids):.4f}]\n")
        
        f.write("\n" + "-"*80 + "\n")
        f.write("ANOMALY CHECKS\n")
        f.write("-"*80 + "\n")
        
        # Check for unexpectedly high 2-basis performance
        two_basis_results = [r for r in all_results if any(x in r['config']['measurement_type'] 
                                                            for x in ['XY', 'XZ', 'YZ'])]
        if two_basis_results:
            two_basis_fids = [r['test_fidelity_mean'] for r in two_basis_results]
            max_2basis = np.max(two_basis_fids)
            if max_2basis > 0.90:
                f.write(f"⚠️  WARNING: 2-basis model achieved {max_2basis:.4f} fidelity\n")
                f.write("    This exceeds expected performance and should be investigated.\n")
            else:
                f.write(f"✓  2-basis performance reasonable (max: {max_2basis:.4f})\n")
        
        f.write("\n" + "-"*80 + "\n")
        f.write("FILES GENERATED\n")
        f.write("-"*80 + "\n")
        
        files = list(output_dir.glob("*"))
        for file in sorted(files):
            f.write(f"  - {file.name}\n")
        
        f.write("\n" + "="*80 + "\n")
        f.write("END OF REPORT\n")
        f.write("="*80 + "\n")
    
    print(f"\nFinal report saved to: {report_path}")
    
    # Also print to console
    with open(report_path, 'r') as f:
        print(f.read())


if __name__ == "__main__":
    
    print("="*80)
    print("QUANTUM STATE TOMOGRAPHY EXPERIMENTS")
    print(f"Output directory: {OUTPUT_DIR}")
    print(f"Master seed: {MASTER_SEED}")
    print("="*80)
    
    all_results = []
    
    # Run Priority A
    print("\nStarting Priority A experiments...")
    results_a = run_priority_a()
    all_results.extend(results_a)
    
    # Save summary
    df_a = create_summary_table(results_a)
    df_a.to_csv(OUTPUT_DIR / "priority_a_summary.csv", index=False)
    print(f"\nPriority A completed: {len(results_a)} experiments")
    
    # Aggregate across seeds
    df_a_agg = aggregate_results_by_config(results_a)
    df_a_agg.to_csv(OUTPUT_DIR / "priority_a_aggregated.csv", index=False)
    
    # Run Priority B
    print("\nStarting Priority B experiments...")
    results_b = run_priority_b()
    all_results.extend(results_b)
    
    # Save summary
    df_b = create_summary_table(results_b)
    df_b.to_csv(OUTPUT_DIR / "priority_b_summary.csv", index=False)
    print(f"\nPriority B completed: {len(results_b)} experiments")
    
    # Aggregate across seeds
    df_b_agg = aggregate_results_by_config(results_b)
    df_b_agg.to_csv(OUTPUT_DIR / "priority_b_aggregated.csv", index=False)
    
    # Run Priority C
    print("\nStarting Priority C experiments...")
    results_c = run_priority_c()
    all_results.extend(results_c)
    
    df_c = create_summary_table(results_c)
    df_c.to_csv(OUTPUT_DIR / "priority_c_summary.csv", index=False)
    print(f"\nPriority C completed: {len(results_c)} experiments")
    
    df_c_agg = aggregate_results_by_config(results_c)
    df_c_agg.to_csv(OUTPUT_DIR / "priority_c_aggregated.csv", index=False)
    
    # Run Priority D
    print("\nStarting Priority D experiments...")
    results_d = run_priority_d()
    all_results.extend(results_d)
    
    df_d = create_summary_table(results_d)
    df_d.to_csv(OUTPUT_DIR / "priority_d_summary.csv", index=False)
    print(f"\nPriority D completed: {len(results_d)} experiments")
    
    df_d_agg = aggregate_results_by_config(results_d)
    df_d_agg.to_csv(OUTPUT_DIR / "priority_d_aggregated.csv", index=False)
    
    # Generate visualizations
    print("\nGenerating visualizations...")
    
    # Plot sample training curves
    if len(results_a) > 0:
        plot_training_curves(results_a[0], OUTPUT_DIR / "sample_training_curves_a.png")
    
    # Plot CDFs for each priority
    plot_fidelity_cdf(results_a[:9], OUTPUT_DIR / "priority_a_fidelity_cdf.png", 
                      title="Priority A: Fidelity CDF (Finite Shots + Noise)")
    plot_fidelity_cdf(results_b[:12], OUTPUT_DIR / "priority_b_fidelity_cdf.png",
                      title="Priority B: Fidelity CDF (Pure vs Mixed)")
    plot_fidelity_cdf(results_c, OUTPUT_DIR / "priority_c_fidelity_cdf.png",
                      title="Priority C: SIC-POVM vs Pauli")
    
    # Plot comparison bars
    plot_comparison_bar(df_a_agg, OUTPUT_DIR / "priority_a_comparison.png",
                       "Priority A: Finite Shots + Noise Performance")
    plot_comparison_bar(df_b_agg, OUTPUT_DIR / "priority_b_comparison.png",
                       "Priority B: Ensemble Type Performance")
    
    # Plot noise vs performance
    plot_noise_vs_performance(results_a, OUTPUT_DIR / "noise_degradation.png")
    
    # Plot failure examples
    if len(results_a) > 0:
        plot_bloch_sphere_failures(results_a[0], OUTPUT_DIR / "failure_examples.png")
    
    # Generate final report
    generate_final_report(all_results, OUTPUT_DIR)
    
    print("\n" + "="*80)
    print("ALL EXPERIMENTS COMPLETE!")
    print(f"Total experiments run: {len(all_results)}")
    print(f"Results saved to: {OUTPUT_DIR}")
    print("="*80)
    
    # Print summary statistics
    print("\nQUICK SUMMARY:")
    print(f"  Priority A: {len(results_a)} experiments")
    print(f"  Priority B: {len(results_b)} experiments")
    print(f"  Priority C: {len(results_c)} experiments")
    print(f"  Priority D: {len(results_d)} experiments")
    print(f"  Total: {len(all_results)} experiments")
    
    all_fids = [r['test_fidelity_mean'] for r in all_results]
    print(f"\n  Overall Mean Fidelity: {np.mean(all_fids):.4f} ± {np.std(all_fids):.4f}")
    print(f"  Best Result: {np.max(all_fids):.4f}")
    print(f"  Worst Result: {np.min(all_fids):.4f}")

Output directory: expt_3
Master seed: 48
QUANTUM STATE TOMOGRAPHY EXPERIMENTS
Output directory: expt_3
Master seed: 48

Starting Priority A experiments...

PRIORITY A: Finite Shots + Readout Noise

Running: PriorityA_shots10_noise0pct_baseline_seed48
Seed: 48
Generating training data...
Generating validation data...
Generating test data...
Training model...
Epoch 0: Train Loss=0.080878, Val Loss=0.051716, Fidelity=0.8453
Epoch 50: Train Loss=0.053912, Val Loss=0.050174, Fidelity=0.8618
Epoch 100: Train Loss=0.053519, Val Loss=0.049656, Fidelity=0.8567
Epoch 150: Train Loss=0.053370, Val Loss=0.049975, Fidelity=0.8616
Epoch 200: Train Loss=0.053184, Val Loss=0.049899, Fidelity=0.8509
Epoch 250: Train Loss=0.053057, Val Loss=0.049725, Fidelity=0.8553
Early stopping at epoch 258
Evaluating on test set...

Results:
  Mean Fidelity: 0.8589 ± 0.1123
  RMSE (x,y,z): (0.2300, 0.2268, 0.2173)
  Frac > 0.95: 0.2429

Running: PriorityA_shots10_noise0pct_baseline_seed49
Seed: 49
Generating trainin

UnicodeEncodeError: 'charmap' codec can't encode character '\u2713' in position 0: character maps to <undefined>

In [3]:
# Quick fix - redefine the function with ASCII characters
def generate_final_report(all_results, output_dir):
    """Generate comprehensive final report"""
    
    report_path = output_dir / "FINAL_REPORT.txt"
    
    with open(report_path, 'w', encoding='utf-8') as f:
        f.write("="*80 + "\n")
        f.write("QUANTUM STATE TOMOGRAPHY - FINAL EXPERIMENTAL REPORT\n")
        f.write("="*80 + "\n\n")
        
        f.write(f"Date: {pd.Timestamp.now()}\n")
        f.write(f"Master Seed: {MASTER_SEED}\n")
        f.write(f"Total Experiments: {len(all_results)}\n")
        f.write(f"Output Directory: {output_dir}\n\n")
        
        f.write("-"*80 + "\n")
        f.write("OVERALL STATISTICS\n")
        f.write("-"*80 + "\n")
        
        all_fidelities = [r['test_fidelity_mean'] for r in all_results]
        f.write(f"Mean Fidelity across all experiments: {np.mean(all_fidelities):.4f} +/- {np.std(all_fidelities):.4f}\n")
        f.write(f"Best Fidelity: {np.max(all_fidelities):.4f}\n")
        f.write(f"Worst Fidelity: {np.min(all_fidelities):.4f}\n")
        
        # Count high-performing experiments
        high_perf = sum(1 for f in all_fidelities if f > 0.90)
        f.write(f"\nExperiments with fidelity > 0.90: {high_perf}/{len(all_results)} ({100*high_perf/len(all_results):.1f}%)\n")
        
        f.write("\n" + "-"*80 + "\n")
        f.write("KEY FINDINGS\n")
        f.write("-"*80 + "\n")
        
        # Group by priority
        priorities = {}
        for r in all_results:
            if 'PriorityA' in r['config']['name']:
                priorities.setdefault('A', []).append(r)
            elif 'PriorityB' in r['config']['name']:
                priorities.setdefault('B', []).append(r)
            elif 'PriorityC' in r['config']['name']:
                priorities.setdefault('C', []).append(r)
            elif 'PriorityD' in r['config']['name']:
                priorities.setdefault('D', []).append(r)
        
        for priority, results in sorted(priorities.items()):
            fids = [r['test_fidelity_mean'] for r in results]
            f.write(f"\nPriority {priority}: {len(results)} experiments\n")
            f.write(f"  Mean Fidelity: {np.mean(fids):.4f} +/- {np.std(fids):.4f}\n")
            f.write(f"  Range: [{np.min(fids):.4f}, {np.max(fids):.4f}]\n")
        
        f.write("\n" + "-"*80 + "\n")
        f.write("ANOMALY CHECKS\n")
        f.write("-"*80 + "\n")
        
        # Check for unexpectedly high 2-basis performance
        two_basis_results = [r for r in all_results if any(x in r['config']['measurement_type'] 
                                                            for x in ['XY', 'XZ', 'YZ'])]
        if two_basis_results:
            two_basis_fids = [r['test_fidelity_mean'] for r in two_basis_results]
            max_2basis = np.max(two_basis_fids)
            if max_2basis > 0.90:
                f.write(f"[WARNING] 2-basis model achieved {max_2basis:.4f} fidelity\n")
                f.write("    This exceeds expected performance and should be investigated.\n")
            else:
                f.write(f"[OK] 2-basis performance reasonable (max: {max_2basis:.4f})\n")
        
        f.write("\n" + "-"*80 + "\n")
        f.write("FILES GENERATED\n")
        f.write("-"*80 + "\n")
        
        files = list(output_dir.glob("*"))
        for file in sorted(files):
            f.write(f"  - {file.name}\n")
        
        f.write("\n" + "="*80 + "\n")
        f.write("END OF REPORT\n")
        f.write("="*80 + "\n")
    
    print(f"\nFinal report saved to: {report_path}")
    
    # Also print to console
    with open(report_path, 'r', encoding='utf-8') as f:
        print(f.read())

# Now run it
generate_final_report(all_results, OUTPUT_DIR)

print("\n" + "="*80)
print("ALL EXPERIMENTS COMPLETE!")
print(f"Total experiments run: {len(all_results)}")
print(f"Results saved to: {OUTPUT_DIR}")
print("="*80)

# Print summary statistics
print("\nQUICK SUMMARY:")
print(f"  Priority A: {len(results_a)} experiments")
print(f"  Priority B: {len(results_b)} experiments")
print(f"  Priority C: {len(results_c)} experiments")
print(f"  Priority D: {len(results_d)} experiments")
print(f"  Total: {len(all_results)} experiments")

all_fids = [r['test_fidelity_mean'] for r in all_results]
print(f"\n  Overall Mean Fidelity: {np.mean(all_fids):.4f} +/- {np.std(all_fids):.4f}")
print(f"  Best Result: {np.max(all_fids):.4f}")
print(f"  Worst Result: {np.min(all_fids):.4f}")


Final report saved to: expt_3\FINAL_REPORT.txt
QUANTUM STATE TOMOGRAPHY - FINAL EXPERIMENTAL REPORT

Date: 2025-10-04 08:46:47.622781
Master Seed: 48
Total Experiments: 126
Output Directory: expt_3

--------------------------------------------------------------------------------
OVERALL STATISTICS
--------------------------------------------------------------------------------
Mean Fidelity across all experiments: 0.8156 +/- 0.1045
Best Fidelity: 0.9999
Worst Fidelity: 0.5605

Experiments with fidelity > 0.90: 33/126 (26.2%)

--------------------------------------------------------------------------------
KEY FINDINGS
--------------------------------------------------------------------------------

Priority A: 54 experiments
  Mean Fidelity: 0.8529 +/- 0.0559
  Range: [0.7679, 0.9252]

Priority B: 60 experiments
  Mean Fidelity: 0.7706 +/- 0.1252
  Range: [0.5605, 0.9999]

Priority C: 6 experiments
  Mean Fidelity: 0.8728 +/- 0.0512
  Range: [0.8198, 0.9252]

Priority D: 6 experiments