# Filter Optimization using SBI-DELTA

This notebook demonstrates how to use SBI-DELTA to optimize filter parameters (start wavelength and width) along with fluorophore concentrations.

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch

# Add parent directory to path for imports
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from sbi_delta.config import BaseConfig, ExcitationConfig, PriorConfig
from sbi_delta.spectra_manager import SpectraManager
from sbi_delta.excitation_manager import ExcitationManager
from sbi_delta.prior_manager import PriorManager
from sbi_delta.simulator.filter_prior_simulator import FilterPriorSimulator
from sbi_delta.trainer import Trainer

## Setup Configuration

Configure the simulation with filter parameters in the prior.

In [None]:
# Base configuration
base_cfg = BaseConfig(
    min_wavelength=400,
    max_wavelength=750,
    wavelength_step=1,
    spectra_folder="../data/spectra_npz",
    dye_names=["JF479", "JF525", "JF552", "JF608", "JFX673"],
    photon_budget=100,
)

# Excitation configuration
exc_cfg = ExcitationConfig(
    excitation_mode="min_crosstalk"
)

# Prior configuration with filter parameters
prior_cfg = PriorConfig(
    dirichlet_concentration=5.0,
    include_filter_params=True,
    n_filters=5,
    max_filter_width=50.0,
    min_filter_width=10.0
)

# Initialize managers
spectra_mgr = SpectraManager(base_cfg)
spectra_mgr.load()
excitation_mgr = ExcitationManager(base_cfg, exc_cfg, spectra_mgr)
prior_mgr = PriorManager(prior_cfg, base_cfg, excitation_mgr)

# Create simulator
sim = FilterPriorSimulator(
    spectra_manager=spectra_mgr,
    config=base_cfg,
    excitation_manager=excitation_mgr,
    prior_manager=prior_mgr,
    n_filters=prior_cfg.n_filters,
    max_filter_width=prior_cfg.max_filter_width
)

## Visualize Prior Distribution

Check the joint prior distribution for concentrations and filter parameters.

In [None]:
# Sample from joint prior
joint_prior = prior_mgr.get_joint_prior()
samples = joint_prior.sample((1000,)).numpy()

# Plot distributions
n_dyes = len(base_cfg.dye_names)
n_filters = prior_cfg.n_filters
n_params = samples.shape[1]

fig, axes = plt.subplots(3, 1, figsize=(15, 12))

# Plot concentration distributions
ax = axes[0]
for i in range(n_dyes):
    ax.hist(samples[:, i], bins=30, alpha=0.6, label=f'{base_cfg.dye_names[i]}')
ax.set_title('Dye Concentration Distributions')
ax.set_xlabel('Concentration')
ax.set_ylabel('Count')
ax.legend()

# Plot filter start wavelengths
ax = axes[1]
start_idx = n_dyes
for i in range(n_filters):
    ax.hist(samples[:, start_idx + i*2], bins=30, alpha=0.6, label=f'Filter {i+1}')
ax.set_title('Filter Start Wavelength Distributions')
ax.set_xlabel('Wavelength (nm)')
ax.set_ylabel('Count')
ax.legend()

# Plot filter widths
ax = axes[2]
for i in range(n_filters):
    ax.hist(samples[:, start_idx + i*2 + 1], bins=30, alpha=0.6, label=f'Filter {i+1}')
ax.set_title('Filter Width Distributions')
ax.set_xlabel('Width (nm)')
ax.set_ylabel('Count')
ax.legend()

plt.tight_layout()
plt.show()

## Visualize Example Filters

Sample some filter configurations and visualize them.

In [None]:
def plot_filter_set(params):
    """Plot a set of filters with their parameters."""
    n_dyes = len(base_cfg.dye_names)
    filter_params = params[n_dyes:]
    
    plt.figure(figsize=(12, 6))
    wavelengths = np.arange(base_cfg.min_wavelength, base_cfg.max_wavelength)
    
    # Plot excitation wavelengths
    exc_wavelengths = excitation_mgr.get_wavelengths()
    for wl in exc_wavelengths:
        plt.axvline(wl, color='r', linestyle='--', alpha=0.3)
    
    # Plot filters
    for i in range(n_filters):
        start = filter_params[i*2]
        width = filter_params[i*2 + 1]
        stop = start + width
        
        # Create filter profile (simple rectangle for visualization)
        profile = np.zeros_like(wavelengths)
        mask = (wavelengths >= start) & (wavelengths <= stop)
        profile[mask] = 1.0
        
        plt.plot(wavelengths, profile, label=f'Filter {i+1} ({start:.0f}-{stop:.0f}nm)')
    
    plt.xlabel('Wavelength (nm)')
    plt.ylabel('Transmission')
    plt.title('Filter Configurations')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

# Plot a few random filter configurations
for _ in range(3):
    params = joint_prior.sample()
    plot_filter_set(params.numpy())

## Setup SBI Training

Set up the trainer and prepare for neural posterior estimation.

In [None]:
# Create trainer
trainer = Trainer(
    simulator=sim,
    prior_manager=prior_mgr,
    training_batch_size=100,
    num_workers=4
)

# Generate training data
n_simulations = 1000
trainer.simulate_for_sbi(n_simulations)

# Train the neural posterior
density_estimator = trainer.train_density_estimator()

## Test Inference

Generate synthetic data and try to recover the parameters.

In [None]:
# Generate test data
true_params = joint_prior.sample((1,))
observation = sim.simulate(true_params.squeeze().numpy())

# Build posterior and sample
posterior = trainer.build_posterior(observation)
samples = posterior.sample((1000,))

# Plot results
n_dyes = len(base_cfg.dye_names)
n_filters = prior_cfg.n_filters

fig, axes = plt.subplots(3, 1, figsize=(15, 12))

# Plot concentration posterior
ax = axes[0]
for i in range(n_dyes):
    ax.hist(samples[:, i].numpy(), bins=30, alpha=0.6, 
            label=f'{base_cfg.dye_names[i]} (true={true_params[0,i]:.2f})')
ax.set_title('Inferred Dye Concentrations')
ax.set_xlabel('Concentration')
ax.set_ylabel('Count')
ax.legend()

# Plot filter start wavelength posterior
ax = axes[1]
start_idx = n_dyes
for i in range(n_filters):
    ax.hist(samples[:, start_idx + i*2].numpy(), bins=30, alpha=0.6,
            label=f'Filter {i+1} (true={true_params[0,start_idx + i*2]:.0f}nm)')
ax.set_title('Inferred Filter Start Wavelengths')
ax.set_xlabel('Wavelength (nm)')
ax.set_ylabel('Count')
ax.legend()

# Plot filter width posterior
ax = axes[2]
for i in range(n_filters):
    ax.hist(samples[:, start_idx + i*2 + 1].numpy(), bins=30, alpha=0.6,
            label=f'Filter {i+1} (true={true_params[0,start_idx + i*2 + 1]:.0f}nm)')
ax.set_title('Inferred Filter Widths')
ax.set_xlabel('Width (nm)')
ax.set_ylabel('Count')
ax.legend()

plt.tight_layout()
plt.show()

# Plot true vs inferred filters
plt.figure(figsize=(12, 6))
print("True filter configuration:")
plot_filter_set(true_params.numpy().squeeze())

print("\nExample inferred filter configuration:")
plot_filter_set(samples[0].numpy())