# Parameter Sweep for Network Architectures: R² Improvement Analysis

This notebook performs a parameter sweep over different network architectures and hyperparameters for the SBI Trainer, using synthetic spectra. The goal is to identify which settings yield the best R² improvements.

In [1]:
# Import Required Libraries
import numpy as np
import matplotlib.pyplot as plt
import itertools
import pandas as pd
from pathlib import Path

from sbi_delta.trainer import Trainer
from sbi_delta.simulator.emission_simulator import EmissionSimulator
from sbi_delta.spectra_manager import SpectraManager
from sbi_delta.filter_bank import FilterBank
from sbi_delta.config import BaseConfig, FilterConfig, ExcitationConfig
from sbi_delta.excitation_manager import ExcitationManager

In [2]:
# Set Up Real Data and Configuration (matching sbi_hyperparameter_search.ipynb)
import os
base_path = os.path.abspath(".")
fluorophore_names = ["JF479", "JF525", "JF552", "JF608", "JFX650", "JFX673"]

config = BaseConfig(
    min_wavelength=400,
    max_wavelength=750,
    wavelength_step=1,
    spectra_folder=os.path.join(base_path, "data/spectra_npz"),
    dye_names=fluorophore_names,
    bg_dye='AF_v1',
    photon_budget=1000,
)
filter_cfgs = [
    FilterConfig(start, stop, sharpness=1)
    for start, stop in zip([490, 530, 570, 620, 680], [530, 570, 620, 680, 740])
]
excitation_cfg = ExcitationConfig(excitation_mode="min_crosstalk")
spectra_manager = SpectraManager(config)
spectra_manager.load()
excitation_manager = ExcitationManager(config, excitation_cfg, spectra_manager)
filter_bank = FilterBank(config, filter_cfgs)
simulator = EmissionSimulator(
    spectra_manager=spectra_manager,
    filter_bank=filter_bank,
    config=config,
    excitation_manager=excitation_manager
)

In [3]:
# Define Architecture Parameter Sweep Grid
# Sweep only over network architecture (density estimator, hidden features, num transforms)
param_grid = {
    'density_estimator': ['maf', 'nsf', 'mdn'],
    'hidden_features': [32, 64, 128],
    'num_transforms': [2, 4]
}

# Create all combinations
keys, values = zip(*param_grid.items())
param_combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]
print(f"Total architecture parameter combinations: {len(param_combinations)}")

Total architecture parameter combinations: 18


In [4]:
# Run Trainer for Each Architecture Parameter Combination (skip known slow config)
results = []
n_train = 2000  # Use a realistic value for architecture sweep
n_val = 500
SKIP_CONFIG = {'density_estimator': 'mdn', 'hidden_features': 32, 'num_transforms': 4}

for i, params in enumerate(param_combinations):
    if (params['density_estimator'] == SKIP_CONFIG['density_estimator'] and
        params['hidden_features'] == SKIP_CONFIG['hidden_features'] and
        params['num_transforms'] == SKIP_CONFIG['num_transforms']):
        print(f"SKIPPING: {params} (known slow configuration)")
        results.append({**params, 'mean_r2': None, 'r2_scores': None, 'timed_out': True})
        continue
    print(f"Running combination {i+1}/{len(param_combinations)}: {params}")
    network_architecture = {
        'density_estimator': params['density_estimator'],
        'hidden_features': params['hidden_features'],
        'num_transforms': params['num_transforms']
    }
    save_dir = f"arch_sweep_run_{i}"
    trainer = Trainer(simulator, n_train=n_train, n_val=n_val, save_dir=save_dir, network_architecture=network_architecture)
    try:
        posterior = trainer.train()
        r2_scores, _, _ = trainer.validate()
        mean_r2 = np.mean(r2_scores)
        results.append({
            **params,
            'mean_r2': mean_r2,
            'r2_scores': r2_scores,
            'timed_out': False
        })
    except Exception as e:
        print(f"FAILED: {params} with error: {e}")
        results.append({**params, 'mean_r2': None, 'r2_scores': None, 'timed_out': True})

