# Parameter Sweep for Network Architectures: R² Improvement Analysis

This notebook performs a parameter sweep over different network architectures and hyperparameters for the SBI Trainer, using synthetic spectra. The goal is to identify which settings yield the best R² improvements.

In [1]:
# Import Required Libraries
import numpy as np
import matplotlib.pyplot as plt
import itertools
import pandas as pd
from pathlib import Path

from sbi_delta.trainer import Trainer
from sbi_delta.simulator.emission_simulator import EmissionSimulator
from sbi_delta.spectra_manager import SpectraManager
from sbi_delta.filter_bank import FilterBank
from sbi_delta.config import BaseConfig, FilterConfig, ExcitationConfig
from sbi_delta.excitation_manager import ExcitationManager

In [2]:
# Set Up Real Data and Configuration (matching sbi_hyperparameter_search.ipynb)
import os
base_path = os.path.abspath(".")
fluorophore_names = ["JF479", "JF525", "JF552", "JF608", "JFX650", "JFX673"]

config = BaseConfig(
    min_wavelength=400,
    max_wavelength=750,
    wavelength_step=1,
    spectra_folder=os.path.join(base_path, "data/spectra_npz"),
    dye_names=fluorophore_names,
    bg_dye='AF_v1',
    photon_budget=1000,
)
filter_cfgs = [
    FilterConfig(start, stop, sharpness=1)
    for start, stop in zip([490, 530, 570, 620, 680], [530, 570, 620, 680, 740])
]
excitation_cfg = ExcitationConfig(excitation_mode="min_crosstalk")
spectra_manager = SpectraManager(config)
spectra_manager.load()
excitation_manager = ExcitationManager(config, excitation_cfg, spectra_manager)
filter_bank = FilterBank(config, filter_cfgs)
simulator = EmissionSimulator(
    spectra_manager=spectra_manager,
    filter_bank=filter_bank,
    config=config,
    excitation_manager=excitation_manager
)

In [3]:
# Define Architecture Parameter Sweep Grid
# Sweep only over network architecture (density estimator, hidden features, num transforms)
param_grid = {
    'density_estimator': ['maf', 'nsf', 'mdn'],
    'hidden_features': [32, 64, 128],
    'num_transforms': [2, 4]
}

# Create all combinations
keys, values = zip(*param_grid.items())
param_combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]
print(f"Total architecture parameter combinations: {len(param_combinations)}")

Total architecture parameter combinations: 18


In [None]:
# Run Trainer for Each Architecture Parameter Combination (skip known slow config)
results = []
n_train = 2000  # Use a realistic value for architecture sweep
n_val = 500
SKIP_CONFIG = {'density_estimator': 'mdn', 'hidden_features': 32, 'num_transforms': 4}

for i, params in enumerate(param_combinations):
    if (params['density_estimator'] == SKIP_CONFIG['density_estimator'] and
        params['hidden_features'] == SKIP_CONFIG['hidden_features'] and
        params['num_transforms'] == SKIP_CONFIG['num_transforms']):
        print(f"SKIPPING: {params} (known slow configuration)")
        results.append({**params, 'mean_r2': None, 'r2_scores': None, 'timed_out': True})
        continue
    print(f"Running combination {i+1}/{len(param_combinations)}: {params}")
    network_architecture = {
        'density_estimator': params['density_estimator'],
        'hidden_features': params['hidden_features'],
        'num_transforms': params['num_transforms']
    }
    save_dir = f"arch_sweep_run_{i}"
    trainer = Trainer(simulator, n_train=n_train, n_val=n_val, save_dir=save_dir, network_architecture=network_architecture)
    try:
        posterior = trainer.train()
        r2_scores, _, _ = trainer.validate()
        mean_r2 = np.mean(r2_scores)
        results.append({
            **params,
            'mean_r2': mean_r2,
            'r2_scores': r2_scores,
            'timed_out': False
        })
    except Exception as e:
        print(f"FAILED: {params} with error: {e}")
        results.append({**params, 'mean_r2': None, 'r2_scores': None, 'timed_out': True})

Running combination 1/18: {'density_estimator': 'maf', 'hidden_features': 32, 'num_transforms': 2}
 Neural network successfully converged after 172 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 172
        Best validation performance: -7.7921
        -------------------------
        

        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 172
        Best validation performance: -7.7921
        -------------------------
        


Validating: 100%|██████████| 500/500 [00:13<00:00, 37.39it/s]



