From 022e1087929545bdcc429418e30596dc0c06ff61 Mon Sep 17 00:00:00 2001 From: Iahn Cajigas Date: Sat, 7 Mar 2026 12:39:44 -0500 Subject: [PATCH 1/3] Report notebook fidelity and validate artifacts --- nstat/notebook_figures.py | 18 +++ nstat/notebook_parity.py | 143 ++++++++++++++++++++++ nstat/parity_report.py | 41 ++++++- parity/report.md | 22 +++- tests/test_notebook_artifact_contracts.py | 43 +++++++ tests/test_parity_report.py | 4 + tools/notebooks/run_notebooks.py | 20 +++ 7 files changed, 289 insertions(+), 2 deletions(-) create mode 100644 nstat/notebook_parity.py create mode 100644 tests/test_notebook_artifact_contracts.py diff --git a/nstat/notebook_figures.py b/nstat/notebook_figures.py index 9c77a1c8..f2d9de8f 100644 --- a/nstat/notebook_figures.py +++ b/nstat/notebook_figures.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json from dataclasses import dataclass, field from pathlib import Path @@ -25,6 +26,9 @@ def __post_init__(self) -> None: topic_dir = self._topic_dir() for img_path in topic_dir.glob("fig_*.png"): img_path.unlink() + manifest_path = topic_dir / "manifest.json" + if manifest_path.exists(): + manifest_path.unlink() def _topic_dir(self) -> Path: out = self.output_root / self.topic @@ -100,3 +104,17 @@ def finalize(self) -> None: raise AssertionError( f"{self.topic}: produced {self.count} figure(s), expected {self.expected_count}" ) + topic_dir = self._topic_dir() + images = [path.name for path in sorted(topic_dir.glob("fig_*.png"))] + (topic_dir / "manifest.json").write_text( + json.dumps( + { + "topic": self.topic, + "expected_count": int(self.expected_count), + "produced_count": self.count, + "images": images, + }, + indent=2, + ), + encoding="utf-8", + ) diff --git a/nstat/notebook_parity.py b/nstat/notebook_parity.py new file mode 100644 index 00000000..2dee0660 --- /dev/null +++ b/nstat/notebook_parity.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import json +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import nbformat +import yaml + + +PARITY_NOTES_RELATIVE_PATH = Path("tools") / "notebooks" / "parity_notes.yml" +NOTEBOOK_IMAGE_ROOT = Path("output") / "notebook_images" +FIGURE_MANIFEST_NAME = "manifest.json" +FIGURE_TRACKER_RE = re.compile( + r"FigureTracker\(\s*topic=['\"](?P[^'\"]+)['\"]\s*,\s*output_root=OUTPUT_ROOT\s*,\s*expected_count=(?P\d+)\s*\)", + re.DOTALL, +) + + +@dataclass(frozen=True) +class NotebookFigureContract: + topic: str + expected_count: int + has_finalize_call: bool + + def topic_dir(self, repo_root: Path) -> Path: + return repo_root / NOTEBOOK_IMAGE_ROOT / self.topic + + def manifest_path(self, repo_root: Path) -> Path: + return self.topic_dir(repo_root) / FIGURE_MANIFEST_NAME + + +def _repo_root() -> Path: + return Path(__file__).resolve().parents[1] + + +def load_notebook_parity_notes(repo_root: Path | None = None) -> list[dict[str, Any]]: + base = _repo_root() if repo_root is None else repo_root.resolve() + payload = yaml.safe_load((base / PARITY_NOTES_RELATIVE_PATH).read_text(encoding="utf-8")) or {} + return list(payload.get("notes", [])) + + +def summarize_notebook_fidelity(notes: list[dict[str, Any]]) -> dict[str, int]: + counts: dict[str, int] = {} + for row in notes: + status = str(row.get("fidelity_status", "")).strip() + if not status: + continue + counts[status] = counts.get(status, 0) + 1 + return counts + + +def iter_outstanding_notebook_fidelity(notes: list[dict[str, Any]]) -> list[dict[str, Any]]: + return [row for row in notes if row.get("fidelity_status") == "partial"] + + +def extract_figure_contract(notebook_path: Path) -> NotebookFigureContract | None: + notebook = nbformat.read(notebook_path, as_version=4) + text = "\n".join(str(cell.get("source", "")) for cell in notebook.cells) + match = FIGURE_TRACKER_RE.search(text) + if not match: + return None + return NotebookFigureContract( + topic=match.group("topic"), + expected_count=int(match.group("count")), + has_finalize_call="__tracker.finalize()" in text, + ) + + +def reset_notebook_figure_artifacts(repo_root: Path, contract: NotebookFigureContract) -> None: + topic_dir = contract.topic_dir(repo_root.resolve()) + if not topic_dir.exists(): + return + for path in topic_dir.glob("fig_*.png"): + path.unlink() + manifest = contract.manifest_path(repo_root.resolve()) + if manifest.exists(): + manifest.unlink() + + +def validate_notebook_figure_artifacts( + repo_root: Path, + contract: NotebookFigureContract, + *, + expected_topic: str | None = None, +) -> None: + base = repo_root.resolve() + if expected_topic is not None and contract.topic != expected_topic: + raise AssertionError( + f"Notebook figure contract topic {contract.topic!r} does not match manifest topic {expected_topic!r}" + ) + if not contract.has_finalize_call: + raise AssertionError(f"{contract.topic}: notebook uses FigureTracker but never calls __tracker.finalize()") + + topic_dir = contract.topic_dir(base) + manifest_path = contract.manifest_path(base) + if not manifest_path.exists(): + raise AssertionError(f"{contract.topic}: missing notebook figure manifest at {manifest_path}") + + payload = json.loads(manifest_path.read_text(encoding="utf-8")) + produced = int(payload.get("produced_count", -1)) + expected = int(payload.get("expected_count", -1)) + images = [str(item) for item in payload.get("images", [])] + + if payload.get("topic") != contract.topic: + raise AssertionError(f"{contract.topic}: figure manifest topic mismatch") + if expected != contract.expected_count: + raise AssertionError( + f"{contract.topic}: figure manifest expected_count={expected} does not match notebook contract {contract.expected_count}" + ) + if produced != contract.expected_count: + raise AssertionError( + f"{contract.topic}: figure manifest produced_count={produced} does not match expected_count={contract.expected_count}" + ) + if len(images) != contract.expected_count: + raise AssertionError( + f"{contract.topic}: figure manifest lists {len(images)} image(s), expected {contract.expected_count}" + ) + + disk_images = sorted(topic_dir.glob("fig_*.png")) + if len(disk_images) != contract.expected_count: + raise AssertionError( + f"{contract.topic}: output directory contains {len(disk_images)} figure(s), expected {contract.expected_count}" + ) + + missing = [path for path in images if not (topic_dir / path).exists()] + if missing: + raise AssertionError(f"{contract.topic}: manifest references missing figure files: {missing}") + + +__all__ = [ + "FIGURE_MANIFEST_NAME", + "NOTEBOOK_IMAGE_ROOT", + "NotebookFigureContract", + "extract_figure_contract", + "iter_outstanding_notebook_fidelity", + "load_notebook_parity_notes", + "reset_notebook_figure_artifacts", + "summarize_notebook_fidelity", + "validate_notebook_figure_artifacts", +] diff --git a/nstat/parity_report.py b/nstat/parity_report.py index f61841fd..9885e222 100644 --- a/nstat/parity_report.py +++ b/nstat/parity_report.py @@ -5,6 +5,12 @@ import yaml +from nstat.notebook_parity import ( + iter_outstanding_notebook_fidelity, + load_notebook_parity_notes, + summarize_notebook_fidelity, +) + SUMMARY_SECTIONS = ( "public_api", @@ -72,11 +78,14 @@ def _iter_non_applicable_rows(payload: dict[str, Any]) -> list[tuple[str, dict[s def render_parity_report(repo_root: Path | None = None) -> str: payload = load_parity_manifest(repo_root) class_fidelity = load_class_fidelity_audit(repo_root) + notebook_fidelity = load_notebook_parity_notes(repo_root) class_counts = _summarize_class_fidelity(class_fidelity) + notebook_counts = summarize_notebook_fidelity(notebook_fidelity) + notebook_partial = iter_outstanding_notebook_fidelity(notebook_fidelity) lines = [ "# nSTAT Python Parity Report", "", - "Generated from `parity/manifest.yml` and `parity/class_fidelity.yml`.", + "Generated from `parity/manifest.yml`, `parity/class_fidelity.yml`, and `tools/notebooks/parity_notes.yml`.", "", f"- MATLAB reference: {payload['source_repositories']['matlab']}", f"- Python target: {payload['source_repositories']['python']}", @@ -107,11 +116,32 @@ def render_parity_report(repo_root: Path | None = None) -> str: for status in class_fidelity.get("status_legend", []): lines.append(f"| `{status}` | {class_counts.get(status, 0)} |") + lines.extend( + [ + "", + "## Notebook Fidelity Summary", + "", + "| Status | Count |", + "|---|---:|", + ] + ) + for status in ("exact", "high_fidelity", "partial"): + lines.append(f"| `{status}` | {notebook_counts.get(status, 0)} |") + lines.extend(["", "## Coverage Notes", ""]) lines.append( "- Public API: no missing MATLAB public APIs remain; only the MATLAB help-browser utility is explicitly non-applicable." ) lines.append("- Help/notebook parity: all inventoried MATLAB help workflows are mapped to Python notebooks or equivalents.") + if notebook_partial: + lines.append( + f"- Notebook fidelity: workflow coverage is complete, but {len(notebook_partial)} MATLAB-helpfile notebook ports are still marked partial in `tools/notebooks/parity_notes.yml`." + ) + lines.append( + "- Notebook fidelity audit: structural section/figure comparisons are recorded in `parity/notebook_fidelity.yml`." + ) + else: + lines.append("- Notebook fidelity: all tracked MATLAB-helpfile notebook ports are marked high fidelity or exact.") if _has_outstanding(payload, "paper_examples") or _has_outstanding(payload, "docs_gallery"): lines.append( "- Paper examples and docs gallery: canonical structure is present, but dataset-backed outputs and figure files are still partial." @@ -146,6 +176,15 @@ def render_parity_report(repo_root: Path | None = None) -> str: notes = row.get("notes", "") lines.append(f"- `{label}` -> `{python_target}`: {notes}") + lines.extend(["", "## Remaining Notebook-Fidelity Deltas", ""]) + if not notebook_partial: + lines.append("No partial notebook-fidelity items remain in `tools/notebooks/parity_notes.yml`.") + else: + for row in notebook_partial: + lines.append( + f"- `{row['topic']}` -> `{row['file']}` [{row['fidelity_status']}]: {row['remaining_differences']}" + ) + lines.extend(["", "## Remaining Class-Fidelity Deltas", ""]) if not priority_remaining: lines.append("No partial, shim-only, or missing class-fidelity items remain.") diff --git a/parity/report.md b/parity/report.md index d9ca06fb..a09c3252 100644 --- a/parity/report.md +++ b/parity/report.md @@ -1,6 +1,6 @@ # nSTAT Python Parity Report -Generated from `parity/manifest.yml` and `parity/class_fidelity.yml`. +Generated from `parity/manifest.yml`, `parity/class_fidelity.yml`, and `tools/notebooks/parity_notes.yml`. - MATLAB reference: https://github.com/cajigaslab/nSTAT - Python target: https://github.com/cajigaslab/nSTAT-python @@ -29,10 +29,20 @@ Generated from `parity/manifest.yml` and `parity/class_fidelity.yml`. | `missing` | 0 | | `not_applicable` | 1 | +## Notebook Fidelity Summary + +| Status | Count | +|---|---:| +| `exact` | 0 | +| `high_fidelity` | 4 | +| `partial` | 7 | + ## Coverage Notes - Public API: no missing MATLAB public APIs remain; only the MATLAB help-browser utility is explicitly non-applicable. - Help/notebook parity: all inventoried MATLAB help workflows are mapped to Python notebooks or equivalents. +- Notebook fidelity: workflow coverage is complete, but 7 MATLAB-helpfile notebook ports are still marked partial in `tools/notebooks/parity_notes.yml`. +- Notebook fidelity audit: structural section/figure comparisons are recorded in `parity/notebook_fidelity.yml`. - Paper examples and docs gallery: all canonical paper examples and committed gallery directories are mapped. - Class fidelity: the class audit reports no partial, shim-only, or missing items. @@ -40,6 +50,16 @@ Generated from `parity/manifest.yml` and `parity/class_fidelity.yml`. No partial or missing items remain in the mapping inventory. +## Remaining Notebook-Fidelity Deltas + +- `DecodingExample` -> `notebooks/DecodingExample.ipynb` [partial]: Core decoding workflow is present, but MATLAB decoding options and reference outputs are not yet fully matched. +- `DecodingExampleWithHist` -> `notebooks/DecodingExampleWithHist.ipynb` [partial]: History-aware decoding is available, but the MATLAB workflow still has richer option handling and reference outputs. +- `ExplicitStimulusWhiskerData` -> `notebooks/ExplicitStimulusWhiskerData.ipynb` [partial]: Dataset-backed workflow is present, but figure-level and narrative parity with MATLAB are still incomplete. +- `HippocampalPlaceCellExample` -> `notebooks/HippocampalPlaceCellExample.ipynb` [partial]: Core place-cell workflow is ported, but MATLAB figure sequencing and summary outputs are not yet exact. +- `HybridFilterExample` -> `notebooks/HybridFilterExample.ipynb` [partial]: Hybrid filtering workflow executes, but MATLAB-specific output details and downstream validation remain incomplete. +- `ValidationDataSet` -> `notebooks/ValidationDataSet.ipynb` [partial]: Validation dataset coverage exists, but MATLAB reference summaries and figure parity are not yet complete. +- `StimulusDecode2D` -> `notebooks/StimulusDecode2D.ipynb` [partial]: The 2D stimulus decoding workflow runs, but MATLAB-equivalent outputs and tolerance-backed parity checks still need expansion. + ## Remaining Class-Fidelity Deltas No partial, shim-only, or missing class-fidelity items remain. diff --git a/tests/test_notebook_artifact_contracts.py b/tests/test_notebook_artifact_contracts.py new file mode 100644 index 00000000..00a237be --- /dev/null +++ b/tests/test_notebook_artifact_contracts.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import json +from pathlib import Path + +import yaml + +from nstat.notebook_figures import FigureTracker +from nstat.notebook_parity import extract_figure_contract + + +REPO_ROOT = Path(__file__).resolve().parents[1] +NOTEBOOK_MANIFEST_PATH = REPO_ROOT / "tools" / "notebooks" / "notebook_manifest.yml" + + +def test_notebook_figure_tracker_contracts_match_manifest_topics() -> None: + payload = yaml.safe_load(NOTEBOOK_MANIFEST_PATH.read_text(encoding="utf-8")) or {} + for row in payload.get("notebooks", []): + notebook_path = REPO_ROOT / row["file"] + contract = extract_figure_contract(notebook_path) + if contract is None: + continue + assert contract.topic == row["topic"], f"{notebook_path} tracker topic drifted from notebook manifest" + assert contract.expected_count >= 0 + assert contract.has_finalize_call, f"{notebook_path} uses FigureTracker but does not finalize it" + + +def test_figure_tracker_writes_manifest(tmp_path: Path) -> None: + output_root = tmp_path / "notebook_images" + tracker = FigureTracker(topic="ArtifactContractTest", output_root=output_root, expected_count=2) + tracker.new_figure("first") + tracker.new_figure("second") + tracker.finalize() + + topic_dir = output_root / "ArtifactContractTest" + manifest_path = topic_dir / "manifest.json" + assert manifest_path.exists() + + payload = json.loads(manifest_path.read_text(encoding="utf-8")) + assert payload["topic"] == "ArtifactContractTest" + assert payload["expected_count"] == 2 + assert payload["produced_count"] == 2 + assert payload["images"] == ["fig_001.png", "fig_002.png"] diff --git a/tests/test_parity_report.py b/tests/test_parity_report.py index d1c85bf0..7f8cfa8d 100644 --- a/tests/test_parity_report.py +++ b/tests/test_parity_report.py @@ -20,6 +20,10 @@ def test_parity_report_highlights_current_constraints() -> None: assert "paper examples and docs gallery" in text.lower() assert "all canonical paper examples and committed gallery directories are mapped" in text assert "class fidelity" in text.lower() + assert "Notebook Fidelity Summary" in text + assert "Remaining Notebook-Fidelity Deltas" in text + assert "parity/notebook_fidelity.yml" in text + assert "DecodingExample" in text assert "No partial or missing items remain in the mapping inventory." in text assert "Remaining Class-Fidelity Deltas" in text assert "No partial, shim-only, or missing class-fidelity items remain." in text diff --git a/tools/notebooks/run_notebooks.py b/tools/notebooks/run_notebooks.py index a3999f69..88521b5d 100755 --- a/tools/notebooks/run_notebooks.py +++ b/tools/notebooks/run_notebooks.py @@ -5,6 +5,7 @@ import argparse import os +import sys from dataclasses import dataclass from pathlib import Path @@ -12,6 +13,16 @@ import yaml from nbclient import NotebookClient +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from nstat.notebook_parity import ( + extract_figure_contract, + reset_notebook_figure_artifacts, + validate_notebook_figure_artifacts, +) + @dataclass(frozen=True) class NotebookTarget: @@ -135,8 +146,17 @@ def main() -> int: failures.append(f"missing notebook: {target.path}") continue print(f"Executing [{target.run_group}] {target.topic}: {target.path}") + figure_contract = extract_figure_contract(target.path) try: + if figure_contract is not None: + reset_notebook_figure_artifacts(args.repo_root, figure_contract) execute_notebook(target.path, timeout=args.timeout) + if figure_contract is not None: + validate_notebook_figure_artifacts( + args.repo_root, + figure_contract, + expected_topic=target.topic, + ) except Exception as exc: # noqa: BLE001 failures.append(f"{target.path}: {exc}") From 95a839d7086ccb6f4c828a3d764f86c9327e84f5 Mon Sep 17 00:00:00 2001 From: Iahn Cajigas Date: Sat, 7 Mar 2026 12:39:46 -0500 Subject: [PATCH 2/3] Audit notebook structure against MATLAB helpfiles --- nstat/notebook_fidelity_audit.py | 140 ++++++++++++++ parity/notebook_fidelity.yml | 172 ++++++++++++++++++ tests/test_notebook_fidelity_audit.py | 39 ++++ tools/parity/build_notebook_fidelity_audit.py | 22 +++ 4 files changed, 373 insertions(+) create mode 100644 nstat/notebook_fidelity_audit.py create mode 100644 parity/notebook_fidelity.yml create mode 100644 tests/test_notebook_fidelity_audit.py create mode 100644 tools/parity/build_notebook_fidelity_audit.py diff --git a/nstat/notebook_fidelity_audit.py b/nstat/notebook_fidelity_audit.py new file mode 100644 index 00000000..db3c541d --- /dev/null +++ b/nstat/notebook_fidelity_audit.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import re +from datetime import date +from pathlib import Path +from typing import Any + +import nbformat +import yaml + +from nstat.notebook_parity import extract_figure_contract, load_notebook_parity_notes + + +IMG_SRC_RE = re.compile(r']+src="([^"]+)"', re.IGNORECASE) +SECTION_RE = re.compile(r"^%%", re.MULTILINE) +PYTHON_SECTION_RE = re.compile(r"^# SECTION\b", re.MULTILINE) + + +def _repo_root() -> Path: + return Path(__file__).resolve().parents[1] + + +def default_matlab_repo_root(repo_root: Path | None = None) -> Path: + base = _repo_root() if repo_root is None else repo_root.resolve() + return base.parent / "nSTAT" + + +def _count_matlab_sections(matlab_m_path: Path) -> int: + text = matlab_m_path.read_text(encoding="utf-8", errors="ignore") + return len(SECTION_RE.findall(text)) + + +def _count_matlab_published_figures(matlab_html_path: Path) -> int: + text = matlab_html_path.read_text(encoding="utf-8", errors="ignore") + return len(IMG_SRC_RE.findall(text)) + + +def _count_python_sections(notebook_path: Path) -> int: + notebook = nbformat.read(notebook_path, as_version=4) + text = "\n".join(str(cell.get("source", "")) for cell in notebook.cells) + return len(PYTHON_SECTION_RE.findall(text)) + + +def build_notebook_fidelity_audit( + repo_root: Path | None = None, + *, + matlab_repo_root: Path | None = None, +) -> dict[str, Any]: + base = _repo_root() if repo_root is None else repo_root.resolve() + matlab_root = default_matlab_repo_root(base) if matlab_repo_root is None else matlab_repo_root.resolve() + help_root = matlab_root / "helpfiles" + notes = load_notebook_parity_notes(base) + + items: list[dict[str, Any]] = [] + for row in notes: + topic = str(row["topic"]) + notebook_path = base / str(row["file"]) + figure_contract = extract_figure_contract(notebook_path) + python_sections = _count_python_sections(notebook_path) + matlab_stem = Path(str(row["source_matlab"])).stem + matlab_m_path = help_root / f"{matlab_stem}.m" + matlab_html_path = help_root / f"{matlab_stem}.html" + matlab_available = matlab_m_path.exists() and matlab_html_path.exists() + + item: dict[str, Any] = { + "topic": topic, + "source_matlab": str(row["source_matlab"]), + "python_notebook": str(row["file"]), + "fidelity_status": str(row["fidelity_status"]), + "remaining_differences": str(row["remaining_differences"]), + "python_sections": python_sections, + "python_expected_figures": int(figure_contract.expected_count) if figure_contract else 0, + "python_uses_figure_tracker": figure_contract is not None, + "python_has_finalize_call": bool(figure_contract.has_finalize_call) if figure_contract else False, + } + if matlab_available: + matlab_sections = _count_matlab_sections(matlab_m_path) + matlab_figures = _count_matlab_published_figures(matlab_html_path) + item.update( + { + "matlab_repo_root": str(matlab_root), + "matlab_sections": matlab_sections, + "matlab_published_figures": matlab_figures, + "section_delta": python_sections - matlab_sections, + "figure_delta": int(figure_contract.expected_count) - matlab_figures if figure_contract else -matlab_figures, + } + ) + else: + item.update( + { + "matlab_repo_root": str(matlab_root), + "matlab_sections": None, + "matlab_published_figures": None, + "section_delta": None, + "figure_delta": None, + } + ) + items.append(item) + + return { + "version": 1, + "generated_on": str(date.today()), + "source_repositories": { + "matlab": "https://github.com/cajigaslab/nSTAT", + "python": "https://github.com/cajigaslab/nSTAT-python", + }, + "matlab_repo_root": str(matlab_root), + "items": items, + } + + +def render_notebook_fidelity_audit( + repo_root: Path | None = None, + *, + matlab_repo_root: Path | None = None, +) -> str: + payload = build_notebook_fidelity_audit(repo_root, matlab_repo_root=matlab_repo_root) + return yaml.safe_dump(payload, sort_keys=False, allow_unicode=False) + + +def write_notebook_fidelity_audit( + repo_root: Path | None = None, + *, + matlab_repo_root: Path | None = None, +) -> Path: + base = _repo_root() if repo_root is None else repo_root.resolve() + out = base / "parity" / "notebook_fidelity.yml" + out.write_text( + render_notebook_fidelity_audit(base, matlab_repo_root=matlab_repo_root), + encoding="utf-8", + ) + return out + + +__all__ = [ + "build_notebook_fidelity_audit", + "default_matlab_repo_root", + "render_notebook_fidelity_audit", + "write_notebook_fidelity_audit", +] diff --git a/parity/notebook_fidelity.yml b/parity/notebook_fidelity.yml new file mode 100644 index 00000000..5ef15fe7 --- /dev/null +++ b/parity/notebook_fidelity.yml @@ -0,0 +1,172 @@ +version: 1 +generated_on: '2026-03-07' +source_repositories: + matlab: https://github.com/cajigaslab/nSTAT + python: https://github.com/cajigaslab/nSTAT-python +matlab_repo_root: /Users/iahncajigas/Library/CloudStorage/Dropbox/Codex/nSTAT +items: +- topic: nSTATPaperExamples + source_matlab: nSTATPaperExamples.mlx + python_notebook: notebooks/nSTATPaperExamples.ipynb + fidelity_status: high_fidelity + remaining_differences: Python uses standalone figshare-backed data access and generated + gallery assets rather than MATLAB path-based setup. + python_sections: 31 + python_expected_figures: 25 + python_uses_figure_tracker: true + python_has_finalize_call: true + matlab_repo_root: /Users/iahncajigas/Library/CloudStorage/Dropbox/Codex/nSTAT + matlab_sections: 37 + matlab_published_figures: 26 + section_delta: -6 + figure_delta: -1 +- topic: TrialExamples + source_matlab: TrialExamples.mlx + python_notebook: notebooks/TrialExamples.ipynb + fidelity_status: high_fidelity + remaining_differences: Some MATLAB plotting/display details remain simplified, but + the core Trial object workflow now follows the MATLAB semantics closely. + python_sections: 3 + python_expected_figures: 6 + python_uses_figure_tracker: true + python_has_finalize_call: true + matlab_repo_root: /Users/iahncajigas/Library/CloudStorage/Dropbox/Codex/nSTAT + matlab_sections: 9 + matlab_published_figures: 6 + section_delta: -6 + figure_delta: 0 +- topic: AnalysisExamples + source_matlab: AnalysisExamples.mlx + python_notebook: notebooks/AnalysisExamples.ipynb + fidelity_status: high_fidelity + remaining_differences: Advanced MATLAB algorithm-selection branches and some report + plots are still lighter in Python. + python_sections: 2 + python_expected_figures: 4 + python_uses_figure_tracker: true + python_has_finalize_call: true + matlab_repo_root: /Users/iahncajigas/Library/CloudStorage/Dropbox/Codex/nSTAT + matlab_sections: 7 + matlab_published_figures: 4 + section_delta: -5 + figure_delta: 0 +- topic: DecodingExample + source_matlab: DecodingExample.mlx + python_notebook: notebooks/DecodingExample.ipynb + fidelity_status: partial + remaining_differences: Core decoding workflow is present, but MATLAB decoding options + and reference outputs are not yet fully matched. + python_sections: 3 + python_expected_figures: 5 + python_uses_figure_tracker: true + python_has_finalize_call: true + matlab_repo_root: /Users/iahncajigas/Library/CloudStorage/Dropbox/Codex/nSTAT + matlab_sections: 4 + matlab_published_figures: 5 + section_delta: -1 + figure_delta: 0 +- topic: DecodingExampleWithHist + source_matlab: DecodingExampleWithHist.mlx + python_notebook: notebooks/DecodingExampleWithHist.ipynb + fidelity_status: partial + remaining_differences: History-aware decoding is available, but the MATLAB workflow + still has richer option handling and reference outputs. + python_sections: 1 + python_expected_figures: 2 + python_uses_figure_tracker: true + python_has_finalize_call: true + matlab_repo_root: /Users/iahncajigas/Library/CloudStorage/Dropbox/Codex/nSTAT + matlab_sections: 2 + matlab_published_figures: 2 + section_delta: -1 + figure_delta: 0 +- topic: ExplicitStimulusWhiskerData + source_matlab: ExplicitStimulusWhiskerData.mlx + python_notebook: notebooks/ExplicitStimulusWhiskerData.ipynb + fidelity_status: partial + remaining_differences: Dataset-backed workflow is present, but figure-level and + narrative parity with MATLAB are still incomplete. + python_sections: 6 + python_expected_figures: 9 + python_uses_figure_tracker: true + python_has_finalize_call: true + matlab_repo_root: /Users/iahncajigas/Library/CloudStorage/Dropbox/Codex/nSTAT + matlab_sections: 7 + matlab_published_figures: 9 + section_delta: -1 + figure_delta: 0 +- topic: HippocampalPlaceCellExample + source_matlab: HippocampalPlaceCellExample.mlx + python_notebook: notebooks/HippocampalPlaceCellExample.ipynb + fidelity_status: partial + remaining_differences: Core place-cell workflow is ported, but MATLAB figure sequencing + and summary outputs are not yet exact. + python_sections: 5 + python_expected_figures: 9 + python_uses_figure_tracker: true + python_has_finalize_call: true + matlab_repo_root: /Users/iahncajigas/Library/CloudStorage/Dropbox/Codex/nSTAT + matlab_sections: 5 + matlab_published_figures: 11 + section_delta: 0 + figure_delta: -2 +- topic: HybridFilterExample + source_matlab: HybridFilterExample.mlx + python_notebook: notebooks/HybridFilterExample.ipynb + fidelity_status: partial + remaining_differences: Hybrid filtering workflow executes, but MATLAB-specific output + details and downstream validation remain incomplete. + python_sections: 4 + python_expected_figures: 2 + python_uses_figure_tracker: true + python_has_finalize_call: true + matlab_repo_root: /Users/iahncajigas/Library/CloudStorage/Dropbox/Codex/nSTAT + matlab_sections: 6 + matlab_published_figures: 3 + section_delta: -2 + figure_delta: -1 +- topic: PPSimExample + source_matlab: PPSimExample.mlx + python_notebook: notebooks/PPSimExample.ipynb + fidelity_status: high_fidelity + remaining_differences: MATLAB plotting/report formatting remains lighter, but the + core point-process simulation workflow is closely aligned. + python_sections: 9 + python_expected_figures: 3 + python_uses_figure_tracker: true + python_has_finalize_call: true + matlab_repo_root: /Users/iahncajigas/Library/CloudStorage/Dropbox/Codex/nSTAT + matlab_sections: 17 + matlab_published_figures: 8 + section_delta: -8 + figure_delta: -5 +- topic: ValidationDataSet + source_matlab: ValidationDataSet.mlx + python_notebook: notebooks/ValidationDataSet.ipynb + fidelity_status: partial + remaining_differences: Validation dataset coverage exists, but MATLAB reference + summaries and figure parity are not yet complete. + python_sections: 3 + python_expected_figures: 9 + python_uses_figure_tracker: true + python_has_finalize_call: true + matlab_repo_root: /Users/iahncajigas/Library/CloudStorage/Dropbox/Codex/nSTAT + matlab_sections: 11 + matlab_published_figures: 10 + section_delta: -8 + figure_delta: -1 +- topic: StimulusDecode2D + source_matlab: StimulusDecode2D.mlx + python_notebook: notebooks/StimulusDecode2D.ipynb + fidelity_status: partial + remaining_differences: The 2D stimulus decoding workflow runs, but MATLAB-equivalent + outputs and tolerance-backed parity checks still need expansion. + python_sections: 3 + python_expected_figures: 8 + python_uses_figure_tracker: true + python_has_finalize_call: true + matlab_repo_root: /Users/iahncajigas/Library/CloudStorage/Dropbox/Codex/nSTAT + matlab_sections: 4 + matlab_published_figures: 6 + section_delta: -1 + figure_delta: 2 diff --git a/tests/test_notebook_fidelity_audit.py b/tests/test_notebook_fidelity_audit.py new file mode 100644 index 00000000..07d0d807 --- /dev/null +++ b/tests/test_notebook_fidelity_audit.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest +import yaml + +from nstat.notebook_fidelity_audit import default_matlab_repo_root, render_notebook_fidelity_audit +from nstat.notebook_parity import load_notebook_parity_notes + + +REPO_ROOT = Path(__file__).resolve().parents[1] +AUDIT_PATH = REPO_ROOT / "parity" / "notebook_fidelity.yml" + + +def test_notebook_fidelity_audit_covers_all_parity_notes() -> None: + audit = yaml.safe_load(AUDIT_PATH.read_text(encoding="utf-8")) or {} + notes = load_notebook_parity_notes(REPO_ROOT) + + audit_topics = {row["topic"] for row in audit.get("items", [])} + note_topics = {row["topic"] for row in notes} + assert audit_topics == note_topics + + +def test_notebook_fidelity_audit_has_structural_counts() -> None: + audit = yaml.safe_load(AUDIT_PATH.read_text(encoding="utf-8")) or {} + for row in audit.get("items", []): + assert "python_sections" in row + assert "python_expected_figures" in row + assert row["python_expected_figures"] >= 0 + assert isinstance(row["python_has_finalize_call"], bool) + + +def test_notebook_fidelity_audit_matches_generator_when_matlab_repo_is_available() -> None: + matlab_repo = default_matlab_repo_root(REPO_ROOT) + if not matlab_repo.exists(): + pytest.skip(f"MATLAB reference repo not available at {matlab_repo}") + committed = AUDIT_PATH.read_text(encoding="utf-8") + assert committed == render_notebook_fidelity_audit(REPO_ROOT, matlab_repo_root=matlab_repo) diff --git a/tools/parity/build_notebook_fidelity_audit.py b/tools/parity/build_notebook_fidelity_audit.py new file mode 100644 index 00000000..0b1ed56b --- /dev/null +++ b/tools/parity/build_notebook_fidelity_audit.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import sys +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from nstat.notebook_fidelity_audit import write_notebook_fidelity_audit + + +def main() -> int: + path = write_notebook_fidelity_audit(REPO_ROOT) + print(path) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From 75bfa6c8d447d039eb9968d014680a0e673ec78b Mon Sep 17 00:00:00 2001 From: Iahn Cajigas Date: Sat, 7 Mar 2026 13:22:31 -0500 Subject: [PATCH 3/3] Translate decoding helpfile notebooks --- notebooks/DecodingExample.ipynb | 205 +++++++--- notebooks/DecodingExampleWithHist.ipynb | 190 +++++---- parity/notebook_fidelity.yml | 22 +- parity/report.md | 8 +- tests/test_parity_report.py | 2 +- .../build_decoding_fidelity_notebooks.py | 380 ++++++++++++++++++ tools/notebooks/parity_notes.yml | 8 +- 7 files changed, 667 insertions(+), 148 deletions(-) create mode 100644 tools/notebooks/build_decoding_fidelity_notebooks.py diff --git a/notebooks/DecodingExample.ipynb b/notebooks/DecodingExample.ipynb index 39b36c6f..ae8a6a8d 100644 --- a/notebooks/DecodingExample.ipynb +++ b/notebooks/DecodingExample.ipynb @@ -2,20 +2,20 @@ "cells": [ { "cell_type": "markdown", - "id": "c488b5fa", + "id": "72ddd907", "metadata": {}, "source": [ "\n", "## MATLAB Parity Note\n", "- Source MATLAB helpfile: `DecodingExample.mlx`\n", - "- Fidelity status: `partial`\n", - "- Remaining justified differences: Core decoding workflow is present, but MATLAB decoding options and reference outputs are not yet fully matched." + "- Fidelity status: `high_fidelity`\n", + "- Remaining justified differences: Workflow, model fitting, and decoded-stimulus figures now follow the MATLAB helpfile closely; exact traces still depend on stochastic simulation draws and Python plotting defaults.\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "3d710745", + "id": "b558e18d", "metadata": {}, "outputs": [], "source": [ @@ -34,86 +34,171 @@ "matplotlib.use(\"Agg\")\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "from scipy.io import loadmat\n", "\n", - "from nstat.data_manager import ensure_example_data\n", + "from nstat import Analysis, CIF, ConfigColl, CovColl, Covariate, DecodingAlgorithms, Trial, TrialConfig\n", "from nstat.notebook_figures import FigureTracker\n", "\n", "np.random.seed(0)\n", - "DATA_DIR = ensure_example_data(download=True)\n", "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", - "__tracker = FigureTracker(topic='DecodingExample', output_root=OUTPUT_ROOT, expected_count=5)\n", - "\n", - "def _load_example_globals(name: str) -> dict[str, object]:\n", - " candidates = [\n", - " Path(name),\n", - " DATA_DIR / name,\n", - " DATA_DIR / \"mEPSCs\" / name,\n", - " DATA_DIR / \"Place Cells\" / name,\n", - " DATA_DIR / \"Explicit Stimulus\" / name,\n", - " ]\n", - " for path in candidates:\n", - " if path.exists():\n", - " data = loadmat(path)\n", - " return {k: v for k, v in data.items() if not k.startswith(\"__\")}\n", - " return {}\n", - "\n", - "# SECTION 0: Section 0\n", - "# STIMULUS DECODING\n", - "# In this example we show how to decode a univariate and a bivariate stimulus based on a point process observations using nSTAT. Even though due to the simulated nature of the data, we know the exact condition intensity function, we estimate the parameters before moving on to the decoding stage." + "__tracker = FigureTracker(topic=\"DecodingExample\", output_root=OUTPUT_ROOT, expected_count=5)\n", + "\n", + "\n", + "def _prepare_figure(matlab_line: str, *, figsize=(8.0, 4.5)):\n", + " fig = __tracker.new_figure(matlab_line)\n", + " fig.clear()\n", + " fig.set_size_inches(*figsize)\n", + " return fig\n", + "\n", + "\n", + "def _plot_raster(ax, spike_coll):\n", + " for row in range(1, spike_coll.numSpikeTrains + 1):\n", + " train = spike_coll.getNST(row)\n", + " spikes = np.asarray(train.getSpikeTimes(), dtype=float).reshape(-1)\n", + " if spikes.size:\n", + " ax.vlines(spikes, row - 0.4, row + 0.4, color=\"k\", linewidth=0.5)\n", + " ax.set_ylabel(\"Neuron\")\n", + " ax.set_ylim(0.5, spike_coll.numSpikeTrains + 0.5)\n", + "\n", + "\n", + "def _plot_decoded_ci(ax, time, decoded, cov, stim, title):\n", + " center = np.asarray(decoded, dtype=float).reshape(-1)\n", + " variance = np.asarray(cov, dtype=float).reshape(-1)\n", + " sigma = np.sqrt(np.maximum(variance, 0.0))\n", + " z_val = 3.0\n", + " lower = center - z_val * sigma\n", + " upper = center + z_val * sigma\n", + " ax.plot(time[: center.size], center, \"b\", linewidth=1.5, label=\"x_{k|k}(t)\")\n", + " ax.plot(time[: center.size], lower, \"g\", linewidth=1.0, label=\"x_{k|k}(t)-3σ\")\n", + " ax.plot(time[: center.size], upper, \"g\", linewidth=1.0, label=\"x_{k|k}(t)+3σ\")\n", + " ax.plot(time[: center.size], np.asarray(stim).reshape(-1)[: center.size], \"k\", linewidth=1.5, label=\"x(t)\")\n", + " ax.set_title(title)\n", + " ax.set_xlabel(\"time (s)\")\n", + " ax.legend(loc=\"upper right\", frameon=False, fontsize=8)\n", + "\n", + "\n", + "# SECTION 0: STIMULUS DECODING\n", + "# In this example we decode a univariate stimulus from simulated point-process observations by following the MATLAB DecodingExample workflow.\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "feb22f76", + "id": "e37eea70", "metadata": {}, "outputs": [], "source": [ "# SECTION 1: Generate the conditional Intensity Function\n", "plt.close(\"all\")\n", - "#\n", - "#\n", + "delta = 0.001\n", + "Tmax = 10.0\n", + "time = np.arange(0.0, Tmax + delta, delta)\n", + "f = 0.1\n", + "b1 = 1.0\n", + "b0 = -3.0\n", + "x = np.sin(2.0 * np.pi * f * time)\n", + "exp_data = np.exp(b1 * x + b0)\n", + "lambda_data = exp_data / (1.0 + exp_data)\n", + "lambda_cov = Covariate(time, lambda_data / delta, \"\\\\Lambda(t)\", \"time\", \"s\", \"Hz\", [\"lambda_1\"])\n", + "\n", "numRealizations = 10\n", - "__tracker.new_figure('figure')\n", - "__tracker.new_figure('figure;')\n", - "__tracker.annotate('subplot(2,1,1)')\n", - "__tracker.annotate('spikeColl.plot')\n", - "__tracker.annotate('subplot(2,1,2)')\n", - "__tracker.annotate('lambda.plot')" + "spikeColl = CIF.simulateCIFByThinningFromLambda(lambda_cov, numRealizations=numRealizations)\n", + "\n", + "fig = _prepare_figure(\"figure\", figsize=(8.0, 5.5))\n", + "axs = fig.subplots(2, 1, sharex=True)\n", + "_plot_raster(axs[0], spikeColl)\n", + "axs[0].set_title(\"Simulated spike trains from λ(t)\")\n", + "axs[1].plot(time, lambda_cov.data[:, 0], color=\"b\", linewidth=2.0)\n", + "axs[1].set_title(\"Conditional intensity λ(t)\")\n", + "axs[1].set_xlabel(\"time (s)\")\n", + "axs[1].set_ylabel(\"Hz\")\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "d559b2d3", + "id": "b8dd2913", "metadata": {}, "outputs": [], "source": [ "# SECTION 2: Fit a model to the spikedata to obtain a model CIF\n", - "__tracker.new_figure('figure')\n", - "__tracker.new_figure('figure;')\n", - "__tracker.annotate('trial.plot')\n", - "#\n", - "pass\n", - "__tracker.new_figure('figure')\n", - "__tracker.new_figure('figure;')\n", - "#\n", - "# So we now have a model for lambda lambda = exp(b_0 + b_1*x(t))./(1+exp(b_0 + b_1*x(t)) * 1/delta because exp(b_0 + b_1*x(t))<<1 we can approximate this lambda by just the numerator i.e. lambda = exp(b_0 + b_1*x(t))./delta\n", - "# Now suppose we wanted to decode x(t) based on only having observed lambda\n", - "pass\n", - "# Construct a CIF object for each realization based on our encoding\n", - "# results abovel\n", - "# close all;\n", - "# Make noise according to the dynamic range of the stimulus\n", - "A = 1\n", - "__tracker.new_figure('figure')\n", - "__tracker.new_figure('figure;')\n", - "zVal = 3\n", - "__tracker.annotate(\"hEst=plot(time,x_u(1:end),'b',time,ciLower,'g',time,ciUpper,'g')\")\n", - "#\n", - "__tracker.annotate(\"hStim=stim.plot([],{{' ''k'',''Linewidth'',2'}})\")\n", - "__tracker.finalize()" + "stim = Covariate(time, x, \"Stimulus\", \"time\", \"s\", \"V\", [\"stim\"])\n", + "baseline = Covariate(time, np.ones_like(time), \"Baseline\", \"time\", \"s\", \"\", [\"constant\"])\n", + "cc = CovColl([stim, baseline])\n", + "trial = Trial(spikeColl, cc)\n", + "\n", + "fig = _prepare_figure(\"figure\", figsize=(8.0, 6.0))\n", + "axs = fig.subplots(3, 1, sharex=True)\n", + "_plot_raster(axs[0], spikeColl)\n", + "axs[0].set_title(\"Trial spike raster\")\n", + "axs[1].plot(time, stim.data[:, 0], color=\"k\", linewidth=1.5)\n", + "axs[1].set_title(\"Stimulus covariate\")\n", + "axs[1].set_ylabel(\"V\")\n", + "axs[2].plot(time, baseline.data[:, 0], color=\"0.3\", linewidth=1.5)\n", + "axs[2].set_title(\"Baseline covariate\")\n", + "axs[2].set_ylabel(\"constant\")\n", + "axs[2].set_xlabel(\"time (s)\")\n", + "\n", + "cfgColl = ConfigColl(\n", + " [\n", + " TrialConfig([[\"Baseline\", \"constant\"]], 1000.0, [], [], name=\"Baseline\"),\n", + " TrialConfig([[\"Baseline\", \"constant\"], [\"Stimulus\", \"stim\"]], 1000.0, [], [], name=\"Baseline+Stimulus\"),\n", + " ]\n", + ")\n", + "results = Analysis.RunAnalysisForAllNeurons(trial, cfgColl, 0)\n", + "\n", + "paramEst = np.column_stack([fit.getCoeffs(2)[:2] for fit in results])\n", + "meanParams = np.mean(paramEst, axis=1)\n", + "aic_matrix = np.vstack([fit.AIC for fit in results])\n", + "logll_matrix = np.vstack([fit.logLL for fit in results])\n", + "config_names = results[0].configNames\n", + "\n", + "fig = _prepare_figure(\"figure\", figsize=(8.0, 4.5))\n", + "axs = fig.subplots(1, 2)\n", + "neuron_idx = np.arange(1, paramEst.shape[1] + 1)\n", + "axs[0].plot(neuron_idx, paramEst[0], \"o-\", color=\"tab:blue\", label=\"b0\")\n", + "axs[0].axhline(meanParams[0], color=\"tab:blue\", linestyle=\"--\", linewidth=1.0)\n", + "axs[0].set_title(\"Baseline coefficients\")\n", + "axs[0].set_xlabel(\"Neuron\")\n", + "axs[0].set_ylabel(\"b0\")\n", + "axs[1].plot(neuron_idx, paramEst[1], \"o-\", color=\"tab:orange\", label=\"b1\")\n", + "axs[1].axhline(meanParams[1], color=\"tab:orange\", linestyle=\"--\", linewidth=1.0)\n", + "axs[1].set_title(\"Stimulus coefficients\")\n", + "axs[1].set_xlabel(\"Neuron\")\n", + "axs[1].set_ylabel(\"b1\")\n", + "\n", + "fig = _prepare_figure(\"figure\", figsize=(8.0, 4.5))\n", + "axs = fig.subplots(1, 2)\n", + "xloc = np.arange(len(config_names))\n", + "axs[0].bar(xloc, np.mean(aic_matrix, axis=0), color=[\"0.6\", \"0.3\"])\n", + "axs[0].set_xticks(xloc, config_names, rotation=15)\n", + "axs[0].set_title(\"Mean AIC across neurons\")\n", + "axs[1].bar(xloc, np.mean(logll_matrix, axis=0), color=[\"0.6\", \"0.3\"])\n", + "axs[1].set_xticks(xloc, config_names, rotation=15)\n", + "axs[1].set_title(\"Mean log-likelihood across neurons\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7529413", + "metadata": {}, + "outputs": [], + "source": [ + "# SECTION 3: Decode the stimulus from the fitted CIF\n", + "b0_est = paramEst[0, :]\n", + "b1_est = paramEst[1, :]\n", + "lambdaCIF = [CIF([b0_est[i], b1_est[i]], [\"1\", \"x\"], [\"x\"], \"binomial\") for i in range(numRealizations)]\n", + "\n", + "spikeColl.resample(1.0 / delta)\n", + "dN = spikeColl.dataToMatrix()\n", + "Q = 2.0 * np.std(np.diff(stim.data[:, 0]))\n", + "A = 1.0\n", + "x_p, W_p, x_u, W_u, *_ = DecodingAlgorithms.PPDecodeFilterLinear(A, Q, dN.T, b0_est, b1_est, \"binomial\", delta)\n", + "\n", + "fig = _prepare_figure(\"figure\", figsize=(8.0, 4.5))\n", + "ax = fig.subplots(1, 1)\n", + "_plot_decoded_ci(ax, time, x_u, W_u, stim.data[:, 0], f\"Decoded stimulus using {numRealizations} cells\")\n", + "__tracker.finalize()\n" ] } ], @@ -130,4 +215,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/notebooks/DecodingExampleWithHist.ipynb b/notebooks/DecodingExampleWithHist.ipynb index ae6b9c6a..302e98fb 100644 --- a/notebooks/DecodingExampleWithHist.ipynb +++ b/notebooks/DecodingExampleWithHist.ipynb @@ -2,20 +2,20 @@ "cells": [ { "cell_type": "markdown", - "id": "739c56fe", + "id": "f624a68c", "metadata": {}, "source": [ "\n", "## MATLAB Parity Note\n", "- Source MATLAB helpfile: `DecodingExampleWithHist.mlx`\n", - "- Fidelity status: `partial`\n", - "- Remaining justified differences: History-aware decoding is available, but the MATLAB workflow still has richer option handling and reference outputs." + "- Fidelity status: `high_fidelity`\n", + "- Remaining justified differences: The notebook now mirrors the MATLAB history-aware decoding workflow closely; exact stochastic trajectories and figure styling still vary slightly under Python execution.\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "107e889a", + "id": "cc21ece5", "metadata": {}, "outputs": [], "source": [ @@ -34,76 +34,130 @@ "matplotlib.use(\"Agg\")\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "from scipy.io import loadmat\n", "\n", - "from nstat.data_manager import ensure_example_data\n", + "from nstat import CIF, DecodingAlgorithms, History, Covariate, nspikeTrain, nstColl\n", "from nstat.notebook_figures import FigureTracker\n", "\n", "np.random.seed(0)\n", - "DATA_DIR = ensure_example_data(download=True)\n", "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", - "__tracker = FigureTracker(topic='DecodingExampleWithHist', output_root=OUTPUT_ROOT, expected_count=2)\n", - "\n", - "def _load_example_globals(name: str) -> dict[str, object]:\n", - " candidates = [\n", - " Path(name),\n", - " DATA_DIR / name,\n", - " DATA_DIR / \"mEPSCs\" / name,\n", - " DATA_DIR / \"Place Cells\" / name,\n", - " DATA_DIR / \"Explicit Stimulus\" / name,\n", - " ]\n", - " for path in candidates:\n", - " if path.exists():\n", - " data = loadmat(path)\n", - " return {k: v for k, v in data.items() if not k.startswith(\"__\")}\n", - " return {}\n", - "\n", - "# SECTION 0: Section 0\n", - "# 1-D Stimulus Decode with History Effect\n", - "# In the above decoding example, the simulated neurons did not have memory. That is their previous firing activity did not modulate their current probability of firing. In reality the firing history does affect the probabilty of neuronal firing (eg. refractory period, bursting, etc.). In this example, we simulate a population a neurons that exhibit this type of history dependence. We then decode the stimulus activity based on a conditional intensity function that includes the correct history terms and one that assumes no history dependence.\n", + "__tracker = FigureTracker(topic=\"DecodingExampleWithHist\", output_root=OUTPUT_ROOT, expected_count=2)\n", + "\n", + "\n", + "def _prepare_figure(matlab_line: str, *, figsize=(8.0, 4.5)):\n", + " fig = __tracker.new_figure(matlab_line)\n", + " fig.clear()\n", + " fig.set_size_inches(*figsize)\n", + " return fig\n", + "\n", + "\n", + "def _plot_raster(ax, spike_coll):\n", + " for row in range(1, spike_coll.numSpikeTrains + 1):\n", + " train = spike_coll.getNST(row)\n", + " spikes = np.asarray(train.getSpikeTimes(), dtype=float).reshape(-1)\n", + " if spikes.size:\n", + " ax.vlines(spikes, row - 0.4, row + 0.4, color=\"k\", linewidth=0.5)\n", + " ax.set_ylabel(\"Neuron\")\n", + " ax.set_ylim(0.5, spike_coll.numSpikeTrains + 0.5)\n", + "\n", + "\n", + "def _plot_decoded_ci(ax, time, decoded, cov, stim, title):\n", + " center = np.asarray(decoded, dtype=float).reshape(-1)\n", + " spread = np.asarray(cov, dtype=float).reshape(-1)\n", + " z_val = 3.0\n", + " lower = center - z_val * spread\n", + " upper = center + z_val * spread\n", + " ax.plot(time[: center.size], center, \"b\", linewidth=1.5, label=\"x_{k|k}(t)\")\n", + " ax.plot(time[: center.size], lower, \"g\", linewidth=1.0, label=\"x_{k|k}(t)-3σ\")\n", + " ax.plot(time[: center.size], upper, \"r\", linewidth=1.0, label=\"x_{k|k}(t)+3σ\")\n", + " ax.plot(time[: center.size], np.asarray(stim).reshape(-1)[: center.size], \"k\", linewidth=1.5, label=\"x(t)\")\n", + " ax.set_title(title)\n", + " ax.set_xlabel(\"time (s)\")\n", + " ax.legend(loc=\"upper right\", frameon=False, fontsize=8)\n", + "\n", + "\n", + "def _simulate_history_spike_train(time, stim_data, baseline, hist_coeffs, window_times):\n", + " spikes = []\n", + " for idx in range(1, len(time)):\n", + " t = time[idx]\n", + " spike_arr = np.asarray(spikes, dtype=float)\n", + " history_counts = []\n", + " for w_start, w_stop in zip(window_times[:-1], window_times[1:]):\n", + " if spike_arr.size:\n", + " history_counts.append(np.sum((spike_arr >= t - w_stop) & (spike_arr < t - w_start)))\n", + " else:\n", + " history_counts.append(0.0)\n", + " eta = baseline + stim_data[idx] + float(np.dot(hist_coeffs, history_counts))\n", + " p = np.exp(np.clip(eta, -20.0, 20.0))\n", + " p = p / (1.0 + p)\n", + " if np.random.rand() < p:\n", + " spikes.append(t)\n", + " return np.asarray(spikes, dtype=float)\n", + "\n", + "\n", + "# SECTION 0: 1-D Stimulus Decode with History Effect\n", + "# We simulate neurons with refractory-history effects and compare point-process decoding with and without the correct history terms.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44a6c7e4", + "metadata": {}, + "outputs": [], + "source": [ + "# SECTION 1: History-aware decoding workflow\n", "plt.close(\"all\")\n", - "# clear all;\n", - "#\n", + "delta = 0.001\n", + "Tmax = 1.0\n", + "time = np.arange(0.0, Tmax + delta, delta)\n", + "f = 1.0\n", + "b1 = 1.0\n", + "b0 = -2.0\n", + "stimData = b1 * np.sin(2.0 * np.pi * f * time)\n", + "histCoeffs = np.array([-2.0, -2.0, -4.0])\n", + "windowTimes = np.array([0.0, 0.001, 0.002, 0.003])\n", + "histObj = History(windowTimes)\n", + "stim = Covariate(time, stimData, \"Stimulus\", \"time\", \"s\", \"Voltage\", [\"sin\"])\n", + "\n", "numRealizations = 20\n", - "#\n", - "__tracker.new_figure('figure')\n", - "__tracker.new_figure('figure;')\n", - "__tracker.annotate('subplot(2,1,1)')\n", - "__tracker.annotate('sC.plot')\n", - "__tracker.annotate('subplot(2,1,2)')\n", - "__tracker.annotate('stim.plot')\n", - "#\n", - "#\n", - "# Construct a CIF object for each realization based on our encoding\n", - "# results above\n", - "# correct CIF w/ History\n", - "# CIF ignoring the history effect\n", - "#\n", - "#\n", - "#\n", - "#\n", - "# Make noise according to the dynamic range of the stimulus\n", - "# Decode with the correct and incorrect CIFs\n", - "#\n", - "#\n", - "#\n", - "# Compare the results\n", - "__tracker.new_figure('figure')\n", - "__tracker.new_figure('figure;')\n", - "__tracker.annotate('subplot(2,1,1)')\n", - "zVal = 3\n", - "__tracker.annotate(\"hEst=plot(time,x_u(1:end),'b',time,ciLower,'g',time,ciUpper,'r')\")\n", - "#\n", - "__tracker.annotate(\"hStim=stim.plot([],{{' ''k'',''Linewidth'',2'}})\")\n", - "#\n", - "__tracker.annotate('subplot(2,1,2)')\n", - "zVal = 3\n", - "__tracker.annotate(\"hEst=plot(time,x_uNoHist(1:end),'b',time,ciLower,'g',time,ciUpper,'r')\")\n", - "#\n", - "__tracker.annotate(\"hStim=stim.plot([],{{' ''k'',''Linewidth'',2'}})\")\n", - "__tracker.annotate(\"title(['Decoded Stimulus No Hist +/- 99% confidence intervals using ' num2str(numRealizations) ' cells'])\")\n", - "# We see that inclusion of history effect improves (as expected) the decoding of the stimulus based on the point process observations\n", - "__tracker.finalize()" + "trains = []\n", + "for idx in range(numRealizations):\n", + " spikes = _simulate_history_spike_train(time, stimData, b0, histCoeffs, windowTimes)\n", + " trains.append(nspikeTrain(spikes, str(idx + 1), delta, 0.0, Tmax, makePlots=-1))\n", + "sC = nstColl(trains)\n", + "\n", + "fig = _prepare_figure(\"figure\", figsize=(8.0, 5.5))\n", + "axs = fig.subplots(2, 1, sharex=True)\n", + "_plot_raster(axs[0], sC)\n", + "axs[0].set_title(\"History-dependent simulated spike trains\")\n", + "axs[1].plot(time, stim.data[:, 0], color=\"k\", linewidth=1.5)\n", + "axs[1].set_title(\"Stimulus\")\n", + "axs[1].set_xlabel(\"time (s)\")\n", + "axs[1].set_ylabel(\"Voltage\")\n", + "\n", + "lambdaCIF = [CIF([b0, b1], [\"1\", \"x\"], [\"x\"], \"binomial\", histCoeffs, histObj) for _ in range(numRealizations)]\n", + "lambdaCIFNoHist = [CIF([b0, b1], [\"1\", \"x\"], [\"x\"], \"binomial\") for _ in range(numRealizations)]\n", + "\n", + "sC.resample(1.0 / delta)\n", + "dN = sC.dataToMatrix()\n", + "Q = 2.0 * np.std(np.diff(stim.data[:, 0]))\n", + "Px0 = 0.1\n", + "A = 1.0\n", + "x_p, W_p, x_u, W_u, *_ = DecodingAlgorithms.PPDecodeFilter(A, Q, Px0, dN.T, lambdaCIF, delta)\n", + "x_p_no_hist, W_p_no_hist, x_u_no_hist, W_u_no_hist, *_ = DecodingAlgorithms.PPDecodeFilter(\n", + " A,\n", + " Q,\n", + " Px0,\n", + " dN.T,\n", + " lambdaCIFNoHist,\n", + " delta,\n", + ")\n", + "\n", + "fig = _prepare_figure(\"figure\", figsize=(8.0, 6.0))\n", + "axs = fig.subplots(2, 1, sharex=True)\n", + "_plot_decoded_ci(axs[0], time, x_u, W_u, stim.data[:, 0], f\"Decoded stimulus with history using {numRealizations} cells\")\n", + "_plot_decoded_ci(axs[1], time, x_u_no_hist, W_u_no_hist, stim.data[:, 0], f\"Decoded stimulus without history using {numRealizations} cells\")\n", + "__tracker.finalize()\n" ] } ], @@ -120,4 +174,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/parity/notebook_fidelity.yml b/parity/notebook_fidelity.yml index 5ef15fe7..acdb9956 100644 --- a/parity/notebook_fidelity.yml +++ b/parity/notebook_fidelity.yml @@ -53,32 +53,34 @@ items: - topic: DecodingExample source_matlab: DecodingExample.mlx python_notebook: notebooks/DecodingExample.ipynb - fidelity_status: partial - remaining_differences: Core decoding workflow is present, but MATLAB decoding options - and reference outputs are not yet fully matched. - python_sections: 3 + fidelity_status: high_fidelity + remaining_differences: Workflow, model fitting, and decoded-stimulus figures now + follow the MATLAB helpfile closely; exact traces still depend on stochastic simulation + draws and Python plotting defaults. + python_sections: 4 python_expected_figures: 5 python_uses_figure_tracker: true python_has_finalize_call: true matlab_repo_root: /Users/iahncajigas/Library/CloudStorage/Dropbox/Codex/nSTAT matlab_sections: 4 matlab_published_figures: 5 - section_delta: -1 + section_delta: 0 figure_delta: 0 - topic: DecodingExampleWithHist source_matlab: DecodingExampleWithHist.mlx python_notebook: notebooks/DecodingExampleWithHist.ipynb - fidelity_status: partial - remaining_differences: History-aware decoding is available, but the MATLAB workflow - still has richer option handling and reference outputs. - python_sections: 1 + fidelity_status: high_fidelity + remaining_differences: The notebook now mirrors the MATLAB history-aware decoding + workflow closely; exact stochastic trajectories and figure styling still vary + slightly under Python execution. + python_sections: 2 python_expected_figures: 2 python_uses_figure_tracker: true python_has_finalize_call: true matlab_repo_root: /Users/iahncajigas/Library/CloudStorage/Dropbox/Codex/nSTAT matlab_sections: 2 matlab_published_figures: 2 - section_delta: -1 + section_delta: 0 figure_delta: 0 - topic: ExplicitStimulusWhiskerData source_matlab: ExplicitStimulusWhiskerData.mlx diff --git a/parity/report.md b/parity/report.md index a09c3252..fb0e6531 100644 --- a/parity/report.md +++ b/parity/report.md @@ -34,14 +34,14 @@ Generated from `parity/manifest.yml`, `parity/class_fidelity.yml`, and `tools/no | Status | Count | |---|---:| | `exact` | 0 | -| `high_fidelity` | 4 | -| `partial` | 7 | +| `high_fidelity` | 6 | +| `partial` | 5 | ## Coverage Notes - Public API: no missing MATLAB public APIs remain; only the MATLAB help-browser utility is explicitly non-applicable. - Help/notebook parity: all inventoried MATLAB help workflows are mapped to Python notebooks or equivalents. -- Notebook fidelity: workflow coverage is complete, but 7 MATLAB-helpfile notebook ports are still marked partial in `tools/notebooks/parity_notes.yml`. +- Notebook fidelity: workflow coverage is complete, but 5 MATLAB-helpfile notebook ports are still marked partial in `tools/notebooks/parity_notes.yml`. - Notebook fidelity audit: structural section/figure comparisons are recorded in `parity/notebook_fidelity.yml`. - Paper examples and docs gallery: all canonical paper examples and committed gallery directories are mapped. - Class fidelity: the class audit reports no partial, shim-only, or missing items. @@ -52,8 +52,6 @@ No partial or missing items remain in the mapping inventory. ## Remaining Notebook-Fidelity Deltas -- `DecodingExample` -> `notebooks/DecodingExample.ipynb` [partial]: Core decoding workflow is present, but MATLAB decoding options and reference outputs are not yet fully matched. -- `DecodingExampleWithHist` -> `notebooks/DecodingExampleWithHist.ipynb` [partial]: History-aware decoding is available, but the MATLAB workflow still has richer option handling and reference outputs. - `ExplicitStimulusWhiskerData` -> `notebooks/ExplicitStimulusWhiskerData.ipynb` [partial]: Dataset-backed workflow is present, but figure-level and narrative parity with MATLAB are still incomplete. - `HippocampalPlaceCellExample` -> `notebooks/HippocampalPlaceCellExample.ipynb` [partial]: Core place-cell workflow is ported, but MATLAB figure sequencing and summary outputs are not yet exact. - `HybridFilterExample` -> `notebooks/HybridFilterExample.ipynb` [partial]: Hybrid filtering workflow executes, but MATLAB-specific output details and downstream validation remain incomplete. diff --git a/tests/test_parity_report.py b/tests/test_parity_report.py index 7f8cfa8d..ce8b6f22 100644 --- a/tests/test_parity_report.py +++ b/tests/test_parity_report.py @@ -23,7 +23,7 @@ def test_parity_report_highlights_current_constraints() -> None: assert "Notebook Fidelity Summary" in text assert "Remaining Notebook-Fidelity Deltas" in text assert "parity/notebook_fidelity.yml" in text - assert "DecodingExample" in text + assert "HybridFilterExample" in text assert "No partial or missing items remain in the mapping inventory." in text assert "Remaining Class-Fidelity Deltas" in text assert "No partial, shim-only, or missing class-fidelity items remain." in text diff --git a/tools/notebooks/build_decoding_fidelity_notebooks.py b/tools/notebooks/build_decoding_fidelity_notebooks.py new file mode 100644 index 00000000..74336a5a --- /dev/null +++ b/tools/notebooks/build_decoding_fidelity_notebooks.py @@ -0,0 +1,380 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +from pathlib import Path +from textwrap import dedent + +import nbformat +from nbformat.v4 import new_code_cell, new_markdown_cell, new_notebook + + +REPO_ROOT = Path(__file__).resolve().parents[2] +NOTEBOOK_DIR = REPO_ROOT / "notebooks" + + +LANGUAGE_METADATA = { + "language_info": { + "name": "python", + } +} + + +def _write_notebook(path: Path, *, topic: str, expected_figures: int, markdown_note: str, code_cells: list[str]) -> None: + notebook = new_notebook( + cells=[new_markdown_cell(markdown_note), *[new_code_cell(dedent(cell).strip() + "\n") for cell in code_cells]], + metadata={ + **LANGUAGE_METADATA, + "nstat": { + "expected_figures": expected_figures, + "run_group": "smoke", + "style": "python-example", + "topic": topic, + }, + }, + ) + path.write_text(nbformat.writes(notebook), encoding="utf-8") + + +DECODING_EXAMPLE_NOTE = """\ + +## MATLAB Parity Note +- Source MATLAB helpfile: `DecodingExample.mlx` +- Fidelity status: `high_fidelity` +- Remaining justified differences: Workflow, model fitting, and decoded-stimulus figures now follow the MATLAB helpfile closely; exact traces still depend on stochastic simulation draws and Python plotting defaults. +""" + + +DECODING_EXAMPLE_CODE = [ + """ + # nSTAT-python notebook example: DecodingExample + from pathlib import Path + import sys + + REPO_ROOT = Path.cwd().resolve().parent + if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + SRC_PATH = (REPO_ROOT / "src").resolve() + if str(SRC_PATH) not in sys.path: + sys.path.insert(0, str(SRC_PATH)) + + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + import numpy as np + + from nstat import Analysis, CIF, ConfigColl, CovColl, Covariate, DecodingAlgorithms, Trial, TrialConfig + from nstat.notebook_figures import FigureTracker + + np.random.seed(0) + OUTPUT_ROOT = REPO_ROOT / "output" / "notebook_images" + __tracker = FigureTracker(topic="DecodingExample", output_root=OUTPUT_ROOT, expected_count=5) + + + def _prepare_figure(matlab_line: str, *, figsize=(8.0, 4.5)): + fig = __tracker.new_figure(matlab_line) + fig.clear() + fig.set_size_inches(*figsize) + return fig + + + def _plot_raster(ax, spike_coll): + for row in range(1, spike_coll.numSpikeTrains + 1): + train = spike_coll.getNST(row) + spikes = np.asarray(train.getSpikeTimes(), dtype=float).reshape(-1) + if spikes.size: + ax.vlines(spikes, row - 0.4, row + 0.4, color="k", linewidth=0.5) + ax.set_ylabel("Neuron") + ax.set_ylim(0.5, spike_coll.numSpikeTrains + 0.5) + + + def _plot_decoded_ci(ax, time, decoded, cov, stim, title): + center = np.asarray(decoded, dtype=float).reshape(-1) + variance = np.asarray(cov, dtype=float).reshape(-1) + sigma = np.sqrt(np.maximum(variance, 0.0)) + z_val = 3.0 + lower = center - z_val * sigma + upper = center + z_val * sigma + ax.plot(time[: center.size], center, "b", linewidth=1.5, label="x_{k|k}(t)") + ax.plot(time[: center.size], lower, "g", linewidth=1.0, label="x_{k|k}(t)-3σ") + ax.plot(time[: center.size], upper, "g", linewidth=1.0, label="x_{k|k}(t)+3σ") + ax.plot(time[: center.size], np.asarray(stim).reshape(-1)[: center.size], "k", linewidth=1.5, label="x(t)") + ax.set_title(title) + ax.set_xlabel("time (s)") + ax.legend(loc="upper right", frameon=False, fontsize=8) + + + # SECTION 0: STIMULUS DECODING + # In this example we decode a univariate stimulus from simulated point-process observations by following the MATLAB DecodingExample workflow. + """, + """ + # SECTION 1: Generate the conditional Intensity Function + plt.close("all") + delta = 0.001 + Tmax = 10.0 + time = np.arange(0.0, Tmax + delta, delta) + f = 0.1 + b1 = 1.0 + b0 = -3.0 + x = np.sin(2.0 * np.pi * f * time) + exp_data = np.exp(b1 * x + b0) + lambda_data = exp_data / (1.0 + exp_data) + lambda_cov = Covariate(time, lambda_data / delta, "\\\\Lambda(t)", "time", "s", "Hz", ["lambda_1"]) + + numRealizations = 10 + spikeColl = CIF.simulateCIFByThinningFromLambda(lambda_cov, numRealizations=numRealizations) + + fig = _prepare_figure("figure", figsize=(8.0, 5.5)) + axs = fig.subplots(2, 1, sharex=True) + _plot_raster(axs[0], spikeColl) + axs[0].set_title("Simulated spike trains from λ(t)") + axs[1].plot(time, lambda_cov.data[:, 0], color="b", linewidth=2.0) + axs[1].set_title("Conditional intensity λ(t)") + axs[1].set_xlabel("time (s)") + axs[1].set_ylabel("Hz") + """, + """ + # SECTION 2: Fit a model to the spikedata to obtain a model CIF + stim = Covariate(time, x, "Stimulus", "time", "s", "V", ["stim"]) + baseline = Covariate(time, np.ones_like(time), "Baseline", "time", "s", "", ["constant"]) + cc = CovColl([stim, baseline]) + trial = Trial(spikeColl, cc) + + fig = _prepare_figure("figure", figsize=(8.0, 6.0)) + axs = fig.subplots(3, 1, sharex=True) + _plot_raster(axs[0], spikeColl) + axs[0].set_title("Trial spike raster") + axs[1].plot(time, stim.data[:, 0], color="k", linewidth=1.5) + axs[1].set_title("Stimulus covariate") + axs[1].set_ylabel("V") + axs[2].plot(time, baseline.data[:, 0], color="0.3", linewidth=1.5) + axs[2].set_title("Baseline covariate") + axs[2].set_ylabel("constant") + axs[2].set_xlabel("time (s)") + + cfgColl = ConfigColl( + [ + TrialConfig([["Baseline", "constant"]], 1000.0, [], [], name="Baseline"), + TrialConfig([["Baseline", "constant"], ["Stimulus", "stim"]], 1000.0, [], [], name="Baseline+Stimulus"), + ] + ) + results = Analysis.RunAnalysisForAllNeurons(trial, cfgColl, 0) + + paramEst = np.column_stack([fit.getCoeffs(2)[:2] for fit in results]) + meanParams = np.mean(paramEst, axis=1) + aic_matrix = np.vstack([fit.AIC for fit in results]) + logll_matrix = np.vstack([fit.logLL for fit in results]) + config_names = results[0].configNames + + fig = _prepare_figure("figure", figsize=(8.0, 4.5)) + axs = fig.subplots(1, 2) + neuron_idx = np.arange(1, paramEst.shape[1] + 1) + axs[0].plot(neuron_idx, paramEst[0], "o-", color="tab:blue", label="b0") + axs[0].axhline(meanParams[0], color="tab:blue", linestyle="--", linewidth=1.0) + axs[0].set_title("Baseline coefficients") + axs[0].set_xlabel("Neuron") + axs[0].set_ylabel("b0") + axs[1].plot(neuron_idx, paramEst[1], "o-", color="tab:orange", label="b1") + axs[1].axhline(meanParams[1], color="tab:orange", linestyle="--", linewidth=1.0) + axs[1].set_title("Stimulus coefficients") + axs[1].set_xlabel("Neuron") + axs[1].set_ylabel("b1") + + fig = _prepare_figure("figure", figsize=(8.0, 4.5)) + axs = fig.subplots(1, 2) + xloc = np.arange(len(config_names)) + axs[0].bar(xloc, np.mean(aic_matrix, axis=0), color=["0.6", "0.3"]) + axs[0].set_xticks(xloc, config_names, rotation=15) + axs[0].set_title("Mean AIC across neurons") + axs[1].bar(xloc, np.mean(logll_matrix, axis=0), color=["0.6", "0.3"]) + axs[1].set_xticks(xloc, config_names, rotation=15) + axs[1].set_title("Mean log-likelihood across neurons") + """, + """ + # SECTION 3: Decode the stimulus from the fitted CIF + b0_est = paramEst[0, :] + b1_est = paramEst[1, :] + lambdaCIF = [CIF([b0_est[i], b1_est[i]], ["1", "x"], ["x"], "binomial") for i in range(numRealizations)] + + spikeColl.resample(1.0 / delta) + dN = spikeColl.dataToMatrix() + Q = 2.0 * np.std(np.diff(stim.data[:, 0])) + A = 1.0 + x_p, W_p, x_u, W_u, *_ = DecodingAlgorithms.PPDecodeFilterLinear(A, Q, dN.T, b0_est, b1_est, "binomial", delta) + + fig = _prepare_figure("figure", figsize=(8.0, 4.5)) + ax = fig.subplots(1, 1) + _plot_decoded_ci(ax, time, x_u, W_u, stim.data[:, 0], f"Decoded stimulus using {numRealizations} cells") + __tracker.finalize() + """, +] + + +DECODING_HISTORY_NOTE = """\ + +## MATLAB Parity Note +- Source MATLAB helpfile: `DecodingExampleWithHist.mlx` +- Fidelity status: `high_fidelity` +- Remaining justified differences: The notebook now mirrors the MATLAB history-aware decoding workflow closely; exact stochastic trajectories and figure styling still vary slightly under Python execution. +""" + + +DECODING_HISTORY_CODE = [ + """ + # nSTAT-python notebook example: DecodingExampleWithHist + from pathlib import Path + import sys + + REPO_ROOT = Path.cwd().resolve().parent + if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + SRC_PATH = (REPO_ROOT / "src").resolve() + if str(SRC_PATH) not in sys.path: + sys.path.insert(0, str(SRC_PATH)) + + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + import numpy as np + + from nstat import CIF, DecodingAlgorithms, History, Covariate, nspikeTrain, nstColl + from nstat.notebook_figures import FigureTracker + + np.random.seed(0) + OUTPUT_ROOT = REPO_ROOT / "output" / "notebook_images" + __tracker = FigureTracker(topic="DecodingExampleWithHist", output_root=OUTPUT_ROOT, expected_count=2) + + + def _prepare_figure(matlab_line: str, *, figsize=(8.0, 4.5)): + fig = __tracker.new_figure(matlab_line) + fig.clear() + fig.set_size_inches(*figsize) + return fig + + + def _plot_raster(ax, spike_coll): + for row in range(1, spike_coll.numSpikeTrains + 1): + train = spike_coll.getNST(row) + spikes = np.asarray(train.getSpikeTimes(), dtype=float).reshape(-1) + if spikes.size: + ax.vlines(spikes, row - 0.4, row + 0.4, color="k", linewidth=0.5) + ax.set_ylabel("Neuron") + ax.set_ylim(0.5, spike_coll.numSpikeTrains + 0.5) + + + def _plot_decoded_ci(ax, time, decoded, cov, stim, title): + center = np.asarray(decoded, dtype=float).reshape(-1) + spread = np.asarray(cov, dtype=float).reshape(-1) + z_val = 3.0 + lower = center - z_val * spread + upper = center + z_val * spread + ax.plot(time[: center.size], center, "b", linewidth=1.5, label="x_{k|k}(t)") + ax.plot(time[: center.size], lower, "g", linewidth=1.0, label="x_{k|k}(t)-3σ") + ax.plot(time[: center.size], upper, "r", linewidth=1.0, label="x_{k|k}(t)+3σ") + ax.plot(time[: center.size], np.asarray(stim).reshape(-1)[: center.size], "k", linewidth=1.5, label="x(t)") + ax.set_title(title) + ax.set_xlabel("time (s)") + ax.legend(loc="upper right", frameon=False, fontsize=8) + + + def _simulate_history_spike_train(time, stim_data, baseline, hist_coeffs, window_times): + spikes = [] + for idx in range(1, len(time)): + t = time[idx] + spike_arr = np.asarray(spikes, dtype=float) + history_counts = [] + for w_start, w_stop in zip(window_times[:-1], window_times[1:]): + if spike_arr.size: + history_counts.append(np.sum((spike_arr >= t - w_stop) & (spike_arr < t - w_start))) + else: + history_counts.append(0.0) + eta = baseline + stim_data[idx] + float(np.dot(hist_coeffs, history_counts)) + p = np.exp(np.clip(eta, -20.0, 20.0)) + p = p / (1.0 + p) + if np.random.rand() < p: + spikes.append(t) + return np.asarray(spikes, dtype=float) + + + # SECTION 0: 1-D Stimulus Decode with History Effect + # We simulate neurons with refractory-history effects and compare point-process decoding with and without the correct history terms. + """, + """ + # SECTION 1: History-aware decoding workflow + plt.close("all") + delta = 0.001 + Tmax = 1.0 + time = np.arange(0.0, Tmax + delta, delta) + f = 1.0 + b1 = 1.0 + b0 = -2.0 + stimData = b1 * np.sin(2.0 * np.pi * f * time) + histCoeffs = np.array([-2.0, -2.0, -4.0]) + windowTimes = np.array([0.0, 0.001, 0.002, 0.003]) + histObj = History(windowTimes) + stim = Covariate(time, stimData, "Stimulus", "time", "s", "Voltage", ["sin"]) + + numRealizations = 20 + trains = [] + for idx in range(numRealizations): + spikes = _simulate_history_spike_train(time, stimData, b0, histCoeffs, windowTimes) + trains.append(nspikeTrain(spikes, str(idx + 1), delta, 0.0, Tmax, makePlots=-1)) + sC = nstColl(trains) + + fig = _prepare_figure("figure", figsize=(8.0, 5.5)) + axs = fig.subplots(2, 1, sharex=True) + _plot_raster(axs[0], sC) + axs[0].set_title("History-dependent simulated spike trains") + axs[1].plot(time, stim.data[:, 0], color="k", linewidth=1.5) + axs[1].set_title("Stimulus") + axs[1].set_xlabel("time (s)") + axs[1].set_ylabel("Voltage") + + lambdaCIF = [CIF([b0, b1], ["1", "x"], ["x"], "binomial", histCoeffs, histObj) for _ in range(numRealizations)] + lambdaCIFNoHist = [CIF([b0, b1], ["1", "x"], ["x"], "binomial") for _ in range(numRealizations)] + + sC.resample(1.0 / delta) + dN = sC.dataToMatrix() + Q = 2.0 * np.std(np.diff(stim.data[:, 0])) + Px0 = 0.1 + A = 1.0 + x_p, W_p, x_u, W_u, *_ = DecodingAlgorithms.PPDecodeFilter(A, Q, Px0, dN.T, lambdaCIF, delta) + x_p_no_hist, W_p_no_hist, x_u_no_hist, W_u_no_hist, *_ = DecodingAlgorithms.PPDecodeFilter( + A, + Q, + Px0, + dN.T, + lambdaCIFNoHist, + delta, + ) + + fig = _prepare_figure("figure", figsize=(8.0, 6.0)) + axs = fig.subplots(2, 1, sharex=True) + _plot_decoded_ci(axs[0], time, x_u, W_u, stim.data[:, 0], f"Decoded stimulus with history using {numRealizations} cells") + _plot_decoded_ci(axs[1], time, x_u_no_hist, W_u_no_hist, stim.data[:, 0], f"Decoded stimulus without history using {numRealizations} cells") + __tracker.finalize() + """, +] + + +def main() -> int: + _write_notebook( + NOTEBOOK_DIR / "DecodingExample.ipynb", + topic="DecodingExample", + expected_figures=5, + markdown_note=DECODING_EXAMPLE_NOTE, + code_cells=DECODING_EXAMPLE_CODE, + ) + _write_notebook( + NOTEBOOK_DIR / "DecodingExampleWithHist.ipynb", + topic="DecodingExampleWithHist", + expected_figures=2, + markdown_note=DECODING_HISTORY_NOTE, + code_cells=DECODING_HISTORY_CODE, + ) + print(NOTEBOOK_DIR / "DecodingExample.ipynb") + print(NOTEBOOK_DIR / "DecodingExampleWithHist.ipynb") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/notebooks/parity_notes.yml b/tools/notebooks/parity_notes.yml index f53501b6..197c97f3 100644 --- a/tools/notebooks/parity_notes.yml +++ b/tools/notebooks/parity_notes.yml @@ -18,13 +18,13 @@ notes: - topic: DecodingExample file: notebooks/DecodingExample.ipynb source_matlab: DecodingExample.mlx - fidelity_status: partial - remaining_differences: Core decoding workflow is present, but MATLAB decoding options and reference outputs are not yet fully matched. + fidelity_status: high_fidelity + remaining_differences: Workflow, model fitting, and decoded-stimulus figures now follow the MATLAB helpfile closely; exact traces still depend on stochastic simulation draws and Python plotting defaults. - topic: DecodingExampleWithHist file: notebooks/DecodingExampleWithHist.ipynb source_matlab: DecodingExampleWithHist.mlx - fidelity_status: partial - remaining_differences: History-aware decoding is available, but the MATLAB workflow still has richer option handling and reference outputs. + fidelity_status: high_fidelity + remaining_differences: The notebook now mirrors the MATLAB history-aware decoding workflow closely; exact stochastic trajectories and figure styling still vary slightly under Python execution. - topic: ExplicitStimulusWhiskerData file: notebooks/ExplicitStimulusWhiskerData.ipynb source_matlab: ExplicitStimulusWhiskerData.mlx