Running combination 1/18: {'density_estimator': 'maf', 'hidden_features': 32, 'num_transforms': 2}
 Neural network successfully converged after 125 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 125
        Best validation performance: -7.4717
        -------------------------
        

        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 125
        Best validation performance: -7.4717
        -------------------------
        


Validating: 100%|██████████| 500/500 [00:13<00:00, 37.22it/s]



Validation mean R^2: 0.641, RMSE: 0.1131
Running combination 2/18: {'density_estimator': 'maf', 'hidden_features': 32, 'num_transforms': 4}
 Neural network successfully converged after 102 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 102
        Best validation performance: -7.4412
        -------------------------
        

        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 102
        Best validation performance: -7.4412
        -------------------------
        


Validating: 100%|██████████| 500/500 [00:14<00:00, 35.28it/s]



Validation mean R^2: 0.657, RMSE: 0.1209
Running combination 3/18: {'density_estimator': 'maf', 'hidden_features': 64, 'num_transforms': 2}
 Neural network successfully converged after 128 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 128
        Best validation performance: -7.4010
        -------------------------
        

        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 128
        Best validation performance: -7.4010
        -------------------------
        


Validating: 100%|██████████| 500/500 [00:14<00:00, 35.29it/s]



Validation mean R^2: 0.686, RMSE: 0.1205
Running combination 4/18: {'density_estimator': 'maf', 'hidden_features': 64, 'num_transforms': 4}
 Neural network successfully converged after 154 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 154
        Best validation performance: -6.8626
        -------------------------
        

        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 154
        Best validation performance: -6.8626
        -------------------------
        


Validating: 100%|██████████| 500/500 [00:15<00:00, 31.75it/s]



Validation mean R^2: 0.589, RMSE: 0.1360
Running combination 5/18: {'density_estimator': 'maf', 'hidden_features': 128, 'num_transforms': 2}
 Neural network successfully converged after 184 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 184
        Best validation performance: -7.6356
        -------------------------
        

        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 184
        Best validation performance: -7.6356
        -------------------------
        


Validating: 100%|██████████| 500/500 [00:13<00:00, 36.24it/s]



Validation mean R^2: 0.387, RMSE: 0.1242
Running combination 6/18: {'density_estimator': 'maf', 'hidden_features': 128, 'num_transforms': 4}
 Neural network successfully converged after 100 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 100
        Best validation performance: -7.2245
        -------------------------
        

        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 100
        Best validation performance: -7.2245
        -------------------------
        


Validating: 100%|██████████| 500/500 [00:14<00:00, 33.86it/s]



Validation mean R^2: 0.590, RMSE: 0.1206
Running combination 7/18: {'density_estimator': 'nsf', 'hidden_features': 32, 'num_transforms': 2}
 Neural network successfully converged after 86 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 86
        Best validation performance: -7.5807
        -------------------------
        

        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 86
        Best validation performance: -7.5807
        -------------------------
        


