diff --git a/crates/solver/pkg/calab_solver_bg.wasm b/crates/solver/pkg/calab_solver_bg.wasm index 54fec6b1..82804492 100644 Binary files a/crates/solver/pkg/calab_solver_bg.wasm and b/crates/solver/pkg/calab_solver_bg.wasm differ diff --git a/crates/solver/src/banded.rs b/crates/solver/src/banded.rs index 04bbc3ab..647b1704 100644 --- a/crates/solver/src/banded.rs +++ b/crates/solver/src/banded.rs @@ -1,3 +1,5 @@ +use crate::kernel::clamp_tau_rise; + /// Banded AR(2) convolution engine — O(T) replacement for FFT-based O(T log T). /// /// The AR(2) model c[t] = g1*c[t-1] + g2*c[t-2] + s[t] defines a banded @@ -19,6 +21,7 @@ pub(crate) struct BandedAR2 { impl BandedAR2 { /// Create a new BandedAR2 with the given tau parameters. pub(crate) fn new(tau_rise: f64, tau_decay: f64, fs: f64) -> Self { + let tau_rise = clamp_tau_rise(tau_rise, tau_decay); let dt = 1.0 / fs; let d = (-dt / tau_decay).exp(); let r = (-dt / tau_rise).exp(); @@ -37,20 +40,13 @@ impl BandedAR2 { /// Recompute coefficients after parameter change. pub(crate) fn update(&mut self, tau_rise: f64, tau_decay: f64, fs: f64) { - let dt = 1.0 / fs; - let d = (-dt / tau_decay).exp(); - let r = (-dt / tau_rise).exp(); - self.g1 = d + r; - self.g2 = -(d * r); - self.impulse_peak = compute_impulse_peak(self.g1, self.g2, tau_decay, fs); - self.lipschitz = - compute_banded_lipschitz(self.g1, self.g2) / (self.impulse_peak * self.impulse_peak); + *self = Self::new(tau_rise, tau_decay, fs); } /// Forward convolution: s -> normalized AR2 output, O(T). /// - /// Runs the raw AR2 recursion then divides by the impulse peak so that - /// a single spike produces a peak of 1.0 at all sampling rates. + /// Pre-scales input by 1/peak so the AR2 recursion directly produces + /// a peak-normalized output — no second normalization pass needed. pub(crate) fn convolve_forward(&self, source: &[f32], output: &mut [f32]) { let n = source.len(); if n == 0 { @@ -61,23 +57,19 @@ impl BandedAR2 { let g2 = self.g2 as f32; let inv_peak = (1.0 / self.impulse_peak) as f32; - output[0] = source[0]; + output[0] = source[0] * inv_peak; if n > 1 { - output[1] = g1 * output[0] + source[1]; + output[1] = g1 * output[0] + source[1] * inv_peak; } for t in 2..n { - output[t] = g1 * output[t - 1] + g2 * output[t - 2] + source[t]; - } - - // Normalize by impulse peak - for v in &mut output[..n] { - *v *= inv_peak; + output[t] = g1 * output[t - 1] + g2 * output[t - 2] + source[t] * inv_peak; } } /// Adjoint convolution: normalized adjoint, O(T). /// - /// Adjoint of (K / peak) = K^T / peak. + /// Pre-scales input by 1/peak so the backward AR2 recursion directly + /// produces a peak-normalized output — no second normalization pass needed. pub(crate) fn convolve_adjoint(&self, source: &[f32], output: &mut [f32]) { let n = source.len(); if n == 0 { @@ -88,17 +80,12 @@ impl BandedAR2 { let g2 = self.g2 as f32; let inv_peak = (1.0 / self.impulse_peak) as f32; - output[n - 1] = source[n - 1]; + output[n - 1] = source[n - 1] * inv_peak; if n > 1 { - output[n - 2] = source[n - 2] + g1 * output[n - 1]; + output[n - 2] = source[n - 2] * inv_peak + g1 * output[n - 1]; } for t in (0..n.saturating_sub(2)).rev() { - output[t] = source[t] + g1 * output[t + 1] + g2 * output[t + 2]; - } - - // Normalize by impulse peak - for v in &mut output[..n] { - *v *= inv_peak; + output[t] = source[t] * inv_peak + g1 * output[t + 1] + g2 * output[t + 2]; } } diff --git a/crates/solver/src/baseline.rs b/crates/solver/src/baseline.rs index a61bd759..035344ed 100644 --- a/crates/solver/src/baseline.rs +++ b/crates/solver/src/baseline.rs @@ -45,12 +45,19 @@ impl Ord for OrderedF32 { /// Used for O(log M) k-th element queries via binary lifting. struct FenwickTree { tree: Vec, + msb: usize, // highest power of 2 <= (tree.len() - 1) } impl FenwickTree { fn new(size: usize) -> Self { + let mut msb = 1; + while msb <= size { + msb <<= 1; + } + msb >>= 1; Self { - tree: vec![0; size + 1], // 1-indexed + tree: vec![0; size + 1], + msb, } } @@ -66,14 +73,9 @@ impl FenwickTree { /// Find the 0-indexed position of the k-th element (1-based k). /// Uses binary lifting: O(log M) time. fn kth(&self, mut k: i32) -> usize { - let n = self.tree.len() - 1; // max 0-indexed position + 1 + let n = self.tree.len() - 1; let mut pos = 0; - // Find highest power of 2 <= n - let mut bit = 1; - while bit <= n { - bit <<= 1; - } - bit >>= 1; + let mut bit = self.msb; while bit > 0 { let next = pos + bit; @@ -83,7 +85,7 @@ impl FenwickTree { } bit >>= 1; } - pos // 0-indexed coordinate + pos } } diff --git a/crates/solver/src/biexp_fit.rs b/crates/solver/src/biexp_fit.rs index eec3e66c..1da14b19 100644 --- a/crates/solver/src/biexp_fit.rs +++ b/crates/solver/src/biexp_fit.rs @@ -82,6 +82,7 @@ /// the ceiling logic in eval_two_component can be simplified back to a /// plain `bs >= bf` gate. +#[derive(Clone)] #[cfg_attr(feature = "jsbindings", derive(serde::Serialize))] pub struct BiexpResult { pub tau_rise: f64, @@ -93,6 +94,25 @@ pub struct BiexpResult { pub beta_fast: f64, } +impl BiexpResult { + fn sentinel() -> Self { + BiexpResult { + tau_rise: 0.02, + tau_decay: 0.4, + beta: 0.0, + residual: f64::INFINITY, + tau_rise_fast: 0.0, + tau_decay_fast: 0.0, + beta_fast: 0.0, + } + } + + /// Returns true if the fit includes a fast component. + pub fn has_fast_component(&self) -> bool { + self.tau_rise_fast > 0.0 && self.tau_decay_fast > self.tau_rise_fast + } +} + /// Fit a two-component bi-exponential model to a free-form kernel. /// /// Uses a 20×20×(5×8+1) grid search over (tau_r, tau_d, tau_r_fast, tau_d_fast) @@ -116,15 +136,7 @@ pub fn fit_biexponential( let n = h_free.len(); let skip = skip.min(n.saturating_sub(1)); if n == 0 { - return BiexpResult { - tau_rise: 0.02, - tau_decay: 0.4, - beta: 0.0, - residual: f64::INFINITY, - tau_rise_fast: 0.0, - tau_decay_fast: 0.0, - beta_fast: 0.0, - }; + return BiexpResult::sentinel(); } let dt = 1.0 / fs; @@ -145,15 +157,7 @@ pub fn fit_biexponential( // additional candidate. This gives faster convergence when the kernel // is evolving smoothly between iterations. if let Some(warm) = warm_start { - let mut warm_candidate = BiexpResult { - tau_rise: warm.tau_rise, - tau_decay: warm.tau_decay, - beta: warm.beta, - residual: warm.residual, - tau_rise_fast: warm.tau_rise_fast, - tau_decay_fast: warm.tau_decay_fast, - beta_fast: warm.beta_fast, - }; + let mut warm_candidate = warm.clone(); // Re-evaluate on the CURRENT h_free (warm residual was from previous h_free) let (bs, bf, res) = eval_two_component( h_free, @@ -172,7 +176,7 @@ pub fn fit_biexponential( refine_candidate(h_free, &mut warm_candidate, dt, 40, skip); } // Warm candidate competes with the appropriate track - if warm_candidate.tau_rise_fast > 0.0 && warm_candidate.tau_decay_fast > warm_candidate.tau_rise_fast { + if warm_candidate.has_fast_component() { if warm_candidate.residual < best_two.residual { best_two = warm_candidate; } @@ -218,8 +222,15 @@ fn refine_candidate( ) { let (refined_tr, refined_td, refined_trf, refined_tdf) = golden_section_refine(h_free, candidate, dt, max_steps, skip); - let (beta_s, beta_f, residual) = - eval_two_component(h_free, refined_tr, refined_td, refined_trf, refined_tdf, dt, skip); + let (beta_s, beta_f, residual) = eval_two_component( + h_free, + refined_tr, + refined_td, + refined_trf, + refined_tdf, + dt, + skip, + ); if residual < candidate.residual { *candidate = BiexpResult { tau_rise: refined_tr, @@ -285,24 +296,8 @@ fn cold_grid_search(h_free: &[f32], fs: f64, dt: f64, skip: usize) -> (BiexpResu let tdf_lo = 0.5 * dt; let tdf_abs_hi = 8.0 * dt; - let mut best_slow = BiexpResult { - tau_rise: 0.02, - tau_decay: 0.4, - beta: 0.0, - residual: f64::INFINITY, - tau_rise_fast: 0.0, - tau_decay_fast: 0.0, - beta_fast: 0.0, - }; - let mut best_two = BiexpResult { - tau_rise: 0.02, - tau_decay: 0.4, - beta: 0.0, - residual: f64::INFINITY, - tau_rise_fast: 0.0, - tau_decay_fast: 0.0, - beta_fast: 0.0, - }; + let mut best_slow = BiexpResult::sentinel(); + let mut best_two = BiexpResult::sentinel(); for i in 0..grid_n { let log_tr = log_tr_lo + (log_tr_hi - log_tr_lo) * i as f64 / (grid_n - 1) as f64; @@ -334,7 +329,7 @@ fn cold_grid_search(h_free: &[f32], fs: f64, dt: f64, skip: usize) -> (BiexpResu // Inner grid: scan independent (tau_r_fast, tau_d_fast) // Upper bound for tau_d_fast is the tighter of the absolute cap - // (8×dt) and a relative cap (tau_d × 0.2) to prevent degeneracy. + // (8×dt) and a relative cap (tau_d × 0.15) to prevent degeneracy. let tdf_hi = tdf_abs_hi.min(tau_d * 0.15); if tdf_hi <= tdf_lo { continue; // tau_d too small for a distinct fast component @@ -343,8 +338,7 @@ fn cold_grid_search(h_free: &[f32], fs: f64, dt: f64, skip: usize) -> (BiexpResu let log_tdf_hi = tdf_hi.ln(); for ki in 0..trf_grid_n { - let tau_r_fast = - trf_lo + (trf_hi - trf_lo) * ki as f64 / (trf_grid_n - 1) as f64; + let tau_r_fast = trf_lo + (trf_hi - trf_lo) * ki as f64 / (trf_grid_n - 1) as f64; for kj in 0..tdf_grid_n { let log_tdf = log_tdf_lo @@ -356,15 +350,8 @@ fn cold_grid_search(h_free: &[f32], fs: f64, dt: f64, skip: usize) -> (BiexpResu continue; } - let (beta_s, beta_f, residual) = eval_two_component( - h_free, - tau_r, - tau_d, - tau_r_fast, - tau_d_fast, - dt, - skip, - ); + let (beta_s, beta_f, residual) = + eval_two_component(h_free, tau_r, tau_d, tau_r_fast, tau_d_fast, dt, skip); if residual < best_two.residual { best_two = BiexpResult { tau_rise: tau_r, @@ -527,6 +514,23 @@ fn eval_two_component( (best_bs, best_bf, best_res) } +/// Run one golden-section narrowing pass on a 1D interval [lo, hi]. +/// `cost` takes a candidate value and returns the residual. +/// Returns the midpoint of the narrowed interval. +fn golden_bracket(mut lo: f64, mut hi: f64, cost: impl Fn(f64) -> f64) -> f64 { + const PHI: f64 = 0.6180339887498949; // (sqrt(5) - 1) / 2 + for _ in 0..10 { + let x1 = hi - PHI * (hi - lo); + let x2 = lo + PHI * (hi - lo); + if cost(x1) < cost(x2) { + hi = x2; + } else { + lo = x1; + } + } + (lo + hi) / 2.0 +} + /// Golden-section refinement around the best grid point. /// Cycles through refining tau_r, tau_d, tau_r_fast, and tau_d_fast for `max_steps` total. fn golden_section_refine( @@ -536,109 +540,56 @@ fn golden_section_refine( max_steps: usize, skip: usize, ) -> (f64, f64, f64, f64) { - let phi = (5.0_f64.sqrt() - 1.0) / 2.0; // golden ratio conjugate - let mut tau_r = best.tau_rise; let mut tau_d = best.tau_decay; let mut tau_r_fast = best.tau_rise_fast; let mut tau_d_fast = best.tau_decay_fast; - // If fast component is zero, skip fast parameter refinement - let has_fast = tau_r_fast > 0.0 && tau_d_fast > tau_r_fast; + let has_fast = best.has_fast_component(); let n_phases = if has_fast { 4 } else { 2 }; for step in 0..max_steps { - let phase = step % n_phases; - - if phase == 0 { - // Refine tau_r - let mut lo = (tau_r * 0.5).max(dt); - let mut hi = (tau_r * 2.0).min(tau_d * 0.99); - if lo >= hi { - continue; - } - - for _ in 0..10 { - let x1 = hi - phi * (hi - lo); - let x2 = lo + phi * (hi - lo); - let (_, _, r1) = - eval_two_component(h_free, x1, tau_d, tau_r_fast, tau_d_fast, dt, skip); - let (_, _, r2) = - eval_two_component(h_free, x2, tau_d, tau_r_fast, tau_d_fast, dt, skip); - if r1 < r2 { - hi = x2; - } else { - lo = x1; + match step % n_phases { + 0 => { + // Refine tau_r + let lo = (tau_r * 0.5).max(dt); + let hi = (tau_r * 2.0).min(tau_d * 0.99); + if lo < hi { + tau_r = golden_bracket(lo, hi, |x| { + eval_two_component(h_free, x, tau_d, tau_r_fast, tau_d_fast, dt, skip).2 + }); } } - tau_r = (lo + hi) / 2.0; - } else if phase == 1 { - // Refine tau_d - let lo = (tau_d * 0.5).max(tau_r * 1.01); - let mut hi = tau_d * 2.0; - if lo >= hi { - continue; - } - - let mut lo = lo; - for _ in 0..10 { - let x1 = hi - phi * (hi - lo); - let x2 = lo + phi * (hi - lo); - let (_, _, r1) = - eval_two_component(h_free, tau_r, x1, tau_r_fast, tau_d_fast, dt, skip); - let (_, _, r2) = - eval_two_component(h_free, tau_r, x2, tau_r_fast, tau_d_fast, dt, skip); - if r1 < r2 { - hi = x2; - } else { - lo = x1; + 1 => { + // Refine tau_d + let lo = (tau_d * 0.5).max(tau_r * 1.01); + let hi = tau_d * 2.0; + if lo < hi { + tau_d = golden_bracket(lo, hi, |x| { + eval_two_component(h_free, tau_r, x, tau_r_fast, tau_d_fast, dt, skip).2 + }); } } - tau_d = (lo + hi) / 2.0; - } else if phase == 2 { - // Refine tau_r_fast - let mut lo = (tau_r_fast * 0.5).max(0.1 * dt); - let mut hi = (tau_r_fast * 2.0).min((2.0 * dt).min(tau_d_fast * 0.99)); - if lo >= hi { - continue; - } - - for _ in 0..10 { - let x1 = hi - phi * (hi - lo); - let x2 = lo + phi * (hi - lo); - let (_, _, r1) = - eval_two_component(h_free, tau_r, tau_d, x1, tau_d_fast, dt, skip); - let (_, _, r2) = - eval_two_component(h_free, tau_r, tau_d, x2, tau_d_fast, dt, skip); - if r1 < r2 { - hi = x2; - } else { - lo = x1; + 2 => { + // Refine tau_r_fast + let lo = (tau_r_fast * 0.5).max(0.1 * dt); + let hi = (tau_r_fast * 2.0).min((2.0 * dt).min(tau_d_fast * 0.99)); + if lo < hi { + tau_r_fast = golden_bracket(lo, hi, |x| { + eval_two_component(h_free, tau_r, tau_d, x, tau_d_fast, dt, skip).2 + }); } } - tau_r_fast = (lo + hi) / 2.0; - } else { - // Refine tau_d_fast — cap at min(8×dt, tau_d × 0.15) to prevent degeneracy - let mut lo = (tau_d_fast * 0.5).max(tau_r_fast * 1.01); - let mut hi = (tau_d_fast * 2.0).min((8.0 * dt).min(tau_d * 0.15)); - if lo >= hi { - continue; - } - - for _ in 0..10 { - let x1 = hi - phi * (hi - lo); - let x2 = lo + phi * (hi - lo); - let (_, _, r1) = - eval_two_component(h_free, tau_r, tau_d, tau_r_fast, x1, dt, skip); - let (_, _, r2) = - eval_two_component(h_free, tau_r, tau_d, tau_r_fast, x2, dt, skip); - if r1 < r2 { - hi = x2; - } else { - lo = x1; + _ => { + // Refine tau_d_fast + let lo = (tau_d_fast * 0.5).max(tau_r_fast * 1.01); + let hi = (tau_d_fast * 2.0).min((8.0 * dt).min(tau_d * 0.15)); + if lo < hi { + tau_d_fast = golden_bracket(lo, hi, |x| { + eval_two_component(h_free, tau_r, tau_d, tau_r_fast, x, dt, skip).2 + }); } } - tau_d_fast = (lo + hi) / 2.0; } } @@ -921,11 +872,7 @@ mod tests { #[test] fn fast_tau_in_valid_range() { // For various inputs, verify fast component time constants are in expected ranges - let test_cases = [ - (0.08, 0.5, 30.0), - (0.05, 0.3, 100.0), - (0.1, 2.0, 10.0), - ]; + let test_cases = [(0.08, 0.5, 30.0), (0.05, 0.3, 100.0), (0.1, 2.0, 10.0)]; for (tau_r, tau_d, fs) in test_cases { let h = make_biexp(tau_r, tau_d, 2.0, fs, 60); @@ -966,8 +913,7 @@ mod tests { // Case 1: Pure slow component — should yield beta_s > 0, beta_f ≈ 0 let h_slow = make_biexp(0.05, 0.5, 2.0, fs, n); - let (bs, bf, _) = - eval_two_component(&h_slow, 0.05, 0.5, tau_r_fast, tau_d_fast, dt, 0); + let (bs, bf, _) = eval_two_component(&h_slow, 0.05, 0.5, tau_r_fast, tau_d_fast, dt, 0); assert!(bs > 0.0, "beta_s should be positive for slow-only input"); assert!( bf < 0.1 * bs, @@ -985,22 +931,19 @@ mod tests { (3.0 * ((-t / tau_d_fast).exp() - (-t / tau_r_fast).exp())) as f32 }) .collect(); - let (bs, _bf, _) = - eval_two_component(&h_fast, 0.05, 0.5, tau_r_fast, tau_d_fast, dt, 0); + let (bs, _bf, _) = eval_two_component(&h_fast, 0.05, 0.5, tau_r_fast, tau_d_fast, dt, 0); // With no fast-only active set, the slow template absorbs what it can assert!(bs >= 0.0, "beta_s should be non-negative for any input"); // Case 3: Both components present let h_both = make_two_component(0.05, 0.5, 2.0, tau_r_fast, tau_d_fast, 1.5, fs, n); - let (bs, bf, _) = - eval_two_component(&h_both, 0.05, 0.5, tau_r_fast, tau_d_fast, dt, 0); + let (bs, bf, _) = eval_two_component(&h_both, 0.05, 0.5, tau_r_fast, tau_d_fast, dt, 0); assert!(bs > 0.0, "beta_s should be positive for mixed input"); assert!(bf > 0.0, "beta_f should be positive for mixed input"); // Case 4: Zero signal — both should be zero let h_zero = vec![0.0_f32; n]; - let (bs, bf, res) = - eval_two_component(&h_zero, 0.05, 0.5, tau_r_fast, tau_d_fast, dt, 0); + let (bs, bf, res) = eval_two_component(&h_zero, 0.05, 0.5, tau_r_fast, tau_d_fast, dt, 0); assert_eq!(bs, 0.0, "beta_s should be zero for zero input"); assert_eq!(bf, 0.0, "beta_f should be zero for zero input"); assert!(res < 1e-20, "residual should be ~0 for zero input"); diff --git a/crates/solver/src/fft.rs b/crates/solver/src/fft.rs index 2525f0c1..25513eb6 100644 --- a/crates/solver/src/fft.rs +++ b/crates/solver/src/fft.rs @@ -150,41 +150,7 @@ impl FftConvolver { signal_len: usize, output: &mut [f32], ) { - let padded_len = self.fft_len; - let spectrum_len = padded_len / 2 + 1; - - // Zero-pad source into fft_input - self.fft_input[..signal_len].copy_from_slice(&source[..signal_len]); - self.fft_input[signal_len..padded_len].fill(0.0); - - // Forward FFT of source - let fwd = self.plan_fwd.as_ref().expect("plans not initialized"); - fwd.process_with_scratch( - &mut self.fft_input[..padded_len], - &mut self.fft_spectrum[..spectrum_len], - &mut self.fft_scratch_fwd, - ) - .unwrap(); - - // Pointwise multiply with kernel FFT - for i in 0..spectrum_len { - self.fft_spectrum[i] *= self.kernel_fft[i]; - } - - // Inverse FFT - let inv = self.plan_inv.as_ref().expect("plans not initialized"); - inv.process_with_scratch( - &mut self.fft_spectrum[..spectrum_len], - &mut self.fft_output[..padded_len], - &mut self.fft_scratch_inv, - ) - .unwrap(); - - // Normalize and copy first signal_len samples to output - let scale = 1.0 / padded_len as f32; - for i in 0..signal_len { - output[i] = self.fft_output[i] * scale; - } + self.convolve_impl(source, signal_len, output, false); } /// FFT-based adjoint convolution (correlation): output[..signal_len] = (K^T * source)[..signal_len]. @@ -193,6 +159,18 @@ impl FftConvolver { source: &[f32], signal_len: usize, output: &mut [f32], + ) { + self.convolve_impl(source, signal_len, output, true); + } + + /// Shared FFT convolution implementation. + /// `use_conjugate` selects `kernel_conj_fft` (adjoint) or `kernel_fft` (forward). + fn convolve_impl( + &mut self, + source: &[f32], + signal_len: usize, + output: &mut [f32], + use_conjugate: bool, ) { let padded_len = self.fft_len; let spectrum_len = padded_len / 2 + 1; @@ -210,9 +188,15 @@ impl FftConvolver { ) .unwrap(); - // Pointwise multiply with conjugate kernel FFT - for i in 0..spectrum_len { - self.fft_spectrum[i] *= self.kernel_conj_fft[i]; + // Pointwise multiply with kernel spectrum + if use_conjugate { + for i in 0..spectrum_len { + self.fft_spectrum[i] *= self.kernel_conj_fft[i]; + } + } else { + for i in 0..spectrum_len { + self.fft_spectrum[i] *= self.kernel_fft[i]; + } } // Inverse FFT diff --git a/crates/solver/src/filter.rs b/crates/solver/src/filter.rs index 52d0fcf5..1173a4ac 100644 --- a/crates/solver/src/filter.rs +++ b/crates/solver/src/filter.rs @@ -206,17 +206,8 @@ impl BandpassFilter { } } - /// Apply bandpass filter in-place. Caches power spectrum. Returns false if skipped. - pub fn apply(&mut self, trace: &mut [f32]) -> bool { - if !self.is_enabled() || !self.valid || trace.len() < 8 { - return false; - } - - // Mode-specific validity: HP+LP requires f_hp < f_lp - if self.hp_enabled && self.lp_enabled && self.f_hp >= self.f_lp { - return false; - } - + /// Perform forward FFT and cache power spectrum. Used by both `apply` and `compute_spectrum_only`. + fn forward_fft_and_cache_power(&mut self, trace: &[f32]) { let n = trace.len(); self.ensure_buffers(n); let spectrum_len = n / 2 + 1; @@ -224,7 +215,7 @@ impl BandpassFilter { // Copy trace into fft_input self.fft_input[..n].copy_from_slice(trace); - // Forward FFT (use cached plan — no hash-map lookup) + // Forward FFT let fwd = self.plan_fwd.as_ref().expect("plans not initialized"); fwd.process_with_scratch( &mut self.fft_input[..n], @@ -234,14 +225,35 @@ impl BandpassFilter { .unwrap(); // Cache pre-filter power spectrum - for i in 0..spectrum_len { - let c = self.spectrum[i]; - self.power_spectrum[i] = c.re * c.re + c.im * c.im; + for (ps, c) in self.power_spectrum[..spectrum_len] + .iter_mut() + .zip(&self.spectrum[..spectrum_len]) + { + *ps = c.re * c.re + c.im * c.im; } + } + + /// Apply bandpass filter in-place. Caches power spectrum. Returns false if skipped. + pub fn apply(&mut self, trace: &mut [f32]) -> bool { + if !self.is_enabled() || !self.valid || trace.len() < 8 { + return false; + } + + // Mode-specific validity: HP+LP requires f_hp < f_lp + if self.hp_enabled && self.lp_enabled && self.f_hp >= self.f_lp { + return false; + } + + let n = trace.len(); + self.forward_fft_and_cache_power(trace); + let spectrum_len = n / 2 + 1; // Apply gain curve - for i in 0..spectrum_len { - self.spectrum[i] *= self.gain_curve[i]; + for (s, &g) in self.spectrum[..spectrum_len] + .iter_mut() + .zip(&self.gain_curve[..spectrum_len]) + { + *s *= g; } // Inverse FFT (use cached plan — no hash-map lookup) @@ -255,8 +267,8 @@ impl BandpassFilter { // Normalize (realfft doesn't normalize) let scale = 1.0 / n as f32; - for i in 0..n { - trace[i] = self.fft_input[i] * scale; + for (t, &f) in trace.iter_mut().zip(&self.fft_input[..n]) { + *t = f * scale; } true @@ -267,25 +279,7 @@ impl BandpassFilter { if trace.len() < 8 { return; } - - let n = trace.len(); - self.ensure_buffers(n); - let spectrum_len = n / 2 + 1; - - self.fft_input[..n].copy_from_slice(trace); - - let fwd = self.plan_fwd.as_ref().expect("plans not initialized"); - fwd.process_with_scratch( - &mut self.fft_input[..n], - &mut self.spectrum[..spectrum_len], - &mut self.scratch_fwd, - ) - .unwrap(); - - for i in 0..spectrum_len { - let c = self.spectrum[i]; - self.power_spectrum[i] = c.re * c.re + c.im * c.im; - } + self.forward_fft_and_cache_power(trace); } /// Get power spectrum (N/2+1 bins of |FFT|²). diff --git a/crates/solver/src/fista.rs b/crates/solver/src/fista.rs index 185b1a65..6669e024 100644 --- a/crates/solver/src/fista.rs +++ b/crates/solver/src/fista.rs @@ -55,20 +55,9 @@ impl Solver { // mathematically cancels in the gradient (residual = mean-centered signals). // Computing it anyway would produce pure momentum-oscillation noise. if !self.filtered { - let mut sum = 0.0_f64; - for i in 0..n { - sum += (self.trace[i] - self.reconvolution[i]) as f64; - } - let raw_baseline = sum / n as f64; - self.baseline = raw_baseline; - - // Per-iteration EMA smoothing for display baseline - if !self.baseline_ema_init { - self.baseline_ema = raw_baseline; - self.baseline_ema_init = true; - } else { - self.baseline_ema = 0.3 * raw_baseline + 0.7 * self.baseline_ema; - } + let raw = + crate::compute_raw_baseline(&self.trace[..n], &self.reconvolution[..n], n); + self.update_baseline_ema(raw); } // 2. Compute residual = K * y_k + b - trace diff --git a/crates/solver/src/indeca.rs b/crates/solver/src/indeca.rs index fc857f27..503c0620 100644 --- a/crates/solver/src/indeca.rs +++ b/crates/solver/src/indeca.rs @@ -138,6 +138,19 @@ fn solve_upsampled( (solution, filtered, iterations, converged) } +/// Return the interior slice of `s` excluding `pad` samples from each end. +/// Falls back to the full slice when the interior is empty. +fn interior_slice(s: &[f32], pad: usize) -> &[f32] { + let n = s.len(); + let lo = pad.min(n); + let hi = n.saturating_sub(pad).max(lo); + if hi > lo { + &s[lo..hi] + } else { + s + } +} + /// Estimate alpha from the interior of the trace (excluding boundary padding). /// /// Uses peak-to-trough of the inner region to avoid edge artifacts that occur @@ -145,17 +158,7 @@ fn solve_upsampled( /// Since the kernel is peak-normalized, peak-to-trough >= alpha, making this /// a safe overestimate. Returns 1.0 for flat traces. fn estimate_alpha_interior(trace: &[f32], pad: usize) -> f64 { - let n = trace.len(); - let lo_idx = pad.min(n); - let hi_idx = n.saturating_sub(pad).max(lo_idx); - let inner = &trace[lo_idx..hi_idx]; - if inner.is_empty() { - // Trace too short for padding — fall back to full trace - let lo = trace.iter().copied().fold(f32::INFINITY, f32::min); - let hi = trace.iter().copied().fold(f32::NEG_INFINITY, f32::max); - let ptp = (hi - lo) as f64; - return if ptp < 1e-10 { 1.0 } else { ptp }; - } + let inner = interior_slice(trace, pad); let lo = inner.iter().copied().fold(f32::INFINITY, f32::min); let hi = inner.iter().copied().fold(f32::NEG_INFINITY, f32::max); let ptp = (hi - lo) as f64; @@ -170,11 +173,10 @@ fn estimate_alpha_interior(trace: &[f32], pad: usize) -> f64 { /// /// Falls back to the full slice when the interior is empty. fn interior_peak(s: &[f32], pad: usize) -> f32 { - let n = s.len(); - let lo = pad.min(n); - let hi = n.saturating_sub(pad).max(lo); - let region = if hi > lo { &s[lo..hi] } else { s }; - region.iter().copied().fold(0.0_f32, f32::max) + interior_slice(s, pad) + .iter() + .copied() + .fold(0.0_f32, f32::max) } /// Full InDeCa trace processing pipeline with scale iteration. @@ -214,25 +216,17 @@ pub fn solve_trace( let mut solver = Solver::new(); // ── Step 1: Apply optional bandpass filter + rolling baseline subtraction ── - // Run a throwaway FISTA just to get the filtered trace (if HP/LP), then + // Apply bandpass filter directly (if HP/LP enabled), then // subtract the rolling-percentile baseline so the floor is ~0. let mut working_trace = if hp_enabled || lp_enabled { - let (_, filtered_up, _, _) = solve_upsampled( - &mut solver, - &upsampled, - tau_r, - tau_d, - fs_up, - 1, // only 1 iteration — we just need the filtered trace - tol, - None, - hp_enabled, - lp_enabled, - Constraint::Box01, - false, - 0.0, // no sparsity for filter pass - ); - filtered_up.unwrap() + // Apply bandpass filter directly — no need for a full FISTA solve + solver.set_conv_mode(ConvMode::BandedAR2); + solver.set_params(tau_r, tau_d, 0.0, fs_up); + solver.set_trace(&upsampled); + solver.set_hp_filter_enabled(hp_enabled); + solver.set_lp_filter_enabled(lp_enabled); + solver.apply_filter(); + solver.get_trace() } else { upsampled }; @@ -412,7 +406,9 @@ mod tests { #[test] fn outputs_in_range() { let trace = make_trace(0.02, 0.4, 30.0, 300, &[20, 80, 150, 220]); - let result = solve_trace(&trace, 0.02, 0.4, 30.0, 1, 500, 1e-4, None, false, false, 0.0); + let result = solve_trace( + &trace, 0.02, 0.4, 30.0, 1, 500, 1e-4, None, false, false, 0.0, + ); // Spike counts should be non-negative for (i, &v) in result.s_counts.iter().enumerate() { @@ -438,7 +434,9 @@ mod tests { } } } - let result = solve_trace(&trace, 0.02, 0.4, 30.0, 1, 1000, 1e-4, None, false, false, 0.0); + let result = solve_trace( + &trace, 0.02, 0.4, 30.0, 1, 1000, 1e-4, None, false, false, 0.0, + ); // Check that spikes are detected near the true positions let mut detected = 0; @@ -491,7 +489,9 @@ mod tests { #[test] fn upsampled_output_length() { let trace = make_trace(0.02, 0.4, 30.0, 100, &[20, 50]); - let result = solve_trace(&trace, 0.02, 0.4, 30.0, 10, 200, 1e-3, None, false, false, 0.0); + let result = solve_trace( + &trace, 0.02, 0.4, 30.0, 10, 200, 1e-3, None, false, false, 0.0, + ); // Output should be same length as input regardless of upsample factor assert_eq!( @@ -504,7 +504,9 @@ mod tests { #[test] fn zero_trace() { let trace = vec![0.0_f32; 100]; - let result = solve_trace(&trace, 0.02, 0.4, 30.0, 1, 100, 1e-4, None, false, false, 0.0); + let result = solve_trace( + &trace, 0.02, 0.4, 30.0, 1, 100, 1e-4, None, false, false, 0.0, + ); let total_spikes: f32 = result.s_counts.iter().sum(); assert!( total_spikes < 1e-6, @@ -538,7 +540,9 @@ mod tests { } } - let result = solve_trace(&trace, tau_r, tau_d, fs, 10, 500, 1e-4, None, false, false, 0.0); + let result = solve_trace( + &trace, tau_r, tau_d, fs, 10, 500, 1e-4, None, false, false, 0.0, + ); let total_counts: f32 = result.s_counts.iter().sum(); @@ -602,7 +606,9 @@ mod tests { let subset_end = 400; let subset = &full_trace[subset_start..subset_end]; - let result = solve_trace(subset, tau_r, tau_d, fs, 1, 1000, 1e-4, None, false, false, 0.0); + let result = solve_trace( + subset, tau_r, tau_d, fs, 1, 1000, 1e-4, None, false, false, 0.0, + ); let total_spikes: f32 = result.s_counts.iter().sum(); // Should detect interior spikes, not just the edge artifact @@ -638,7 +644,9 @@ mod tests { } } - let result = solve_trace(&trace, tau_r, tau_d, fs, 1, 1000, 1e-4, None, false, false, 0.0); + let result = solve_trace( + &trace, tau_r, tau_d, fs, 1, 1000, 1e-4, None, false, false, 0.0, + ); let total_spikes: f32 = result.s_counts.iter().sum(); assert!( @@ -647,4 +655,120 @@ mod tests { total_spikes, result.alpha, result.threshold, result.pve ); } + + /// HP+LP filter path should produce valid results and return a filtered trace. + #[test] + fn filter_path_hp_lp() { + let spike_positions = [30, 100, 200]; + let alpha_true = 10.0_f32; + let baseline_true = 2.0_f32; + let kernel = build_kernel(0.02, 0.4, 30.0); + let n = 300; + let mut trace = vec![baseline_true; n]; + for &pos in &spike_positions { + for (k, &kv) in kernel.iter().enumerate() { + if pos + k < n { + trace[pos + k] += alpha_true * kv; + } + } + } + + let result = solve_trace( + &trace, 0.02, 0.4, 30.0, 1, 1000, 1e-4, None, true, true, 0.0, + ); + + // Output length should match input + assert_eq!(result.s_counts.len(), trace.len()); + + // Spike counts should be non-negative + for (i, &v) in result.s_counts.iter().enumerate() { + assert!(v >= 0.0, "Negative spike count at {}: {}", i, v); + } + + // Filtered trace should be returned and have the correct length + let filtered = result + .filtered_trace + .as_ref() + .expect("filtered_trace should be Some when filters are enabled"); + assert_eq!(filtered.len(), trace.len()); + + // Should still detect spikes through the filter + let total_spikes: f32 = result.s_counts.iter().sum(); + assert!( + total_spikes >= 1.0, + "Should detect at least 1 spike with HP+LP filter, got {} (pve={:.4})", + total_spikes, + result.pve + ); + } + + /// HP-only filter path should remove DC and still detect spikes. + #[test] + fn filter_path_hp_only() { + let spike_positions = [30, 100, 200]; + let alpha_true = 10.0_f32; + let baseline_true = 50.0_f32; // high DC offset + let kernel = build_kernel(0.02, 0.4, 30.0); + let n = 300; + let mut trace = vec![baseline_true; n]; + for &pos in &spike_positions { + for (k, &kv) in kernel.iter().enumerate() { + if pos + k < n { + trace[pos + k] += alpha_true * kv; + } + } + } + + let result = solve_trace( + &trace, 0.02, 0.4, 30.0, 1, 1000, 1e-4, None, true, false, 0.0, + ); + + assert_eq!(result.s_counts.len(), trace.len()); + + // Filtered trace should be returned + assert!(result.filtered_trace.is_some()); + + // Should still detect spikes + let total_spikes: f32 = result.s_counts.iter().sum(); + assert!( + total_spikes >= 1.0, + "Should detect at least 1 spike with HP-only filter, got {} (pve={:.4})", + total_spikes, + result.pve + ); + } + + /// LP-only filter path should preserve DC and detect spikes. + #[test] + fn filter_path_lp_only() { + let spike_positions = [30, 100, 200]; + let alpha_true = 10.0_f32; + let baseline_true = 2.0_f32; + let kernel = build_kernel(0.02, 0.4, 30.0); + let n = 300; + let mut trace = vec![baseline_true; n]; + for &pos in &spike_positions { + for (k, &kv) in kernel.iter().enumerate() { + if pos + k < n { + trace[pos + k] += alpha_true * kv; + } + } + } + + let result = solve_trace( + &trace, 0.02, 0.4, 30.0, 1, 1000, 1e-4, None, false, true, 0.0, + ); + + assert_eq!(result.s_counts.len(), trace.len()); + assert!(result.filtered_trace.is_some()); + + // Should still detect spikes + let total_spikes: f32 = result.s_counts.iter().sum(); + assert!( + total_spikes >= 1.0, + "Should detect at least 1 spike with LP-only filter, got {} (pve={:.4})", + total_spikes, + result.pve + ); + } } diff --git a/crates/solver/src/js_indeca.rs b/crates/solver/src/js_indeca.rs index b33d0cb9..9f25bdd3 100644 --- a/crates/solver/src/js_indeca.rs +++ b/crates/solver/src/js_indeca.rs @@ -128,8 +128,7 @@ pub fn indeca_fit_biexponential( } else { None }; - let result = - biexp_fit::fit_biexponential(h_free, fs, refine, skip, warm_start.as_ref()); + let result = biexp_fit::fit_biexponential(h_free, fs, refine, skip, warm_start.as_ref()); serde_wasm_bindgen::to_value(&result).unwrap_or(JsValue::NULL) } diff --git a/crates/solver/src/kernel.rs b/crates/solver/src/kernel.rs index 1561bbb2..36675085 100644 --- a/crates/solver/src/kernel.rs +++ b/crates/solver/src/kernel.rs @@ -1,15 +1,20 @@ +/// Clamp tau_rise away from tau_decay to prevent degenerate zero kernels. +/// When tau_rise ≈ tau_decay, the biexponential exp(-t/τ_d) - exp(-t/τ_r) collapses to zero. +pub(crate) fn clamp_tau_rise(tau_rise: f64, tau_decay: f64) -> f64 { + if (tau_rise - tau_decay).abs() < 1e-6 * tau_decay.max(tau_rise).max(1e-12) { + tau_decay * 0.5 + } else { + tau_rise + } +} + /// Build a double-exponential calcium kernel normalized to peak = 1.0. /// /// h(t) = exp(-t/tau_decay) - exp(-t/tau_rise), normalized so max(h) = 1.0. /// Kernel length extends until the decay envelope drops below 1e-6 of peak. /// Computed in f64 for precision, returned as Vec. pub fn build_kernel(tau_rise: f64, tau_decay: f64, fs: f64) -> Vec { - // Guard: tau_rise too close to tau_decay produces a degenerate zero kernel - let tau_rise = if (tau_rise - tau_decay).abs() < 1e-6 * tau_decay.max(tau_rise).max(1e-12) { - tau_decay * 0.5 - } else { - tau_rise - }; + let tau_rise = clamp_tau_rise(tau_rise, tau_decay); let dt = 1.0 / fs; @@ -49,12 +54,7 @@ pub fn build_kernel(tau_rise: f64, tau_decay: f64, fs: f64) -> Vec { /// Used by BandedAR2 tests and the TypeScript port in src/lib/ar2.ts. #[allow(dead_code)] pub fn tau_to_ar2(tau_rise: f64, tau_decay: f64, fs: f64) -> (f64, f64) { - // Guard: tau_rise too close to tau_decay produces a degenerate zero kernel - let tau_rise = if (tau_rise - tau_decay).abs() < 1e-6 * tau_decay.max(tau_rise).max(1e-12) { - tau_decay * 0.5 - } else { - tau_rise - }; + let tau_rise = clamp_tau_rise(tau_rise, tau_decay); let dt = 1.0 / fs; let d = (-dt / tau_decay).exp(); // decay eigenvalue @@ -84,16 +84,19 @@ pub fn compute_lipschitz(kernel: &[f32]) -> f64 { let inv = 2.0 * std::f64::consts::PI / (fft_len as f64); let mut max_power = 0.0_f64; - for w in 0..fft_len { + // Pre-cast kernel to f64 once instead of per-frequency + let kernel_f64: Vec = kernel.iter().map(|&k| k as f64).collect(); + + // Real kernel has symmetric spectrum: only need 0..=fft_len/2 + for w in 0..=fft_len / 2 { let freq = inv * (w as f64); let mut re = 0.0_f64; let mut im = 0.0_f64; - for (k, &hk) in kernel.iter().enumerate() { - let hk64 = hk as f64; + for (k, &hk) in kernel_f64.iter().enumerate() { let angle = freq * (k as f64); let (s, c) = angle.sin_cos(); - re += hk64 * c; - im -= hk64 * s; + re += hk * c; + im -= hk * s; } let power = re * re + im * im; if power > max_power { diff --git a/crates/solver/src/kernel_est.rs b/crates/solver/src/kernel_est.rs index 89e3943c..6aa3a98e 100644 --- a/crates/solver/src/kernel_est.rs +++ b/crates/solver/src/kernel_est.rs @@ -169,25 +169,18 @@ pub fn estimate_free_kernel( let mut stv = vec![0.0_f64; kernel_length]; // S^T S v let mut eigenvalue = 1.0_f64; + let mut v_f32 = vec![0.0_f32; kernel_length]; + for _ in 0..20 { - // S*v: convolve spikes with v (cast to f32) - let v_f32: Vec = v.iter().map(|&x| x as f32).collect(); + // Cast v to f32 into pre-allocated buffer + for (dst, &src) in v_f32.iter_mut().zip(v.iter()) { + *dst = src as f32; + } + // S*v: convolve spikes with v convolve_spikes_kernel(spike_trains, trace_lengths, &v_f32, &mut sv); // S^T (S*v) - stv.fill(0.0); - let mut off = 0; - for i in 0..n_traces { - let len = trace_lengths[i]; - for t in 0..len { - let val = sv[off + t] as f64; - let k_max = kernel_length.min(t + 1); - for k in 0..k_max { - stv[k] += val * spike_trains[off + t - k] as f64; - } - } - off += len; - } + adjoint_spikes_kernel(&sv, spike_trains, trace_lengths, kernel_length, &mut stv); // eigenvalue estimate = ||S^T S v|| eigenvalue = stv.iter().map(|&x| x * x).sum::().sqrt(); @@ -219,30 +212,28 @@ pub fn estimate_free_kernel( // Working buffer for S*h (convolution result) let mut sh = vec![0.0_f32; total_len]; + let mut z = vec![0.0_f64; kernel_length]; + for iter in 0..max_iters { // Forward: S*h (convolve each trace's spikes with h) convolve_spikes_kernel(spike_trains, trace_lengths, &h_prev, &mut sh); - // Residual: r = S*h - y_adj - // Gradient: S^T * r - gradient.fill(0.0); - offset = 0; - for i in 0..n_traces { - let len = trace_lengths[i]; - for t in 0..len { - let r = sh[offset + t] as f64 - y_adj[offset + t] as f64; - // S^T contribution: h[k] gets r * s[t-k] - let k_max = kernel_length.min(t + 1); - for k in 0..k_max { - gradient[k] += r * spike_trains[offset + t - k] as f64; - } - } - offset += len; + // Residual: r = S*h - y_adj (compute in-place in sh) + for i in 0..total_len { + sh[i] -= y_adj[i]; } + // Gradient: S^T * r + adjoint_spikes_kernel( + &sh, + spike_trains, + trace_lengths, + kernel_length, + &mut gradient, + ); + // Proximal gradient step: gradient descent on data-fidelity, then // TV proximal operator, then non-negativity projection. - let mut z = vec![0.0_f64; kernel_length]; for k in 0..kernel_length { z[k] = h_prev[k] as f64 - step_size * gradient[k]; } @@ -284,6 +275,29 @@ pub fn estimate_free_kernel( h } +/// Adjoint of spike convolution: output[k] += sum_t input[t] * s[t-k]. +/// This is S^T * input, the transpose of convolve_spikes_kernel. +fn adjoint_spikes_kernel( + input: &[f32], + spikes: &[f32], + trace_lengths: &[usize], + kernel_length: usize, + output: &mut [f64], +) { + output[..kernel_length].fill(0.0); + let mut offset = 0; + for &len in trace_lengths { + for t in 0..len { + let val = input[offset + t] as f64; + let k_max = kernel_length.min(t + 1); + for k in 0..k_max { + output[k] += val * spikes[offset + t - k] as f64; + } + } + offset += len; + } +} + /// Convolve spike trains with kernel h: output[t] = sum_k h[k] * s[t-k]. fn convolve_spikes_kernel(spikes: &[f32], trace_lengths: &[usize], h: &[f32], output: &mut [f32]) { let k_len = h.len(); diff --git a/crates/solver/src/lib.rs b/crates/solver/src/lib.rs index 57f7dd9c..3d27d24d 100644 --- a/crates/solver/src/lib.rs +++ b/crates/solver/src/lib.rs @@ -162,8 +162,15 @@ impl Solver { self.kernel_dc_gain = self.kernel.iter().map(|&k| k as f64).sum(); self.bandpass.update_cutoffs(tau_rise, tau_decay, fs); - // Update both convolution engines - self.banded.update(tau_rise, tau_decay, fs); + // Update convolution engines (only the active one + compute Lipschitz) + match self.conv_mode { + ConvMode::BandedAR2 => { + self.banded.update(tau_rise, tau_decay, fs); + } + ConvMode::Fft => { + // banded will be updated lazily if conv_mode switches + } + } self.lipschitz_constant = self.current_lipschitz(); // Update kernel FFT if buffers are already set up and large enough. @@ -304,11 +311,19 @@ impl Solver { /// Does NOT reset solution/iteration state — warm-start is preserved. pub fn set_conv_mode(&mut self, mode: ConvMode) { self.conv_mode = mode; - self.lipschitz_constant = self.current_lipschitz(); - // Ensure FFT buffers exist if switching to FFT mode with an active trace - if mode == ConvMode::Fft && self.active_len > 0 { - self.fft.ensure_buffers(self.active_len, &self.kernel); + match mode { + ConvMode::BandedAR2 => { + // Ensure banded coefficients are current (may have been skipped in set_params) + self.banded.update(self.tau_rise, self.tau_decay, self.fs); + } + ConvMode::Fft => { + // Ensure FFT buffers exist if switching to FFT mode with an active trace + if self.active_len > 0 { + self.fft.ensure_buffers(self.active_len, &self.kernel); + } + } } + self.lipschitz_constant = self.current_lipschitz(); } /// Set the constraint type (NonNegative or Box01). @@ -333,8 +348,7 @@ impl Solver { /// Format: [active_len (u32)] [t_fista (f64)] [iteration (u32)] [baseline (f64)] [solution f32...] [solution_prev f32...] pub fn export_state(&self) -> Vec { let n = self.active_len; - // 4 bytes active_len + 8 bytes t_fista + 4 bytes iteration + 8 bytes baseline + 2*n*4 bytes solutions (f32) - let mut buf = Vec::with_capacity(4 + 8 + 4 + 8 + 2 * n * 4); + let mut buf = Vec::with_capacity(state_byte_len(n)); buf.extend_from_slice(&(n as u32).to_le_bytes()); buf.extend_from_slice(&self.t_fista.to_le_bytes()); @@ -386,26 +400,24 @@ impl Solver { // Recompute baseline at current solution for display alignment. // In step_batch, baseline is skipped when filtered (cancels in gradient), // but the display path always needs it to align fit with trace. - { - let mut sum = 0.0_f64; - for i in 0..n { - sum += (self.trace[i] - self.reconvolution[i]) as f64; - } - let raw_baseline = sum / n as f64; - self.baseline = raw_baseline; - - // EMA smoothing for display (damps momentum-induced oscillation) - if !self.baseline_ema_init { - self.baseline_ema = raw_baseline; - self.baseline_ema_init = true; - } else { - self.baseline_ema = 0.3 * raw_baseline + 0.7 * self.baseline_ema; - } - } + let raw = compute_raw_baseline(&self.trace[..n], &self.reconvolution[..n], n); + self.update_baseline_ema(raw); self.reconvolution_stale = false; } + /// Update the baseline EMA from a raw baseline estimate. + /// Called by both `step_batch` (per-iteration) and `compute_reconvolution` (lazy display path). + fn update_baseline_ema(&mut self, raw_baseline: f64) { + self.baseline = raw_baseline; + if !self.baseline_ema_init { + self.baseline_ema = raw_baseline; + self.baseline_ema_init = true; + } else { + self.baseline_ema = 0.3 * raw_baseline + 0.7 * self.baseline_ema; + } + } + // --- Bandpass filter methods --- /// Convenience: set both HP and LP together (used by CaTune's single toggle). @@ -495,7 +507,7 @@ impl Solver { let mut cur = Cursor::new(state); let saved_len = read_u32_le(&mut cur) as usize; - let expected_size = 4 + 8 + 4 + 8 + 2 * saved_len * 4; + let expected_size = state_byte_len(saved_len); if state.len() != expected_size || saved_len != self.active_len { return; // size mismatch, cold start @@ -516,6 +528,20 @@ impl Solver { } } +/// Compute the mean residual (trace - reconvolution) as the raw baseline estimate. +pub(crate) fn compute_raw_baseline(trace: &[f32], reconvolution: &[f32], n: usize) -> f64 { + let mut sum = 0.0_f64; + for i in 0..n { + sum += (trace[i] - reconvolution[i]) as f64; + } + sum / n as f64 +} + +/// Byte length of serialized solver state for a trace of length `n`. +fn state_byte_len(n: usize) -> usize { + 4 + 8 + 4 + 8 + 2 * n * 4 // u32 + f64 + u32 + f64 + 2×n×f32 +} + // --- Little-endian cursor read helpers --- // These wrap the repetitive read_exact + from_le_bytes pattern used by load_state. // Each panics on short reads, which cannot occur when the caller has already diff --git a/crates/solver/src/peak_seed.rs b/crates/solver/src/peak_seed.rs index 8317eb7b..583911da 100644 --- a/crates/solver/src/peak_seed.rs +++ b/crates/solver/src/peak_seed.rs @@ -8,7 +8,6 @@ /// 5. Feed into estimate_free_kernel() → fit_biexponential() (already exist) /// /// The result provides initial tau_rise, tau_decay for the normal iterative pipeline. - use crate::biexp_fit::{fit_biexponential, BiexpResult}; use crate::kernel_est::estimate_free_kernel; @@ -28,7 +27,7 @@ pub struct SeedTraceResult { /// into the kernel estimation step as a replacement for FISTA trace inference. pub fn seed_trace(trace: &[f32], fs: f64) -> SeedTraceResult { let n = trace.len(); - let bl = median(trace); + let (bl, _) = median_and_mad(trace); let onsets = find_seed_spikes(trace, fs, 5.0); let mut s_counts = vec![0.0_f32; n]; @@ -58,7 +57,8 @@ pub struct SeedKernelResult { } /// Median of a slice (copies + sorts). Returns 0.0 for empty input. -pub(crate) fn median(data: &[f32]) -> f32 { +#[cfg(test)] +fn median(data: &[f32]) -> f32 { if data.is_empty() { return 0.0; } @@ -73,7 +73,8 @@ pub(crate) fn median(data: &[f32]) -> f32 { } /// Median absolute deviation of a slice. -pub(crate) fn mad(data: &[f32], median_val: f32) -> f32 { +#[cfg(test)] +fn mad(data: &[f32], median_val: f32) -> f32 { if data.is_empty() { return 0.0; } @@ -81,6 +82,32 @@ pub(crate) fn mad(data: &[f32], median_val: f32) -> f32 { median(&deviations) } +/// Compute median and MAD in one sort pass. +pub(crate) fn median_and_mad(data: &[f32]) -> (f32, f32) { + if data.is_empty() { + return (0.0, 0.0); + } + let mut sorted: Vec = data.to_vec(); + sorted.sort_unstable_by(|a, b| a.total_cmp(b)); + let n = sorted.len(); + let med = if n % 2 == 0 { + (sorted[n / 2 - 1] + sorted[n / 2]) / 2.0 + } else { + sorted[n / 2] + }; + // Compute deviations in-place (reuse sorted buffer) + for v in sorted.iter_mut() { + *v = (*v - med).abs(); + } + sorted.sort_unstable_by(|a, b| a.total_cmp(b)); + let mad_val = if n % 2 == 0 { + (sorted[n / 2 - 1] + sorted[n / 2]) / 2.0 + } else { + sorted[n / 2] + }; + (med, mad_val) +} + /// Find seed spike onset locations from a single trace. /// /// 1. Compute baseline = median(trace) @@ -96,8 +123,7 @@ pub fn find_seed_spikes(trace: &[f32], fs: f64, min_peak_distance_s: f64) -> Vec return Vec::new(); } - let baseline = median(trace); - let mad_val = mad(trace, baseline); + let (baseline, mad_val) = median_and_mad(trace); if mad_val < 1e-10 { return Vec::new(); @@ -191,7 +217,7 @@ pub fn seed_kernel_estimate( let mut offset = 0; for &len in trace_lengths { let trace = &traces_flat[offset..offset + len]; - let bl = median(trace); + let (bl, _) = median_and_mad(trace); let onsets = find_seed_spikes(trace, fs, min_peak_distance_s); for &onset in &onsets { diff --git a/crates/solver/src/threshold.rs b/crates/solver/src/threshold.rs index 44595b35..eda60526 100644 --- a/crates/solver/src/threshold.rs +++ b/crates/solver/src/threshold.rs @@ -51,7 +51,7 @@ pub fn threshold_search( // Collect sorted unique non-zero values for threshold candidates let mut vals: Vec = s_relaxed.iter().copied().filter(|&v| v > 1e-10).collect(); - vals.sort_by(|a, b| a.partial_cmp(b).unwrap()); + vals.sort_unstable_by(|a, b| a.total_cmp(b)); vals.dedup_by(|a, b| (*a - *b).abs() < 1e-10); if vals.is_empty() { @@ -71,7 +71,7 @@ pub fn threshold_search( let mut conv_buf = vec![0.0_f32; n]; let mut best = ThresholdResult { - s_binary: vec![0.0; n], + s_binary: Vec::new(), alpha: 0.0, baseline: 0.0, threshold: 0.0, @@ -172,29 +172,28 @@ pub fn threshold_search( let (alpha, baseline) = lstsq_alpha_baseline(&conv_buf, y, pad, max_alpha); best.alpha = alpha; best.baseline = baseline; - best.s_binary = s_bin.clone(); + best.s_binary = s_bin; // Compute PVE (proportion of variance explained) let inner_range = pad..n.saturating_sub(pad); let inner_len = inner_range.len(); if inner_len > 0 { - let y_mean: f64 = inner_range.clone().map(|i| y[i] as f64).sum::() / inner_len as f64; - - let ss_tot: f64 = inner_range - .clone() - .map(|i| { - let d = y[i] as f64 - y_mean; - d * d - }) - .sum(); - - let ss_res: f64 = inner_range - .map(|i| { - let pred = alpha * conv_buf[i] as f64 + baseline; - let d = y[i] as f64 - pred; - d * d - }) - .sum(); + let mut y_sum = 0.0_f64; + for i in inner_range.clone() { + y_sum += y[i] as f64; + } + let y_mean = y_sum / inner_len as f64; + + let mut ss_tot = 0.0_f64; + let mut ss_res = 0.0_f64; + for i in inner_range { + let yi = y[i] as f64; + let d = yi - y_mean; + ss_tot += d * d; + let pred = alpha * conv_buf[i] as f64 + baseline; + let r = yi - pred; + ss_res += r * r; + } best.pve = if ss_tot > 1e-20 { 1.0 - ss_res / ss_tot