Validation mean R^2: 0.630, RMSE: 0.1127
Running combination 2/18: {'density_estimator': 'maf', 'hidden_features': 32, 'num_transforms': 4}
 Neural network successfully converged after 125 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 125
        Best validation performance: -7.2994
        -------------------------
        

        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 125
        Best validation performance: -7.2994
        -------------------------
        


Validating: 100%|██████████| 500/500 [00:14<00:00, 35.28it/s]



Validation mean R^2: 0.648, RMSE: 0.1141
Running combination 3/18: {'density_estimator': 'maf', 'hidden_features': 64, 'num_transforms': 2}
 Neural network successfully converged after 131 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 131
        Best validation performance: -7.2718
        -------------------------
        

        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 131
        Best validation performance: -7.2718
        -------------------------
        


Validating: 100%|██████████| 500/500 [00:14<00:00, 35.69it/s]



Validation mean R^2: 0.615, RMSE: 0.1203
Running combination 4/18: {'density_estimator': 'maf', 'hidden_features': 64, 'num_transforms': 4}
 Neural network successfully converged after 132 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 132
        Best validation performance: -7.2259
        -------------------------
        

        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 132
        Best validation performance: -7.2259
        -------------------------
        


Validating: 100%|██████████| 500/500 [00:13<00:00, 36.41it/s]



Validation mean R^2: 0.694, RMSE: 0.1120
Running combination 5/18: {'density_estimator': 'maf', 'hidden_features': 128, 'num_transforms': 2}
 Neural network successfully converged after 138 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 138
        Best validation performance: -7.6644
        -------------------------
        

        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 138
        Best validation performance: -7.6644
        -------------------------
        


Validating: 100%|██████████| 500/500 [00:13<00:00, 37.45it/s]



Validation mean R^2: 0.685, RMSE: 0.1119
Running combination 6/18: {'density_estimator': 'maf', 'hidden_features': 128, 'num_transforms': 4}
 Neural network successfully converged after 135 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 135
        Best validation performance: -7.3980
        -------------------------
        

        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 135
        Best validation performance: -7.3980
        -------------------------
        


Validating: 100%|██████████| 500/500 [00:13<00:00, 35.82it/s]



Validation mean R^2: 0.636, RMSE: 0.1209
Running combination 7/18: {'density_estimator': 'nsf', 'hidden_features': 32, 'num_transforms': 2}
 Neural network successfully converged after 75 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 75
        Best validation performance: -7.0412
        -------------------------
        

        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 75
        Best validation performance: -7.0412
        -------------------------
        


torch.linalg.solve_triangular has its arguments reversed and does not return a copy of one of the inputs.
X = torch.triangular_solve(B, A).solution
should be replaced with
X = torch.linalg.solve_triangular(A, B). (Triggered internally at ../aten/src/ATen/native/BatchLinearAlgebra.cpp:2190.)
  outputs, _ = torch.triangular_solve(
torch.linalg.solve_triangular has its arguments reversed and does not return a copy of one of the inputs.
X = torch.triangular_solve(B, A).solution
should be replaced with
X = torch.linalg.solve_triangular(A, B). (Triggered internally at ../aten/src/ATen/native/BatchLinearAlgebra.cpp:2190.)
  outputs, _ = torch.triangular_solve(
Validating: 100%|██████████| 500/500 [00:11<00:00, 44.22it/s]



Validation mean R^2: 0.565, RMSE: 0.1251
Running combination 8/18: {'density_estimator': 'nsf', 'hidden_features': 32, 'num_transforms': 4}
 Neural network successfully converged after 80 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 80
        Best validation performance: -7.2948
        -------------------------
        

        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 80
        Best validation performance: -7.2948
        -------------------------
        


Validating:  72%|███████▏  | 362/500 [03:39<01:23,  1.65it/s]



KeyboardInterrupt: 

In [None]:
# Collect and Store R² Scores
results_df = pd.DataFrame(results)
results_df.sort_values('mean_r2', ascending=False, inplace=True)
results_df.reset_index(drop=True, inplace=True)
results_df.head()

In [None]:
# Visualize R² Score Improvements Across Architecture Parameters
plt.figure(figsize=(12, 7))
for de in param_grid['density_estimator']:
    for nt in param_grid['num_transforms']:
        subset = results_df[(results_df['density_estimator'] == de) & (results_df['num_transforms'] == nt)]
        plt.plot(subset['hidden_features'], subset['mean_r2'], marker='o', label=f"{de}, num_transforms={nt}")
plt.xlabel('Hidden Features')
plt.ylabel('Mean R² Score')
plt.title('Mean R² Score vs. Hidden Features for Each Architecture')
plt.legend(title='Architecture (Density Estimator, Num Transforms)')
plt.grid(True)
plt.show()