# Hyperparameter Search for SBI Trainer with sbi_delta

This notebook performs a hyperparameter search for the SBI Trainer using the sbi_delta package. We systematically vary key SNPE training parameters and evaluate their effect on model performance.

**Outline:**
- Import required libraries
- Define experiment configuration
- Define hyperparameter grid
- Run hyperparameter search loop
- Collect and analyze results
- Visualize hyperparameter effects
- Save results

In [None]:
# Section 1: Import Required Libraries
import os, sys
base_path = os.path.abspath("../..")
sys.path.append(base_path)  # Adjust path as needed
import numpy as np
import torch
import pandas as pd
import itertools
import matplotlib.pyplot as plt
from sbi.inference import SNPE
from sbi_delta.config import BaseConfig, ExcitationConfig, FilterConfig, PriorConfig
from sbi_delta.spectra_manager import SpectraManager
from sbi_delta.excitation_manager import ExcitationManager
from sbi_delta.filter_bank import FilterBank
from sbi_delta.prior_manager import PriorManager
from sbi_delta.simulator.emission_simulator import EmissionSimulator
from sbi_delta.trainer import Trainer

  from .autonotebook import tqdm as notebook_tqdm
INFO:arviz.preview:arviz_base not installed
INFO:arviz.preview:arviz_stats not installed
INFO:arviz.preview:arviz_plots not installed


In [2]:
# Section 2: Define Experiment Configuration
# Define fluorophores, filters, and prior
fluorophore_names = ["JF479", "JF525", "JF552", "JF608", "JFX650", "JFX673"]

base_cfg = BaseConfig(
    min_wavelength=400,
    max_wavelength=750,
    wavelength_step=1,
    spectra_folder=os.path.join(base_path, "data/spectra_npz"),
    dye_names=fluorophore_names,
    bg_dye='AF_v1',
    photon_budget=1000,
)
exc_cfg = ExcitationConfig(excitation_mode="min_crosstalk")
filter_cfgs = [
    FilterConfig(start, stop, sharpness=1)
    for start, stop in zip([490, 530, 570, 620, 680], [530, 570, 620, 680, 740])
]
prior_cfg = PriorConfig(
    dirichlet_concentration=5.0,
    include_background_ratio=True,
    background_ratio_bounds=(0.05, 0.15),
)

spectra_mgr = SpectraManager(base_cfg)
spectra_mgr.load()
excitation_mgr = ExcitationManager(base_cfg, exc_cfg, spectra_mgr)
filter_bank = FilterBank(base_cfg, filter_cfgs)
prior_mgr = PriorManager(prior_cfg, base_cfg)
simulator = EmissionSimulator(
    spectra_manager=spectra_mgr,
    filter_bank=filter_bank,
    config=base_cfg,
    excitation_manager=excitation_mgr,
    prior_manager=prior_mgr,
)
print("Simulator and managers initialized.")

INFO:sbi_delta.spectra_manager:Initialized SpectraManager(folder=/groups/spruston/home/moharb/sbi-DELTA/data/spectra_npz, dyes=['JF479', 'JF525', 'JF552', 'JF608', 'JFX650', 'JFX673'], bg_dye=AF_v1)
INFO:sbi_delta.spectra_manager:Starting load() of spectra
INFO:sbi_delta.spectra_manager:Found 22 .npz files in '/groups/spruston/home/moharb/sbi-DELTA/data/spectra_npz'
INFO:sbi_delta.spectra_manager:Loading emission spectrum for dye 'JF479' from /groups/spruston/home/moharb/sbi-DELTA/data/spectra_npz/JF479.npz
INFO:sbi_delta.spectra_manager:Completed processing for 'JF479'
INFO:sbi_delta.spectra_manager:Loading emission spectrum for dye 'JF525' from /groups/spruston/home/moharb/sbi-DELTA/data/spectra_npz/JF525.npz
INFO:sbi_delta.spectra_manager:Completed processing for 'JF525'
INFO:sbi_delta.spectra_manager:Loading emission spectrum for dye 'JF552' from /groups/spruston/home/moharb/sbi-DELTA/data/spectra_npz/JF552.npz
INFO:sbi_delta.spectra_manager:Completed processing for 'JF552'
INFO:sb

Simulator and managers initialized.


In [3]:
# Section 3: Define Hyperparameter Grid
# Specify the hyperparameters to search over
hyperparams_grid = {
    'training_batch_size': [64, 128, 256],
    'learning_rate': [1e-3, 5e-4, 1e-4],
    'validation_fraction': [0.05, 0.1, 0.2],
    'stop_after_epochs': [5, 10, 20]
}

# Create all combinations of hyperparameters
grid_keys = list(hyperparams_grid.keys())
grid_combos = list(itertools.product(*[hyperparams_grid[k] for k in grid_keys]))
print(f"Total hyperparameter combinations: {len(grid_combos)}")

Total hyperparameter combinations: 81


In [4]:
# Section 4: Run Hyperparameter Search Loop
from tqdm import tqdm
results = []

n_train = 2000  # Reduce for speed; increase for real experiments
n_val = 500

