# Enhanced SBI Simulator - Prior Demonstration

This notebook demonstrates the new prior functionality in the enhanced SBI simulator, including:
- Custom flat prior implementation
- Background parameter integration
- Prior visualization tools
- Parameter exploration

In [1]:
import sys
import os
sys.path.append('..')

import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns

from multiplex_sim.sbi_simulator_with_filters import (
    EnhancedSBISimulator, 
    EnhancedSBIConfig, 
    create_enhanced_sbi_simulator,
    CustomFlatPrior
)
from multiplex_sim.prior_visualization import PriorVisualizer

  from .autonotebook import tqdm as notebook_tqdm


## 1. Setup Enhanced Simulator with Background Parameters

In [2]:
# Configure enhanced simulator
config = EnhancedSBIConfig(
    n_channels=5,
    include_filter_params=True,
    center_wavelength_bounds=(500, 800),
    bandwidth_bounds=(10, 50),
    include_background_params=True,
    background_bounds=(10.0, 100.0)
)

# Create simulator
fluorophore_names = ['AF488', 'AF555', 'AF594', 'AF647', 'AF680']
simulator = create_enhanced_sbi_simulator(
    fluorophore_names=fluorophore_names,
    config=config
)

print(f"Simulator created with {len(fluorophore_names)} fluorophores")
print(f"Total parameters: {simulator.total_params}")
print(f"Concentration params: {simulator.n_concentration_params}")
print(f"Filter params: {simulator.n_filter_params}")
print(f"Background params: {simulator.n_background_params}")

TypeError: EnhancedSBIConfig.__init__() got an unexpected keyword argument 'include_background_params'

## 2. Create Custom Prior

In [None]:
# Define prior configuration
prior_config = {
    'concentration': 1.5,  # Dirichlet concentration parameter
    'center_low': 520,
    'center_high': 780,
    'bandwidth_low': 15,
    'bandwidth_high': 45,
    'background_low': 10.0,
    'background_high': 100.0
}

# Create custom prior
prior = simulator.create_custom_prior(prior_config=prior_config)

print(f"Custom prior created with {prior.total_params} total parameters")
print(f"Prior includes background: {prior.include_background}")

## 3. Generate and Visualize Samples

In [None]:
# Generate samples from prior
n_samples = 5000
samples = prior.sample((n_samples,))
params = prior.extract_parameters(samples)

print("Sample shapes:")
for key, value in params.items():
    print(f"{key}: {value.shape}")

# Create visualizer
visualizer = PriorVisualizer(simulator)
visualizer.set_prior(prior)

## 4. Visualize Prior Distributions

In [None]:
# Plot prior distributions
fig1 = visualizer.plot_prior_distributions(n_samples=2000, figsize=(15, 12))
plt.show()

## 5. Visualize Parameter Correlations

In [None]:
# Plot correlation matrix
fig2 = visualizer.plot_parameter_correlations(n_samples=1000, figsize=(14, 12))
plt.show()

## 6. Explore Filter Configuration Space

In [None]:
# Plot filter configuration space
fig3 = visualizer.plot_filter_configuration_space(n_samples=1000, figsize=(12, 8))
plt.show()

## 7. Generate Training Data

In [None]:
# Generate training data for SBI
n_training = 10000
theta, x = simulator.generate_training_data(
    n_samples=n_training,
    prior_config=prior_config,
    use_custom_prior=True
)

print(f"Generated {n_training} training samples")
print(f"Parameters shape: {theta.shape}")
print(f"Observations shape: {x.shape}")

# Extract parameters from training data
train_params = prior.extract_parameters(theta)
print("\nTraining parameter statistics:")
for key, value in train_params.items():
    print(f"{key}: mean={value.mean():.3f}, std={value.std():.3f}")

## 8. Advanced Prior Configuration

In [None]:
# Example: Different prior configurations
configs = [
    {"concentration": 0.5, "name": "Sparse concentrations"},
    {"concentration": 2.0, "name": "Uniform concentrations"},
    {"concentration": 5.0, "name": "Concentrated concentrations"}
]

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for i, (config_dict, name) in enumerate([(c["config"], c["name"]) for c in configs]):
    temp_prior = simulator.create_custom_prior(prior_config={**prior_config, **config_dict})
    samples = temp_prior.sample((1000,))
    concentrations = temp_prior.extract_parameters(samples)['concentrations']
    
    # Plot concentration distribution
    for j in range(concentrations.shape[1]):
        axes[i].hist(concentrations[:, j], bins=30, alpha=0.7, label=f'Fluor {j+1}', density=True)
    axes[i].set_title(name)
    axes[i].set_xlabel('Concentration')
    axes[i].set_ylabel('Density')
    if i == 0:
        axes[i].legend()

plt.tight_layout()
plt.show()

## 9. Save and Load Configuration

In [None]:
# Save configuration for later use
import json

config_dict = {
    'fluorophore_names': fluorophore_names,
    'n_channels': config.n_channels,
    'prior_config': prior_config,
    'total_params': simulator.total_params
}

with open('enhanced_sbi_config.json', 'w') as f:
    json.dump(config_dict, f, indent=2)

print("Configuration saved to enhanced_sbi_config.json")

# Load configuration
with open('enhanced_sbi_config.json', 'r') as f:
    loaded_config = json.load(f)

print("Loaded configuration:")
for key, value in loaded_config.items():
    print(f"{key}: {value}")

## Summary

This notebook demonstrated:

1. **Enhanced SBI Simulator**: Extended with background parameters
2. **Custom Flat Prior**: Unified prior for all parameter types
3. **Prior Visualization**: Comprehensive plotting tools for parameter exploration
4. **Parameter Extraction**: Easy access to individual parameter groups
5. **Training Data Generation**: Seamless integration with SBI workflows

The enhanced simulator now supports:
- Concentration parameters (Dirichlet prior)
- Filter center wavelengths (Uniform prior)
- Filter bandwidths (Uniform prior)
- Background amplitude (Uniform prior)
- Custom prior configuration
- Comprehensive visualization tools