torch.linalg.solve_triangular has its arguments reversed and does not return a copy of one of the inputs.
X = torch.triangular_solve(B, A).solution
should be replaced with
X = torch.linalg.solve_triangular(A, B). (Triggered internally at ../aten/src/ATen/native/BatchLinearAlgebra.cpp:2190.)
  outputs, _ = torch.triangular_solve(
torch.linalg.solve_triangular has its arguments reversed and does not return a copy of one of the inputs.
X = torch.triangular_solve(B, A).solution
should be replaced with
X = torch.linalg.solve_triangular(A, B). (Triggered internally at ../aten/src/ATen/native/BatchLinearAlgebra.cpp:2190.)
  outputs, _ = torch.triangular_solve(
Validating: 100%|██████████| 500/500 [00:11<00:00, 44.30it/s]



Validation mean R^2: 0.605, RMSE: 0.1144
Running combination 8/18: {'density_estimator': 'nsf', 'hidden_features': 32, 'num_transforms': 4}
 Neural network successfully converged after 92 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 92
        Best validation performance: -7.4033
        -------------------------
        

        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 92
        Best validation performance: -7.4033
        -------------------------
        


Validating:  77%|███████▋  | 385/500 [03:04<00:55,  2.09it/s]



FAILED: {'density_estimator': 'nsf', 'hidden_features': 32, 'num_transforms': 4} with error: 
Running combination 9/18: {'density_estimator': 'nsf', 'hidden_features': 64, 'num_transforms': 2}
 Neural network successfully converged after 82 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 82
        Best validation performance: -7.4813
        -------------------------
        

        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 82
        Best validation performance: -7.4813
        -------------------------
        


Validating: 100%|██████████| 500/500 [00:11<00:00, 45.42it/s]



Validation mean R^2: 0.642, RMSE: 0.1121
Running combination 10/18: {'density_estimator': 'nsf', 'hidden_features': 64, 'num_transforms': 4}
 Neural network successfully converged after 87 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 87
        Best validation performance: -7.3840
        -------------------------
        

        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 87
        Best validation performance: -7.3840
        -------------------------
        


Validating:  11%|█▏        | 57/500 [00:09<01:10,  6.28it/s]



FAILED: {'density_estimator': 'nsf', 'hidden_features': 64, 'num_transforms': 4} with error: 
Running combination 11/18: {'density_estimator': 'nsf', 'hidden_features': 128, 'num_transforms': 2}
 Neural network successfully converged after 72 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 72
        Best validation performance: -7.1729
        -------------------------
        

        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 72
        Best validation performance: -7.1729
        -------------------------
        


Validating: 100%|██████████| 500/500 [00:11<00:00, 43.63it/s]



Validation mean R^2: 0.640, RMSE: 0.1153
Running combination 12/18: {'density_estimator': 'nsf', 'hidden_features': 128, 'num_transforms': 4}
 Neural network successfully converged after 100 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 100
        Best validation performance: -7.3888
        -------------------------
        

        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 100
        Best validation performance: -7.3888
        -------------------------
        


Validating: 100%|██████████| 500/500 [00:11<00:00, 43.83it/s]



Validation mean R^2: 0.592, RMSE: 0.1147
Running combination 13/18: {'density_estimator': 'mdn', 'hidden_features': 32, 'num_transforms': 2}
 Neural network successfully converged after 167 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 167
        Best validation performance: -6.8300
        -------------------------
        

        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 167
        Best validation performance: -6.8300
        -------------------------
        


                    accepted. It may take a long time to collect the remaining
                    -1 samples. Consider interrupting (Ctrl-C) and switching to
                    `build_posterior(..., sample_with='mcmc')`.
                    accepted. It may take a long time to collect the remaining
                    -1 samples. Consider interrupting (Ctrl-C) and switching to
                    `build_posterior(..., sample_with='mcmc')`.
Validating:  88%|████████▊ | 438/500 [14:21:15<2:01:54, 117.98s/it]



KeyboardInterrupt: 

In [None]:
# Collect and Store R² Scores
results_df = pd.DataFrame(results)
results_df.sort_values('mean_r2', ascending=False, inplace=True)
results_df.reset_index(drop=True, inplace=True)
results_df.head()

In [None]:
# Visualize R² Score Improvements Across Architecture Parameters
plt.figure(figsize=(12, 7))
for de in param_grid['density_estimator']:
    for nt in param_grid['num_transforms']:
        subset = results_df[(results_df['density_estimator'] == de) & (results_df['num_transforms'] == nt)]
        plt.plot(subset['hidden_features'], subset['mean_r2'], marker='o', label=f"{de}, num_transforms={nt}")
plt.xlabel('Hidden Features')
plt.ylabel('Mean R² Score')
plt.title('Mean R² Score vs. Hidden Features for Each Architecture')
plt.legend(title='Architecture (Density Estimator, Num Transforms)')
plt.grid(True)
plt.show()