Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified crates/solver/pkg/calab_solver_bg.wasm
Binary file not shown.
41 changes: 14 additions & 27 deletions crates/solver/src/banded.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use crate::kernel::clamp_tau_rise;

/// Banded AR(2) convolution engine — O(T) replacement for FFT-based O(T log T).
///
/// The AR(2) model c[t] = g1*c[t-1] + g2*c[t-2] + s[t] defines a banded
Expand All @@ -19,6 +21,7 @@ pub(crate) struct BandedAR2 {
impl BandedAR2 {
/// Create a new BandedAR2 with the given tau parameters.
pub(crate) fn new(tau_rise: f64, tau_decay: f64, fs: f64) -> Self {
let tau_rise = clamp_tau_rise(tau_rise, tau_decay);
let dt = 1.0 / fs;
let d = (-dt / tau_decay).exp();
let r = (-dt / tau_rise).exp();
Expand All @@ -37,20 +40,13 @@ impl BandedAR2 {

/// Recompute coefficients after parameter change.
pub(crate) fn update(&mut self, tau_rise: f64, tau_decay: f64, fs: f64) {
let dt = 1.0 / fs;
let d = (-dt / tau_decay).exp();
let r = (-dt / tau_rise).exp();
self.g1 = d + r;
self.g2 = -(d * r);
self.impulse_peak = compute_impulse_peak(self.g1, self.g2, tau_decay, fs);
self.lipschitz =
compute_banded_lipschitz(self.g1, self.g2) / (self.impulse_peak * self.impulse_peak);
*self = Self::new(tau_rise, tau_decay, fs);
}

/// Forward convolution: s -> normalized AR2 output, O(T).
///
/// Runs the raw AR2 recursion then divides by the impulse peak so that
/// a single spike produces a peak of 1.0 at all sampling rates.
/// Pre-scales input by 1/peak so the AR2 recursion directly produces
/// a peak-normalized output — no second normalization pass needed.
pub(crate) fn convolve_forward(&self, source: &[f32], output: &mut [f32]) {
let n = source.len();
if n == 0 {
Expand All @@ -61,23 +57,19 @@ impl BandedAR2 {
let g2 = self.g2 as f32;
let inv_peak = (1.0 / self.impulse_peak) as f32;

output[0] = source[0];
output[0] = source[0] * inv_peak;
if n > 1 {
output[1] = g1 * output[0] + source[1];
output[1] = g1 * output[0] + source[1] * inv_peak;
}
for t in 2..n {
output[t] = g1 * output[t - 1] + g2 * output[t - 2] + source[t];
}

// Normalize by impulse peak
for v in &mut output[..n] {
*v *= inv_peak;
output[t] = g1 * output[t - 1] + g2 * output[t - 2] + source[t] * inv_peak;
}
}

/// Adjoint convolution: normalized adjoint, O(T).
///
/// Adjoint of (K / peak) = K^T / peak.
/// Pre-scales input by 1/peak so the backward AR2 recursion directly
/// produces a peak-normalized output — no second normalization pass needed.
pub(crate) fn convolve_adjoint(&self, source: &[f32], output: &mut [f32]) {
let n = source.len();
if n == 0 {
Expand All @@ -88,17 +80,12 @@ impl BandedAR2 {
let g2 = self.g2 as f32;
let inv_peak = (1.0 / self.impulse_peak) as f32;

output[n - 1] = source[n - 1];
output[n - 1] = source[n - 1] * inv_peak;
if n > 1 {
output[n - 2] = source[n - 2] + g1 * output[n - 1];
output[n - 2] = source[n - 2] * inv_peak + g1 * output[n - 1];
}
for t in (0..n.saturating_sub(2)).rev() {
output[t] = source[t] + g1 * output[t + 1] + g2 * output[t + 2];
}

// Normalize by impulse peak
for v in &mut output[..n] {
*v *= inv_peak;
output[t] = source[t] * inv_peak + g1 * output[t + 1] + g2 * output[t + 2];
}
}

Expand Down
20 changes: 11 additions & 9 deletions crates/solver/src/baseline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,19 @@ impl Ord for OrderedF32 {
/// Used for O(log M) k-th element queries via binary lifting.
struct FenwickTree {
tree: Vec<i32>,
msb: usize, // highest power of 2 <= (tree.len() - 1)
}

impl FenwickTree {
fn new(size: usize) -> Self {
let mut msb = 1;
while msb <= size {
msb <<= 1;
}
msb >>= 1;
Self {
tree: vec![0; size + 1], // 1-indexed
tree: vec![0; size + 1],
msb,
}
}

Expand All @@ -66,14 +73,9 @@ impl FenwickTree {
/// Find the 0-indexed position of the k-th element (1-based k).
/// Uses binary lifting: O(log M) time.
fn kth(&self, mut k: i32) -> usize {
let n = self.tree.len() - 1; // max 0-indexed position + 1
let n = self.tree.len() - 1;
let mut pos = 0;
// Find highest power of 2 <= n
let mut bit = 1;
while bit <= n {
bit <<= 1;
}
bit >>= 1;
let mut bit = self.msb;

while bit > 0 {
let next = pos + bit;
Expand All @@ -83,7 +85,7 @@ impl FenwickTree {
}
bit >>= 1;
}
pos // 0-indexed coordinate
pos
}
}

Expand Down
Loading
Loading