# SBI Training Workflow with sbi_delta

This notebook demonstrates how to set up, train, and evaluate a simulation-based inference (SBI) model for multiplexed fluorescence microscopy using the new `SBITrainer` class and the sbi_delta configuration and manager system.

We will:
- Set up experiment configuration (fluorophores, filters, priors, training parameters)
- Initialize the SBITrainer
- Set up the prior distribution
- Generate training data
- Train the SBI model
- Evaluate on validation data
- Analyze multiplexing capacity
- Save experiment results


In [1]:
# Section 1: Import Required Libraries
import os, sys
base_path = os.path.abspath("../..")
sys.path.append(base_path)  # Adjust path as needed
import numpy as np
import torch
from sbi_delta.config import BaseConfig, ExcitationConfig, FilterConfig, PriorConfig
from sbi_delta.spectra_manager import SpectraManager
from sbi_delta.excitation_manager import ExcitationManager
from sbi_delta.filter_bank import FilterBank
from sbi_delta.prior_manager import PriorManager
from sbi_delta.simulator.emission_simulator import EmissionSimulator
from sbi_delta.trainer import Trainer


  from .autonotebook import tqdm as notebook_tqdm
INFO:arviz.preview:arviz_base not installed
INFO:arviz.preview:arviz_stats not installed
INFO:arviz.preview:arviz_plots not installed
INFO:arviz.preview:arviz_base not installed
INFO:arviz.preview:arviz_stats not installed
INFO:arviz.preview:arviz_plots not installed


In [2]:
# Section 2: Define Experiment Configuration
# Define fluorophores, filters, and prior
fluorophore_names = ["JF479", "JF525", "JF552", "JF608", "JFX673"]
center_wavelengths = [535, 565, 595, 625, 670]
bandwidths = [30, 30, 30, 30, 60]