for idx, combo in enumerate(tqdm(grid_combos, desc='Hyperparameter Search')):
    params = dict(zip(grid_keys, combo))
    print(f"\nRunning combo {idx+1}/{len(grid_combos)}: {params}")
    trainer = Trainer(
        simulator,
        n_train=n_train,
        n_val=n_val,
        save_dir=None  # Don't save intermediate results
    )
    # Patch the train method to use current hyperparameters
    def train_with_params(self):
        train_theta, train_x = self.generate_training_data()
        inference = SNPE(prior=self.prior)
        inference.append_simulations(train_theta, train_x)
        density_estimator = inference.train(
            training_batch_size=params['training_batch_size'],
            learning_rate=params['learning_rate'],
            validation_fraction=params['validation_fraction'],
            stop_after_epochs=params['stop_after_epochs'],
            show_train_summary=False
        )
        self.posterior = inference.build_posterior(density_estimator)
        return self.posterior
    import types
    trainer.train = types.MethodType(train_with_params, trainer)
    try:
        posterior = trainer.train()
        r2_scores, rmse_scores, rmse = trainer.validate()
        mean_r2 = np.mean(r2_scores)
        mean_rmse = np.mean(rmse_scores)
        mean_width = np.mean(trainer.results['posterior_width'])
        results.append({
            **params,
            'mean_r2': mean_r2,
            'mean_rmse': mean_rmse,
            'mean_width': mean_width
        })
    except Exception as e:
        print(f"Failed for params {params}: {e}")
        results.append({**params, 'mean_r2': np.nan, 'mean_rmse': np.nan, 'mean_width': np.nan})

Hyperparameter Search:   0%|          | 0/81 [00:00<?, ?it/s]


Running combo 1/81: {'training_batch_size': 64, 'learning_rate': 0.001, 'validation_fraction': 0.05, 'stop_after_epochs': 5}


Hyperparameter Search:   1%|          | 1/81 [00:03<04:11,  3.14s/it]

Failed for params {'training_batch_size': 64, 'learning_rate': 0.001, 'validation_fraction': 0.05, 'stop_after_epochs': 5}: name 'SNPE' is not defined

Running combo 2/81: {'training_batch_size': 64, 'learning_rate': 0.001, 'validation_fraction': 0.05, 'stop_after_epochs': 10}


Hyperparameter Search:   2%|▏         | 2/81 [00:06<04:07,  3.13s/it]

Failed for params {'training_batch_size': 64, 'learning_rate': 0.001, 'validation_fraction': 0.05, 'stop_after_epochs': 10}: name 'SNPE' is not defined

Running combo 3/81: {'training_batch_size': 64, 'learning_rate': 0.001, 'validation_fraction': 0.05, 'stop_after_epochs': 20}


Hyperparameter Search:   4%|▎         | 3/81 [00:09<04:05,  3.14s/it]

Failed for params {'training_batch_size': 64, 'learning_rate': 0.001, 'validation_fraction': 0.05, 'stop_after_epochs': 20}: name 'SNPE' is not defined

Running combo 4/81: {'training_batch_size': 64, 'learning_rate': 0.001, 'validation_fraction': 0.1, 'stop_after_epochs': 5}


Hyperparameter Search:   5%|▍         | 4/81 [00:12<04:02,  3.15s/it]

Failed for params {'training_batch_size': 64, 'learning_rate': 0.001, 'validation_fraction': 0.1, 'stop_after_epochs': 5}: name 'SNPE' is not defined

Running combo 5/81: {'training_batch_size': 64, 'learning_rate': 0.001, 'validation_fraction': 0.1, 'stop_after_epochs': 10}


Hyperparameter Search:   5%|▍         | 4/81 [00:15<04:54,  3.82s/it]



KeyboardInterrupt: 

In [None]:
# Section 5: Collect and Analyze Results
results_df = pd.DataFrame(results)
print(f"Total runs: {len(results_df)}")
display(results_df.head())

# Compute best settings by mean_r2
best_idx = results_df['mean_r2'].idxmax()
best_params = results_df.loc[best_idx]
print("Best hyperparameters by mean R^2:")
display(best_params)

In [None]:
# Section 6: Visualize Hyperparameter Effects
import seaborn as sns

# Plot mean R^2 vs each hyperparameter
for param in grid_keys:
    plt.figure(figsize=(6, 4))
    sns.lineplot(x=param, y='mean_r2', data=results_df, marker='o')
    plt.title(f'Mean R^2 vs {param}')
    plt.ylabel('Mean R^2')
    plt.xlabel(param)
    plt.tight_layout()
    plt.show()

# Heatmap for two hyperparameters (example: batch size vs learning rate)
pivot = results_df.pivot_table(index='training_batch_size', columns='learning_rate', values='mean_r2')
plt.figure(figsize=(8, 6))
sns.heatmap(pivot, annot=True, fmt='.3f', cmap='viridis')
plt.title('Mean R^2: Batch Size vs Learning Rate')
plt.ylabel('Batch Size')
plt.xlabel('Learning Rate')
plt.tight_layout()
plt.show()

In [None]:
# Section 7: Save Hyperparameter Search Results
results_df.to_csv('sbi_hyperparam_search_results.csv', index=False)
print('Results saved to sbi_hyperparam_search_results.csv')

with open('sbi_hyperparam_search_best.json', 'w') as f:
    import json
    json.dump(best_params.to_dict(), f, indent=2)
print('Best hyperparameters saved to sbi_hyperparam_search_best.json')