From b4447fb07ac2aa465650d283725bcb2809dd838a Mon Sep 17 00:00:00 2001 From: daharoni Date: Sat, 21 Mar 2026 22:22:02 -0700 Subject: [PATCH 1/2] feat: add CaDecon Python bridge for automated deconvolution export Enable `calab.decon(traces, fs)` to open CaDecon in the browser and receive structured results (activity matrix + per-cell scalars + kernel params) back to Python via a two-POST binary transport pattern. - Add npy-writer.ts (inverse of npy-parser) for binary activity export - Extend bridge server with /api/v1/results/activity (.npy) and /api/v1/results (JSON) endpoints, plus configurable app param - Add decon() orchestrator, CaDeconResult NamedTuple, biexp waveform builder, and `calab cadecon` CLI subcommand - Wire ExportButton in CaDecon UI with loading/disabled states - 22 new tests (7 npy-writer, 5 bridge endpoints, 7 CaDeconResult, 3 waveform) Co-Authored-By: Claude Opus 4.6 (1M context) --- .../community/GroundTruthControls.tsx | 58 ++++++- apps/cadecon/src/lib/export-utils.ts | 88 +++++++++++ packages/io/src/__tests__/npy-writer.test.ts | 101 ++++++++++++ packages/io/src/bridge.ts | 55 +++++++ packages/io/src/index.ts | 4 + packages/io/src/npy-writer.ts | 62 ++++++++ python/src/calab/__init__.py | 5 +- python/src/calab/_bridge/__init__.py | 11 +- python/src/calab/_bridge/_apps.py | 149 +++++++++++++++++- python/src/calab/_bridge/_server.py | 59 ++++++- python/src/calab/_cli.py | 55 +++++++ python/src/calab/_compute.py | 45 ++++++ python/tests/test_bridge.py | 98 ++++++++++++ python/tests/test_decon.py | 95 +++++++++++ 14 files changed, 867 insertions(+), 18 deletions(-) create mode 100644 apps/cadecon/src/lib/export-utils.ts create mode 100644 packages/io/src/__tests__/npy-writer.test.ts create mode 100644 packages/io/src/npy-writer.ts create mode 100644 python/tests/test_decon.py diff --git a/apps/cadecon/src/components/community/GroundTruthControls.tsx b/apps/cadecon/src/components/community/GroundTruthControls.tsx index 04f6d5d0..58d21809 100644 --- a/apps/cadecon/src/components/community/GroundTruthControls.tsx +++ b/apps/cadecon/src/components/community/GroundTruthControls.tsx @@ -1,6 +1,6 @@ /** Ground truth reveal/toggle controls and export button for CaDecon. */ -import { Show, type JSX } from 'solid-js'; +import { Show, createSignal, type JSX } from 'solid-js'; import { isDemo, groundTruthVisible, @@ -8,7 +8,12 @@ import { revealGroundTruth, toggleGroundTruthVisibility, bridgeUrl, + setBridgeExportDone, + bridgeExportDone, } from '../../lib/data-store.ts'; +import { runState } from '../../lib/iteration-store.ts'; +import { exportCaDeconToBridge } from '@calab/io'; +import { buildCaDeconActivityMatrix, buildCaDeconResultsPayload } from '../../lib/export-utils.ts'; export function GroundTruthControls(): JSX.Element { function handleToggle(): void { @@ -46,11 +51,58 @@ export function GroundTruthNotices(): JSX.Element { } export function ExportButton(): JSX.Element { + const [exporting, setExporting] = createSignal(false); + const [error, setError] = createSignal(null); + + const isComplete = () => runState() === 'complete'; + const isBridge = () => !!bridgeUrl(); + const isDisabled = () => !isComplete() || exporting() || bridgeExportDone(); + + async function handleExport(): Promise { + const url = bridgeUrl(); + if (!url) return; + + setExporting(true); + setError(null); + try { + const { data, shape } = buildCaDeconActivityMatrix(); + const results = buildCaDeconResultsPayload(); + await exportCaDeconToBridge(url, data, shape, results); + setBridgeExportDone(true); + } catch (e) { + setError(e instanceof Error ? e.message : 'Export failed'); + } finally { + setExporting(false); + } + } + return ( - + + {error()} + ); } diff --git a/apps/cadecon/src/lib/export-utils.ts b/apps/cadecon/src/lib/export-utils.ts new file mode 100644 index 00000000..2af37532 --- /dev/null +++ b/apps/cadecon/src/lib/export-utils.ts @@ -0,0 +1,88 @@ +/** + * Collects CaDecon iteration results for export to the Python bridge. + */ + +import { cellResultLookup, convergenceHistory, convergedAtIteration } from './iteration-store.ts'; +import { samplingRate, numCells, numTimepoints } from './data-store.ts'; + +/** + * Build a contiguous Float32Array activity matrix from per-cell sCounts. + * Returns the flat array and its [n_cells, n_timepoints] shape. + */ +export function buildCaDeconActivityMatrix(): { + data: Float32Array; + shape: [number, number]; +} { + const lookup = cellResultLookup(); + const nCells = numCells() ?? 0; + const nTime = numTimepoints() ?? 0; + + const data = new Float32Array(nCells * nTime); + + // Sorted cell indices for deterministic row order + const sortedCells = [...lookup.keys()].sort((a, b) => a - b); + for (let row = 0; row < sortedCells.length; row++) { + const entry = lookup.get(sortedCells[row])!; + const offset = row * nTime; + const len = Math.min(entry.sCounts.length, nTime); + data.set(entry.sCounts.subarray(0, len), offset); + } + + return { data, shape: [sortedCells.length, nTime] }; +} + +/** + * Build the JSON results payload with per-cell scalars, kernel params, and metadata. + */ +export function buildCaDeconResultsPayload(): Record { + const lookup = cellResultLookup(); + const history = convergenceHistory(); + const fs = samplingRate() ?? 30; + + const sortedCells = [...lookup.keys()].sort((a, b) => a - b); + const alphas: number[] = []; + const baselines: number[] = []; + const pves: number[] = []; + + for (const cellIdx of sortedCells) { + const entry = lookup.get(cellIdx)!; + alphas.push(entry.alpha); + baselines.push(entry.baseline); + pves.push(entry.pve); + } + + // Kernel params from last convergence snapshot + const latest = history.length > 0 ? history[history.length - 1] : null; + const tauRise = latest?.tauRise ?? 0; + const tauDecay = latest?.tauDecay ?? 0; + const beta = latest?.beta ?? 1; + const tauRiseFast = latest?.tauRiseFast ?? 0; + const tauDecayFast = latest?.tauDecayFast ?? 0; + const betaFast = latest?.betaFast ?? 0; + const residual = latest?.residual ?? 0; + + // h_free from first subset (data-driven kernel shape) + const hFree = latest && latest.subsets.length > 0 ? Array.from(latest.subsets[0].hFree) : []; + + const convergedAt = convergedAtIteration(); + + return { + alphas, + baselines, + pves, + fs, + tau_rise: tauRise, + tau_decay: tauDecay, + beta, + tau_rise_fast: tauRiseFast, + tau_decay_fast: tauDecayFast, + beta_fast: betaFast, + residual, + h_free: hFree, + num_iterations: history.length, + converged: convergedAt !== null, + converged_at_iteration: convergedAt, + schema_version: 1, + export_date: new Date().toISOString(), + }; +} diff --git a/packages/io/src/__tests__/npy-writer.test.ts b/packages/io/src/__tests__/npy-writer.test.ts new file mode 100644 index 00000000..e2c949fe --- /dev/null +++ b/packages/io/src/__tests__/npy-writer.test.ts @@ -0,0 +1,101 @@ +import { describe, it, expect } from 'vitest'; +import { writeNpy } from '../npy-writer.ts'; +import { parseNpy } from '../npy-parser.ts'; + +describe('writeNpy', () => { + describe('roundtrip with parseNpy', () => { + it('roundtrips a 2D float32 array', () => { + const data = new Float32Array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + const shape = [2, 3]; + + const buffer = writeNpy(data, shape); + const result = parseNpy(buffer); + + expect(result.shape).toEqual([2, 3]); + expect(result.dtype).toBe(' { + const data = new Float32Array([10.5, 20.5, 30.5]); + const shape = [3]; + + const buffer = writeNpy(data, shape); + const result = parseNpy(buffer); + + expect(result.shape).toEqual([3]); + expect(result.data.length).toBe(3); + expect(result.data[0]).toBeCloseTo(10.5, 5); + expect(result.data[2]).toBeCloseTo(30.5, 5); + }); + + it('roundtrips a larger matrix', () => { + const rows = 10; + const cols = 500; + const data = new Float32Array(rows * cols); + for (let i = 0; i < data.length; i++) { + data[i] = Math.sin(i * 0.01); + } + + const buffer = writeNpy(data, [rows, cols]); + const result = parseNpy(buffer); + + expect(result.shape).toEqual([rows, cols]); + expect(result.data.length).toBe(rows * cols); + for (let i = 0; i < 10; i++) { + expect(result.data[i]).toBeCloseTo(data[i], 5); + } + }); + }); + + describe('binary format', () => { + it('starts with correct magic bytes', () => { + const data = new Float32Array([1.0]); + const buffer = writeNpy(data, [1]); + const bytes = new Uint8Array(buffer); + + // \x93NUMPY + expect(bytes[0]).toBe(0x93); + expect(bytes[1]).toBe(0x4e); + expect(bytes[2]).toBe(0x55); + expect(bytes[3]).toBe(0x4d); + expect(bytes[4]).toBe(0x50); + expect(bytes[5]).toBe(0x59); + }); + + it('uses version 1.0', () => { + const data = new Float32Array([1.0]); + const buffer = writeNpy(data, [1]); + const bytes = new Uint8Array(buffer); + + expect(bytes[6]).toBe(1); // major + expect(bytes[7]).toBe(0); // minor + }); + + it('header + preamble is 64-byte aligned', () => { + const data = new Float32Array([1.0, 2.0]); + const buffer = writeNpy(data, [2]); + const view = new DataView(buffer); + + const headerLen = view.getUint16(8, true); + const totalPreamble = 10 + headerLen; // magic(6) + version(2) + headerLen(2) + header + expect(totalPreamble % 64).toBe(0); + }); + + it('header terminates with newline', () => { + const data = new Float32Array([1.0]); + const buffer = writeNpy(data, [1]); + const view = new DataView(buffer); + const bytes = new Uint8Array(buffer); + + const headerLen = view.getUint16(8, true); + // Last byte of header should be newline + expect(bytes[10 + headerLen - 1]).toBe(0x0a); + }); + }); +}); diff --git a/packages/io/src/bridge.ts b/packages/io/src/bridge.ts index e502fafc..93206c8f 100644 --- a/packages/io/src/bridge.ts +++ b/packages/io/src/bridge.ts @@ -6,6 +6,7 @@ */ import { parseNpy } from './npy-parser.ts'; +import { writeNpy } from './npy-writer.ts'; import { processNpyResult } from './array-utils.ts'; import type { NpyResult } from '@calab/core'; @@ -92,3 +93,57 @@ export function stopBridgeHeartbeat(): void { heartbeatTimer = null; } } + +/** + * POST the activity matrix as .npy binary to the bridge server. + * Used by CaDecon to send the large activity array before the JSON results. + */ +export async function postActivityToBridge( + bridgeUrl: string, + activity: Float32Array, + shape: [number, number], +): Promise { + const npyBuffer = writeNpy(activity, shape); + const resp = await fetch(`${bridgeUrl}/api/v1/results/activity`, { + method: 'POST', + headers: { 'Content-Type': 'application/octet-stream' }, + body: npyBuffer, + }); + if (!resp.ok) { + throw new Error(`Bridge: failed to post activity (${resp.status})`); + } +} + +/** + * POST the results JSON (scalars + metadata) to the bridge server. + * This acts as the "done" signal for the two-POST CaDecon export. + */ +export async function postResultsToBridge( + bridgeUrl: string, + results: Record, +): Promise { + const resp = await fetch(`${bridgeUrl}/api/v1/results`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(results), + }); + if (!resp.ok) { + throw new Error(`Bridge: failed to post results (${resp.status})`); + } +} + +/** + * Export CaDecon results to the bridge server. + * Sequences: activity POST first (large binary), then results POST (small JSON, triggers done). + * Stops the heartbeat after both succeed. + */ +export async function exportCaDeconToBridge( + bridgeUrl: string, + activity: Float32Array, + shape: [number, number], + results: Record, +): Promise { + await postActivityToBridge(bridgeUrl, activity, shape); + await postResultsToBridge(bridgeUrl, results); + stopBridgeHeartbeat(); +} diff --git a/packages/io/src/index.ts b/packages/io/src/index.ts index 205d00ef..5385e1e9 100644 --- a/packages/io/src/index.ts +++ b/packages/io/src/index.ts @@ -1,4 +1,5 @@ export { parseNpy } from './npy-parser.ts'; +export { writeNpy } from './npy-writer.ts'; export { parseNpz } from './npz-parser.ts'; export { validateTraceData } from './validation.ts'; export { extractCellTrace, processNpyResult } from './array-utils.ts'; @@ -9,6 +10,9 @@ export { getBridgeUrl, fetchBridgeData, postParamsToBridge, + postActivityToBridge, + postResultsToBridge, + exportCaDeconToBridge, startBridgeHeartbeat, stopBridgeHeartbeat, } from './bridge.ts'; diff --git a/packages/io/src/npy-writer.ts b/packages/io/src/npy-writer.ts new file mode 100644 index 00000000..e9277ebd --- /dev/null +++ b/packages/io/src/npy-writer.ts @@ -0,0 +1,62 @@ +// .npy binary format writer +// Inverse of npy-parser.ts — serializes a Float32Array + shape into .npy format. +// Reference: https://numpy.org/doc/2.3/reference/generated/numpy.lib.format.html + +/** + * Write a Float32Array as a .npy binary buffer (version 1.0, little-endian float32). + * + * @param data - The flat Float32Array of values + * @param shape - The array shape, e.g. [rows, cols] + * @returns ArrayBuffer containing the complete .npy file + */ +export function writeNpy(data: Float32Array, shape: number[]): ArrayBuffer { + // 1. Build header dict string + const shapeStr = shape.length === 1 ? `(${shape[0]},)` : `(${shape.join(', ')})`; + const headerDict = `{'descr': '= timeout: + break + + if server.last_heartbeat is not None: + if (now - server.last_heartbeat) > HEARTBEAT_TIMEOUT: + print("\nBrowser disconnected (heartbeat timeout).") + break + except KeyboardInterrupt: + print("\nBridge cancelled by user.") + finally: + server.shutdown() + + if not received or server.received_results is None: + return None + + results = server.received_results + activity = server.received_activity + if activity is None: + print("Warning: results received but activity matrix was missing.") + return None + + # Build kernel waveforms from biexp params + result_fs = results.get("fs", fs) + tau_rise = results.get("tau_rise", 0.2) + tau_decay = results.get("tau_decay", 1.0) + beta = results.get("beta", 1.0) + kernel_length = int(5.0 * tau_decay * result_fs) + kernel_slow = _build_biexp_waveform(tau_rise, tau_decay, beta, result_fs, kernel_length) + + tau_rise_fast = results.get("tau_rise_fast", 0.0) + tau_decay_fast = results.get("tau_decay_fast", 0.0) + beta_fast = results.get("beta_fast", 0.0) + if tau_decay_fast > 0 and beta_fast != 0: + kernel_length_fast = int(5.0 * tau_decay_fast * result_fs) + kernel_fast = _build_biexp_waveform( + tau_rise_fast, tau_decay_fast, beta_fast, result_fs, kernel_length_fast, + ) + else: + kernel_fast = np.empty(0, dtype=np.float32) + + # Assemble per-cell arrays + alphas = np.array(results.get("alphas", []), dtype=np.float64) + baselines = np.array(results.get("baselines", []), dtype=np.float64) + pves = np.array(results.get("pves", []), dtype=np.float64) + + # Build metadata dict + metadata = { + "tau_rise": tau_rise, + "tau_decay": tau_decay, + "beta": beta, + "tau_rise_fast": tau_rise_fast, + "tau_decay_fast": tau_decay_fast, + "beta_fast": beta_fast, + } + for key in ( + "residual", "h_free", "num_iterations", "converged", + "converged_at_iteration", "schema_version", "calab_version", + "export_date", + ): + if key in results: + value = results[key] + if key == "h_free" and not isinstance(value, list): + value = list(value) + metadata[key] = value + + return CaDeconResult( + activity=np.asarray(activity, dtype=np.float32), + alphas=alphas, + baselines=baselines, + pves=pves, + kernel_slow=kernel_slow, + kernel_fast=kernel_fast, + fs=result_fs, + metadata=metadata, + ) diff --git a/python/src/calab/_bridge/_server.py b/python/src/calab/_bridge/_server.py index 2bed32d4..89df6382 100644 --- a/python/src/calab/_bridge/_server.py +++ b/python/src/calab/_bridge/_server.py @@ -1,6 +1,6 @@ -"""Localhost HTTP bridge server for CaTune <-> Python communication. +"""Localhost HTTP bridge server for CaLab <-> Python communication. -Serves traces as .npy binary and receives exported params as JSON. +Serves traces as .npy binary and receives exported params/results. Binds to 127.0.0.1 only (not network-reachable). CORS enabled for HTTPS->localhost mixed-content requests. """ @@ -42,6 +42,18 @@ def _send_json(self, obj: Any) -> None: """Send a JSON-serializable object as a CORS response.""" self._send_cors_response(json.dumps(obj).encode()) + def _send_error_cors(self, code: int, message: str) -> None: + """Send an error response with CORS headers.""" + body = json.dumps({"error": message}).encode() + self.send_response(code) + self.send_header("Access-Control-Allow-Origin", "*") + self.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + self.send_header("Access-Control-Allow-Headers", "Content-Type") + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + def do_OPTIONS(self) -> None: """Handle CORS preflight.""" self._send_cors_response(b"", content_type="text/plain") @@ -52,7 +64,7 @@ def do_GET(self) -> None: elif self.path == "/api/v1/metadata": self._serve_metadata() elif self.path == "/api/v1/status": - self._send_json({"ready": True, "app": "catune"}) + self._send_json({"ready": True, "app": self.server.app}) elif self.path == "/api/v1/health": self._send_cors_response(b"ok", content_type="text/plain") else: @@ -64,6 +76,10 @@ def do_POST(self) -> None: elif self.path == "/api/v1/heartbeat": self.server.last_heartbeat = time.monotonic() self._send_json({"status": "ok"}) + elif self.path == "/api/v1/results/activity": + self._receive_results_activity() + elif self.path == "/api/v1/results": + self._receive_results() else: self.send_error(404, "Not Found") @@ -96,21 +112,56 @@ def _receive_params(self) -> None: self.server.params_event.set() self._send_json({"status": "ok"}) + def _receive_results_activity(self) -> None: + """Receive activity matrix as .npy binary from CaDecon.""" + content_length = int(self.headers.get("Content-Length", 0)) + body = self.rfile.read(content_length) + + try: + arr = np.load(io.BytesIO(body)) + except Exception: + self._send_error_cors(400, "Invalid .npy data") + return + + self.server.received_activity = arr + self._send_json({"status": "ok"}) + + def _receive_results(self) -> None: + """Receive CaDecon results JSON (scalars + metadata). Triggers completion event.""" + content_length = int(self.headers.get("Content-Length", 0)) + body = self.rfile.read(content_length) + + try: + results = json.loads(body) + except json.JSONDecodeError: + self._send_error_cors(400, "Invalid JSON") + return + + self.server.received_results = results + self.server.results_event.set() + self._send_json({"status": "ok"}) + class BridgeServer(HTTPServer): - """HTTP server that holds trace data and waits for params.""" + """HTTP server that holds trace data and waits for params/results.""" def __init__( self, traces: np.ndarray, fs: float, port: int = 0, + app: str = "catune", ) -> None: self.traces = np.atleast_2d(np.asarray(traces, dtype=np.float64)) self.fs = fs + self.app = app self.received_params: dict | None = None self.params_event = threading.Event() self.last_heartbeat: float | None = None + # CaDecon results (two-POST pattern) + self.received_activity: np.ndarray | None = None + self.received_results: dict | None = None + self.results_event = threading.Event() super().__init__(("127.0.0.1", port), BridgeHandler) diff --git a/python/src/calab/_cli.py b/python/src/calab/_cli.py index 0611cbdc..8c27b269 100644 --- a/python/src/calab/_cli.py +++ b/python/src/calab/_cli.py @@ -53,6 +53,52 @@ def _to_serializable(value): return value.tolist() if hasattr(value, "tolist") else value +def cmd_cadecon(args: argparse.Namespace) -> None: + """Open CaDecon for automated deconvolution.""" + from ._bridge import decon + + traces = np.load(args.file) + if traces.ndim == 1: + traces = traces.reshape(1, -1) + + result = decon( + traces, + fs=args.fs, + port=args.port, + open_browser=not args.no_browser, + ) + + if result is None: + print("No results received.", file=sys.stderr) + sys.exit(1) + + print(f"Activity shape: {result.activity.shape}") + print(f"Sampling rate: {result.fs} Hz") + print(f"Alphas: {result.alphas}") + print(f"Baselines: {result.baselines}") + print(f"PVEs: {result.pves}") + print(f"Kernel slow length: {len(result.kernel_slow)}") + print(f"Kernel fast length: {len(result.kernel_fast)}") + + if args.output: + np.save(f"{args.output}_activity.npy", result.activity) + results_json = { + "alphas": result.alphas.tolist(), + "baselines": result.baselines.tolist(), + "pves": result.pves.tolist(), + "fs": result.fs, + "kernel_slow_length": len(result.kernel_slow), + "kernel_fast_length": len(result.kernel_fast), + "metadata": { + k: (v.tolist() if hasattr(v, "tolist") else v) + for k, v in result.metadata.items() + }, + } + with open(f"{args.output}_results.json", "w") as f: + json.dump(results_json, f, indent=2) + print(f"Saved to {args.output}_activity.npy and {args.output}_results.json") + + def cmd_deconvolve(args: argparse.Namespace) -> None: """Batch deconvolution from file.""" from ._compute import bandpass_filter, run_deconvolution, run_deconvolution_full @@ -187,6 +233,15 @@ def main() -> None: p_tune.add_argument("--no-browser", action="store_true", help="Don't open browser") p_tune.set_defaults(func=cmd_tune) + # cadecon + p_cadecon = subparsers.add_parser("cadecon", help="Open CaDecon for automated deconvolution") + p_cadecon.add_argument("file", help="Input .npy file") + p_cadecon.add_argument("--fs", type=float, default=30.0, help="Sampling rate (Hz)") + p_cadecon.add_argument("--port", type=int, default=None, help="Server port") + p_cadecon.add_argument("--no-browser", action="store_true", help="Don't open browser") + p_cadecon.add_argument("--output", "-o", default=None, help="Output path stem") + p_cadecon.set_defaults(func=cmd_cadecon) + # deconvolve p_deconv = subparsers.add_parser("deconvolve", help="Batch deconvolution") p_deconv.add_argument("file", help="Input .npy file") diff --git a/python/src/calab/_compute.py b/python/src/calab/_compute.py index ad89741d..daea99b1 100644 --- a/python/src/calab/_compute.py +++ b/python/src/calab/_compute.py @@ -19,6 +19,51 @@ ) +class CaDeconResult(NamedTuple): + """Full result from CaDecon (automated deconvolution via InDeCa algorithm). + + Attributes + ---------- + activity : np.ndarray + Deconvolved activity matrix, shape ``(n_cells, n_timepoints)``, float32. + alphas : np.ndarray + Per-cell scaling factors, shape ``(n_cells,)``, float64. + baselines : np.ndarray + Per-cell baseline estimates, shape ``(n_cells,)``, float64. + pves : np.ndarray + Per-cell proportion of variance explained, shape ``(n_cells,)``, float64. + kernel_slow : np.ndarray + Slow biexponential kernel waveform, float32. + kernel_fast : np.ndarray + Fast biexponential kernel waveform, float32 (empty if single-component). + fs : float + Sampling rate in Hz. + metadata : dict + Extensible dict with biexp params, convergence info, h_free, etc. + """ + + activity: np.ndarray + alphas: np.ndarray + baselines: np.ndarray + pves: np.ndarray + kernel_slow: np.ndarray + kernel_fast: np.ndarray + fs: float + metadata: dict + + +def _build_biexp_waveform( + tau_rise: float, tau_decay: float, beta: float, fs: float, length: int, +) -> np.ndarray: + """Build a biexponential waveform: beta * (exp(-t/tau_d) - exp(-t/tau_r)). + + Uses the same 5x tau_decay length convention as the browser solver. + """ + t = np.arange(length) / fs + waveform = beta * (np.exp(-t / tau_decay) - np.exp(-t / tau_rise)) + return waveform.astype(np.float32) + + class DeconvolutionResult(NamedTuple): """Full result from FISTA deconvolution. diff --git a/python/tests/test_bridge.py b/python/tests/test_bridge.py index 1f34db81..72dbc03b 100644 --- a/python/tests/test_bridge.py +++ b/python/tests/test_bridge.py @@ -29,6 +29,21 @@ def bridge_server(): server.shutdown() +@pytest.fixture +def cadecon_server(): + """Start a bridge server in cadecon mode on a random port.""" + rng = np.random.default_rng(42) + traces = rng.standard_normal((3, 200)) + server = BridgeServer(traces, fs=30.0, app="cadecon") + + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + + yield server + + server.shutdown() + + def _get(server: BridgeServer, path: str) -> tuple[int, bytes]: """Make a GET request to the bridge server.""" url = f"http://127.0.0.1:{server.port}{path}" @@ -53,6 +68,18 @@ def _post(server: BridgeServer, path: str, data: dict) -> tuple[int, bytes]: return e.code, e.read() +def _post_binary(server: BridgeServer, path: str, data: bytes) -> tuple[int, bytes]: + """Make a POST request with binary data.""" + url = f"http://127.0.0.1:{server.port}{path}" + req = urllib.request.Request(url, data=data, method="POST") + req.add_header("Content-Type", "application/octet-stream") + try: + with urllib.request.urlopen(req, timeout=5) as resp: + return resp.status, resp.read() + except urllib.error.HTTPError as e: + return e.code, e.read() + + def test_health_endpoint(bridge_server: BridgeServer) -> None: """GET /api/v1/health returns 200 ok.""" status, body = _get(bridge_server, "/api/v1/health") @@ -178,3 +205,74 @@ def test_heartbeat_timeout_detection(bridge_server: BridgeServer) -> None: since_last = time.monotonic() - bridge_server.last_heartbeat assert since_last > HEARTBEAT_TIMEOUT + + +# --- CaDecon bridge tests --- + + +def test_status_cadecon(cadecon_server: BridgeServer) -> None: + """GET /api/v1/status returns app: cadecon.""" + status, body = _get(cadecon_server, "/api/v1/status") + assert status == 200 + data = json.loads(body) + assert data["ready"] is True + assert data["app"] == "cadecon" + + +def test_results_activity_post(cadecon_server: BridgeServer) -> None: + """POST /api/v1/results/activity stores a .npy array.""" + import io as sysio + + activity = np.random.default_rng(0).standard_normal((3, 100)).astype(np.float32) + buf = sysio.BytesIO() + np.save(buf, activity) + npy_bytes = buf.getvalue() + + status, body = _post_binary(cadecon_server, "/api/v1/results/activity", npy_bytes) + assert status == 200 + assert cadecon_server.received_activity is not None + npt.assert_allclose(cadecon_server.received_activity, activity, atol=1e-6) + + +def test_results_json_post(cadecon_server: BridgeServer) -> None: + """POST /api/v1/results stores JSON and triggers results_event.""" + results = {"alphas": [1.0, 2.0], "fs": 30.0, "tau_rise": 0.2} + + status, body = _post(cadecon_server, "/api/v1/results", results) + assert status == 200 + assert cadecon_server.results_event.is_set() + assert cadecon_server.received_results == results + + +def test_results_two_post_sequence(cadecon_server: BridgeServer) -> None: + """Full two-POST flow: activity first, then JSON results.""" + import io as sysio + + # 1. POST activity + activity = np.ones((2, 50), dtype=np.float32) + buf = sysio.BytesIO() + np.save(buf, activity) + status, _ = _post_binary(cadecon_server, "/api/v1/results/activity", buf.getvalue()) + assert status == 200 + + # results_event should NOT be set yet + assert not cadecon_server.results_event.is_set() + + # 2. POST results JSON + results = {"alphas": [1.0, 1.0], "fs": 30.0} + status, _ = _post(cadecon_server, "/api/v1/results", results) + assert status == 200 + + # Now both should be stored + assert cadecon_server.results_event.is_set() + assert cadecon_server.received_activity is not None + assert cadecon_server.received_results == results + npt.assert_array_equal(cadecon_server.received_activity, activity) + + +def test_invalid_npy_returns_400(cadecon_server: BridgeServer) -> None: + """Garbage bytes to /api/v1/results/activity returns 400.""" + status, body = _post_binary( + cadecon_server, "/api/v1/results/activity", b"not-a-npy-file", + ) + assert status == 400 diff --git a/python/tests/test_decon.py b/python/tests/test_decon.py new file mode 100644 index 00000000..c3cc80da --- /dev/null +++ b/python/tests/test_decon.py @@ -0,0 +1,95 @@ +"""Tests for CaDeconResult and _build_biexp_waveform.""" + +from __future__ import annotations + +import numpy as np +import numpy.testing as npt + +from calab._compute import CaDeconResult, _build_biexp_waveform + + +def test_cadecon_result_construction() -> None: + """CaDeconResult can be constructed and fields accessed by name.""" + activity = np.zeros((3, 100), dtype=np.float32) + alphas = np.array([1.0, 1.5, 2.0]) + baselines = np.array([0.1, 0.2, 0.3]) + pves = np.array([0.9, 0.85, 0.92]) + kernel_slow = np.ones(50, dtype=np.float32) + kernel_fast = np.empty(0, dtype=np.float32) + + result = CaDeconResult( + activity=activity, + alphas=alphas, + baselines=baselines, + pves=pves, + kernel_slow=kernel_slow, + kernel_fast=kernel_fast, + fs=30.0, + metadata={"tau_rise": 0.2, "tau_decay": 1.0}, + ) + + assert result.activity.shape == (3, 100) + assert result.alphas.shape == (3,) + assert result.baselines.shape == (3,) + assert result.pves.shape == (3,) + assert result.kernel_slow.shape == (50,) + assert result.kernel_fast.shape == (0,) + assert result.fs == 30.0 + assert result.metadata["tau_rise"] == 0.2 + + +def test_cadecon_result_is_namedtuple() -> None: + """CaDeconResult supports tuple unpacking.""" + result = CaDeconResult( + activity=np.zeros((1, 10), dtype=np.float32), + alphas=np.array([1.0]), + baselines=np.array([0.0]), + pves=np.array([0.9]), + kernel_slow=np.ones(5, dtype=np.float32), + kernel_fast=np.empty(0, dtype=np.float32), + fs=30.0, + metadata={}, + ) + activity, alphas, baselines, pves, ks, kf, fs, meta = result + assert fs == 30.0 + assert len(alphas) == 1 + + +def test_build_biexp_waveform_shape() -> None: + """_build_biexp_waveform returns correct length and dtype.""" + waveform = _build_biexp_waveform( + tau_rise=0.02, tau_decay=0.4, beta=1.0, fs=30.0, length=100, + ) + assert waveform.shape == (100,) + assert waveform.dtype == np.float32 + + +def test_build_biexp_waveform_starts_near_zero() -> None: + """Waveform starts at 0 (at t=0, exp(0)-exp(0) = 0).""" + waveform = _build_biexp_waveform( + tau_rise=0.02, tau_decay=0.4, beta=1.0, fs=1000.0, length=500, + ) + assert abs(waveform[0]) < 1e-6 + + +def test_build_biexp_waveform_peak_positive() -> None: + """Waveform peaks at a positive value when beta > 0.""" + waveform = _build_biexp_waveform( + tau_rise=0.02, tau_decay=0.4, beta=1.0, fs=1000.0, length=500, + ) + assert waveform.max() > 0 + + +def test_build_biexp_waveform_decays() -> None: + """Waveform value at end is less than peak (it decays).""" + waveform = _build_biexp_waveform( + tau_rise=0.02, tau_decay=0.4, beta=1.0, fs=1000.0, length=2000, + ) + assert waveform[-1] < waveform.max() + + +def test_build_biexp_waveform_beta_scaling() -> None: + """Doubling beta doubles the waveform amplitude.""" + w1 = _build_biexp_waveform(tau_rise=0.02, tau_decay=0.4, beta=1.0, fs=100.0, length=50) + w2 = _build_biexp_waveform(tau_rise=0.02, tau_decay=0.4, beta=2.0, fs=100.0, length=50) + npt.assert_allclose(w2, 2.0 * w1, atol=1e-6) From 6d9ee95cfdc2c762ce72d52fd7f70e3f5809d06a Mon Sep 17 00:00:00 2001 From: daharoni Date: Sat, 21 Mar 2026 22:34:05 -0700 Subject: [PATCH 2/2] refactor: deduplicate bridge wait loop and export sort logic Extract shared _run_bridge() helper from tune()/decon() to eliminate ~60 lines of duplicated polling/heartbeat/shutdown boilerplate. Extract sortedCellIndices() in export-utils to avoid double-sorting lookup keys. Remove postActivityToBridge/postResultsToBridge from public io exports since they're internal to exportCaDeconToBridge. Co-Authored-By: Claude Opus 4.6 (1M context) --- apps/cadecon/src/lib/export-utils.ts | 15 +-- packages/io/src/index.ts | 2 - python/src/calab/_bridge/_apps.py | 133 ++++++++++++--------------- 3 files changed, 67 insertions(+), 83 deletions(-) diff --git a/apps/cadecon/src/lib/export-utils.ts b/apps/cadecon/src/lib/export-utils.ts index 2af37532..a85047a3 100644 --- a/apps/cadecon/src/lib/export-utils.ts +++ b/apps/cadecon/src/lib/export-utils.ts @@ -3,7 +3,12 @@ */ import { cellResultLookup, convergenceHistory, convergedAtIteration } from './iteration-store.ts'; -import { samplingRate, numCells, numTimepoints } from './data-store.ts'; +import { samplingRate, numTimepoints } from './data-store.ts'; + +/** Sorted cell indices for deterministic row order across both export functions. */ +function sortedCellIndices(): number[] { + return [...cellResultLookup().keys()].sort((a, b) => a - b); +} /** * Build a contiguous Float32Array activity matrix from per-cell sCounts. @@ -14,13 +19,11 @@ export function buildCaDeconActivityMatrix(): { shape: [number, number]; } { const lookup = cellResultLookup(); - const nCells = numCells() ?? 0; const nTime = numTimepoints() ?? 0; - const data = new Float32Array(nCells * nTime); + const sortedCells = sortedCellIndices(); + const data = new Float32Array(sortedCells.length * nTime); - // Sorted cell indices for deterministic row order - const sortedCells = [...lookup.keys()].sort((a, b) => a - b); for (let row = 0; row < sortedCells.length; row++) { const entry = lookup.get(sortedCells[row])!; const offset = row * nTime; @@ -39,7 +42,7 @@ export function buildCaDeconResultsPayload(): Record { const history = convergenceHistory(); const fs = samplingRate() ?? 30; - const sortedCells = [...lookup.keys()].sort((a, b) => a - b); + const sortedCells = sortedCellIndices(); const alphas: number[] = []; const baselines: number[] = []; const pves: number[] = []; diff --git a/packages/io/src/index.ts b/packages/io/src/index.ts index 5385e1e9..f4e53e28 100644 --- a/packages/io/src/index.ts +++ b/packages/io/src/index.ts @@ -10,8 +10,6 @@ export { getBridgeUrl, fetchBridgeData, postParamsToBridge, - postActivityToBridge, - postResultsToBridge, exportCaDeconToBridge, startBridgeHeartbeat, stopBridgeHeartbeat, diff --git a/python/src/calab/_bridge/_apps.py b/python/src/calab/_bridge/_apps.py index 679ece9f..779418ac 100644 --- a/python/src/calab/_bridge/_apps.py +++ b/python/src/calab/_bridge/_apps.py @@ -17,6 +17,56 @@ _DEFAULT_CADECON_URL = "https://miniscope.github.io/CaLab/CaDecon/" +def _run_bridge( + server: BridgeServer, + event: threading.Event, + app_name: str, + app_url: str, + open_browser: bool, + timeout: float | None, +) -> bool: + """Start server, open browser, and wait for the bridge event. + + Returns True if the event fired (data received), False otherwise. + """ + actual_port = server.port + server_thread = threading.Thread(target=server.serve_forever, daemon=True) + server_thread.start() + + bridge_param = f"http://127.0.0.1:{actual_port}" + full_url = f"{app_url}?bridge={bridge_param}" + + print(f"Bridge server running on http://127.0.0.1:{actual_port}") + print(f"Opening {app_name}: {full_url}") + + if open_browser: + webbrowser.open(full_url) + + received = False + start_time = time.monotonic() + try: + while True: + if event.wait(timeout=1.0): + received = True + break + + now = time.monotonic() + + if timeout is not None and (now - start_time) >= timeout: + break + + if server.last_heartbeat is not None: + if (now - server.last_heartbeat) > HEARTBEAT_TIMEOUT: + print("\nBrowser disconnected (heartbeat timeout).") + break + except KeyboardInterrupt: + print("\nBridge cancelled by user.") + finally: + server.shutdown() + + return received + + def tune( traces: np.ndarray, fs: float = 30.0, @@ -53,44 +103,10 @@ def tune( Keys: ``tau_rise``, ``tau_decay``, ``lambda_``, ``fs``, ``filter_enabled``. """ server = BridgeServer(traces, fs, port=port or 0) - actual_port = server.port - - # Start server in daemon thread - server_thread = threading.Thread(target=server.serve_forever, daemon=True) - server_thread.start() - - url = app_url or _DEFAULT_CATUNE_URL - bridge_param = f"http://127.0.0.1:{actual_port}" - full_url = f"{url}?bridge={bridge_param}" - - print(f"Bridge server running on http://127.0.0.1:{actual_port}") - print(f"Opening CaTune: {full_url}") - - if open_browser: - webbrowser.open(full_url) - - received = False - start_time = time.monotonic() - try: - while True: - if server.params_event.wait(timeout=1.0): - received = True - break - - now = time.monotonic() - - if timeout is not None and (now - start_time) >= timeout: - break - - # Detect browser disconnect (only after first heartbeat arrives) - if server.last_heartbeat is not None: - if (now - server.last_heartbeat) > HEARTBEAT_TIMEOUT: - print("\nBrowser disconnected (heartbeat timeout).") - break - except KeyboardInterrupt: - print("\nBridge cancelled by user.") - finally: - server.shutdown() + received = _run_bridge( + server, server.params_event, "CaTune", + app_url or _DEFAULT_CATUNE_URL, open_browser, timeout, + ) if received and server.received_params is not None: raw = server.received_params @@ -144,43 +160,10 @@ def decon( from .._compute import CaDeconResult, _build_biexp_waveform server = BridgeServer(traces, fs, port=port or 0, app="cadecon") - actual_port = server.port - - # Start server in daemon thread - server_thread = threading.Thread(target=server.serve_forever, daemon=True) - server_thread.start() - - url = app_url or _DEFAULT_CADECON_URL - bridge_param = f"http://127.0.0.1:{actual_port}" - full_url = f"{url}?bridge={bridge_param}" - - print(f"Bridge server running on http://127.0.0.1:{actual_port}") - print(f"Opening CaDecon: {full_url}") - - if open_browser: - webbrowser.open(full_url) - - received = False - start_time = time.monotonic() - try: - while True: - if server.results_event.wait(timeout=1.0): - received = True - break - - now = time.monotonic() - - if timeout is not None and (now - start_time) >= timeout: - break - - if server.last_heartbeat is not None: - if (now - server.last_heartbeat) > HEARTBEAT_TIMEOUT: - print("\nBrowser disconnected (heartbeat timeout).") - break - except KeyboardInterrupt: - print("\nBridge cancelled by user.") - finally: - server.shutdown() + received = _run_bridge( + server, server.results_event, "CaDecon", + app_url or _DEFAULT_CADECON_URL, open_browser, timeout, + ) if not received or server.received_results is None: return None