base_cfg = BaseConfig(
    min_wavelength=400,
    max_wavelength=750,
    wavelength_step=1,
    spectra_folder=os.path.join(base_path, "data/spectra_npz"),
    dye_names=fluorophore_names,
    bg_dye=None,
    photon_budget=1000,
)
exc_cfg = ExcitationConfig(excitation_mode="min_crosstalk")
filter_cfgs = [
    FilterConfig(start=wl - bw//2, stop=wl + bw//2, sharpness=2)
    for wl, bw in zip(center_wavelengths, bandwidths)
]
prior_cfg = PriorConfig(
    dirichlet_concentration=5.0,
    include_background_ratio=False,
    background_ratio_bounds=(0.1, 0.2)
)

spectra_mgr = SpectraManager(base_cfg)
spectra_mgr.load()
excitation_mgr = ExcitationManager(base_cfg, exc_cfg, spectra_mgr)
filter_bank = FilterBank(base_cfg, filter_cfgs)
prior_mgr = PriorManager(prior_cfg, base_cfg)
simulator = EmissionSimulator(
    spectra_manager=spectra_mgr,
    filter_bank=filter_bank,
    config=base_cfg,
    excitation_manager=excitation_mgr,
    prior_manager=prior_mgr,
)


INFO:sbi_delta.spectra_manager:Initialized SpectraManager(folder=/groups/spruston/home/moharb/sbi-DELTA/data/spectra_npz, dyes=['JF479', 'JF525', 'JF552', 'JF608', 'JFX673'], bg_dye=None)
INFO:sbi_delta.spectra_manager:Starting load() of spectra
INFO:sbi_delta.spectra_manager:Found 22 .npz files in '/groups/spruston/home/moharb/sbi-DELTA/data/spectra_npz'
INFO:sbi_delta.spectra_manager:Loading emission spectrum for dye 'JF479' from /groups/spruston/home/moharb/sbi-DELTA/data/spectra_npz/JF479.npz
INFO:sbi_delta.spectra_manager:Completed processing for 'JF479'
INFO:sbi_delta.spectra_manager:Loading emission spectrum for dye 'JF525' from /groups/spruston/home/moharb/sbi-DELTA/data/spectra_npz/JF525.npz
INFO:sbi_delta.spectra_manager:Completed processing for 'JF525'
INFO:sbi_delta.spectra_manager:Loading emission spectrum for dye 'JF552' from /groups/spruston/home/moharb/sbi-DELTA/data/spectra_npz/JF552.npz
INFO:sbi_delta.spectra_manager:Completed processing for 'JF552'
INFO:sbi_delta.spe

In [None]:
# Section 3: Initialize Trainer with Simulator
trainer = Trainer(simulator, n_train=2000, n_val=500, save_dir="sbi_training_demo_results")


In [4]:
# Section 4: (Skipped) Prior is handled by Trainer from the simulator.
# You can inspect with: print(trainer.prior)
print("Trainer prior:", trainer.prior)


Prior: BoxUniform(Uniform(low: torch.Size([5]), high: torch.Size([5])), 1)


In [5]:
# Section 5: (Skipped) Training data is generated by Trainer.train()


Training data shapes: theta torch.Size([2000, 5]), x torch.Size([2000, 25])


In [6]:
# Section 6: Train the SBI Model
posterior = trainer.train()
print("Posterior trained.")


 Neural network successfully converged after 62 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 62
        Best validation performance: -6.5402
        -------------------------
        
Posterior trained.

        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 62
        Best validation performance: -6.5402
        -------------------------
        
Posterior trained.


In [7]:
# Section 7: Evaluate Model on Validation Data
r2_scores, rmse = trainer.validate()
print(f"Validation mean R^2: {np.mean(r2_scores):.3f}, RMSE: {rmse:.4f}")


Drawing 100 posterior samples: 190it [00:00, 17549.00it/s]            
Drawing 100 posterior samples: 190it [00:00, 17549.00it/s]            
Drawing 100 posterior samples: 191it [00:00, 19123.27it/s]            
Drawing 100 posterior samples: 191it [00:00, 19123.27it/s]            
Drawing 100 posterior samples: 121it [00:00, 7431.37it/s]             
Drawing 100 posterior samples: 121it [00:00, 7431.37it/s]             
Drawing 100 posterior samples: 174it [00:00, 17306.35it/s]            
Drawing 100 posterior samples: 174it [00:00, 17306.35it/s]            
Drawing 100 posterior samples: 123it [00:00, 11471.06it/s]            
Drawing 100 posterior samples: 123it [00:00, 11471.06it/s]            
Drawing 100 posterior samples: 181it [00:00, 18134.61it/s]            
Drawing 100 posterior samples: 181it [00:00, 18134.61it/s]            
Drawing 100 posterior samples: 181it [00:00, 18325.02it/s]            
Drawing 100 posterior samples: 181it [00:00, 18325.02it/s]            
Drawin

Validation mean R^2: 0.457


In [8]:
# Section 8: (Optional) Analyze Multiplexing Capacity
# Not implemented in Trainer class. You can add a method or do custom analysis here if needed.


Drawing 50 posterior samples: 91it [00:00, 9130.05it/s]             
Drawing 50 posterior samples:   0%|          | 0/50 [00:00<?, ?it/s]
Drawing 50 posterior samples: 105it [00:00, 11059.26it/s]           
Drawing 50 posterior samples: 105it [00:00, 11059.26it/s]           
Drawing 50 posterior samples: 98it [00:00, 10274.50it/s]            
Drawing 50 posterior samples: 98it [00:00, 10274.50it/s]            
Drawing 50 posterior samples: 128it [00:00, 13352.01it/s]           
Drawing 50 posterior samples: 128it [00:00, 13352.01it/s]           
Drawing 50 posterior samples: 67it [00:00, 7048.37it/s]             
Drawing 50 posterior samples: 67it [00:00, 7048.37it/s]             
Drawing 50 posterior samples: 101it [00:00, 10718.98it/s]           
Drawing 50 posterior samples: 101it [00:00, 10718.98it/s]           
Drawing 50 posterior samples: 129it [00:00, 13596.31it/s]           
Drawing 50 posterior samples: 129it [00:00, 13596.31it/s]           
Drawing 50 posterior samples: 124i

Multiplexing: 62.0% of test samples have R^2 >= 0.8





In [9]:
# Section 9: Save Experiment Results
trainer.save()


Experiment results saved to sbi_training_demo_results/results.pt
