In [1]:
# In [1]: Imports and Setup
import sys
import os
sys.path.append(os.path.abspath(".."))  # Adjust path as needed
import torch
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
from sbi import utils as sbi_utils
from sbi import inference as sbi_inference
from sbi.inference import SNPE, simulate_for_sbi
from multiplex_sim import Microscope, io, plotting
npz_folder='../data/spectra_npz'

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
import numpy as np
import torch
from pathlib import Path
from sbi import utils as sbi_utils
from scipy.interpolate import interp1d

# ===== 1. Load and interpolate emission spectra =====
def load_emission_spectra(fluor_names, spectra_dir):
    spectra = {}
    for name in fluor_names:
        path = Path(spectra_dir) / f"{name}.npz"
        if not path.exists():
            raise FileNotFoundError(f"Missing file for dye: {name}")
        data = np.load(path)
        wl = data["wavelengths_emission"]
        em = data["emission"]
        spectra[name] = (wl, em / em.max())
    return spectra

def interpolate_emissions(spectra_dict, λ_grid):
    interpolated = {}
    for name, (wl, em) in spectra_dict.items():
        f = interp1d(wl, em, kind="linear", bounds_error=False, fill_value=0)
        interpolated[name] = f(λ_grid)
    return interpolated

# ===== 2. Define SBI-compatible simulator =====
def make_sbi_simulator(fluor_names, spectra_dir, num_channels=5, bandwidth=30.0):
    λ_grid = np.arange(500, 801, 1)
    spectra_dict = load_emission_spectra(fluor_names, spectra_dir)
    interpolated = interpolate_emissions(spectra_dict, λ_grid)

    center_wavelengths = np.linspace(540, 760, num_channels)
    channel_filters = np.stack([
        np.exp(-0.5 * ((λ_grid - cw) / bandwidth)**2)
        for cw in center_wavelengths
    ])  # shape: (num_channels, len(λ_grid))

    def simulator(amps: torch.Tensor) -> torch.Tensor:
        amps = amps.numpy()
        signal = np.zeros(num_channels)
        for name, a in zip(fluor_names, amps):
            emission = interpolated[name]
            signal += a * np.sum(channel_filters * emission, axis=1)
        return torch.tensor(signal, dtype=torch.float32)

    return simulator

# ===== 3. User-defined mixture to simulate measurement =====
def simulate_measurement(simulator, mixture_dict, fluor_names):
    amp_vector = torch.tensor([mixture_dict.get(name, 0.0) for name in fluor_names], dtype=torch.float32)
    return simulator(amp_vector)

# ===== 4. Full SBI Inference Pipeline =====
def run_sbi(fluor_names, spectra_dir, mixture_dict, num_channels=5, num_simulations=5000):
    from sbi.inference import simulate_for_sbi
    simulator = make_sbi_simulator(fluor_names, spectra_dir, num_channels=num_channels)

    # Simulated "measured" signal
    x_obs = simulate_measurement(simulator, mixture_dict, fluor_names)

    prior = sbi_utils.BoxUniform(low=torch.zeros(len(fluor_names)), high=torch.ones(len(fluor_names)))
    inference = sbi_inference.SNPE(prior=prior)

    def sim_batch(theta_batch):
        return torch.stack([simulator(p) for p in theta_batch])

    theta, x = simulate_for_sbi(sim_batch, prior, num_simulations=num_simulations)
    density_estimator = inference.append_simulations(theta, x).train()
    posterior = inference.build_posterior(density_estimator)

    samples = posterior.sample((1000,), x=x_obs)
    mean = samples.mean(dim=0)
    std = samples.std(dim=0)

    return x_obs, samples, mean, std

In [13]:
fluor_names = ['JF525','JF552','JF608','JFX673','JF722']
spectra_dir = npz_folder

# Define your known mixture
mixture_dict = {
    "JF525": 0.2,
    "JF552": 0.2,
    "JF608": 0.3,
    "JFX673": 0.1,
    "JF722": 0.2,
}

# Run simulation + SBI
x_obs, samples, mean, std = run_sbi(fluor_names, spectra_dir, mixture_dict)

print("Simulated Measurement (Observed):", x_obs)
print("Posterior Mean Amplitudes:")
for name, m, s in zip(fluor_names, mean, std):
    print(f"  {name:8s}: {m:.3f} ± {s:.3f}")

  0%|          | 0/5000 [00:00<?, ?it/s]

100%|██████████| 5000/5000 [00:00<00:00, 33418.57it/s]


 Neural network successfully converged after 316 epochs.

Drawing 1000 posterior samples: 100%|██████████| 1000/1000 [00:00<00:00, 74925.04it/s]

Simulated Measurement (Observed): tensor([10.4342, 17.4346, 15.2595, 11.0342,  8.6037])
Posterior Mean Amplitudes:
  JF525   : 0.206 ± 0.030
  JF552   : 0.191 ± 0.034
  JF608   : 0.304 ± 0.014
  JFX673  : 0.095 ± 0.009
  JF722   : 0.202 ± 0.002





In [7]:
mixture_dict

{'JF525': 0.2, 'JF552': 0.2, 'JF608': 0.3, 'JFX673': 0.1, 'JF722': 0.2}