Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 145 additions & 60 deletions notebooks/DecodingExample.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@
"cells": [
{
"cell_type": "markdown",
"id": "c488b5fa",
"id": "72ddd907",
"metadata": {},
"source": [
"<!-- parity-note -->\n",
"## MATLAB Parity Note\n",
"- Source MATLAB helpfile: `DecodingExample.mlx`\n",
"- Fidelity status: `partial`\n",
"- Remaining justified differences: Core decoding workflow is present, but MATLAB decoding options and reference outputs are not yet fully matched."
"- Fidelity status: `high_fidelity`\n",
"- Remaining justified differences: Workflow, model fitting, and decoded-stimulus figures now follow the MATLAB helpfile closely; exact traces still depend on stochastic simulation draws and Python plotting defaults.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3d710745",
"id": "b558e18d",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -34,86 +34,171 @@
"matplotlib.use(\"Agg\")\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from scipy.io import loadmat\n",
"\n",
"from nstat.data_manager import ensure_example_data\n",
"from nstat import Analysis, CIF, ConfigColl, CovColl, Covariate, DecodingAlgorithms, Trial, TrialConfig\n",
"from nstat.notebook_figures import FigureTracker\n",
"\n",
"np.random.seed(0)\n",
"DATA_DIR = ensure_example_data(download=True)\n",
"OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n",
"__tracker = FigureTracker(topic='DecodingExample', output_root=OUTPUT_ROOT, expected_count=5)\n",
"\n",
"def _load_example_globals(name: str) -> dict[str, object]:\n",
" candidates = [\n",
" Path(name),\n",
" DATA_DIR / name,\n",
" DATA_DIR / \"mEPSCs\" / name,\n",
" DATA_DIR / \"Place Cells\" / name,\n",
" DATA_DIR / \"Explicit Stimulus\" / name,\n",
" ]\n",
" for path in candidates:\n",
" if path.exists():\n",
" data = loadmat(path)\n",
" return {k: v for k, v in data.items() if not k.startswith(\"__\")}\n",
" return {}\n",
"\n",
"# SECTION 0: Section 0\n",
"# STIMULUS DECODING\n",
"# In this example we show how to decode a univariate and a bivariate stimulus based on a point process observations using nSTAT. Even though due to the simulated nature of the data, we know the exact condition intensity function, we estimate the parameters before moving on to the decoding stage."
"__tracker = FigureTracker(topic=\"DecodingExample\", output_root=OUTPUT_ROOT, expected_count=5)\n",
"\n",
"\n",
"def _prepare_figure(matlab_line: str, *, figsize=(8.0, 4.5)):\n",
" fig = __tracker.new_figure(matlab_line)\n",
" fig.clear()\n",
" fig.set_size_inches(*figsize)\n",
" return fig\n",
"\n",
"\n",
"def _plot_raster(ax, spike_coll):\n",
" for row in range(1, spike_coll.numSpikeTrains + 1):\n",
" train = spike_coll.getNST(row)\n",
" spikes = np.asarray(train.getSpikeTimes(), dtype=float).reshape(-1)\n",
" if spikes.size:\n",
" ax.vlines(spikes, row - 0.4, row + 0.4, color=\"k\", linewidth=0.5)\n",
" ax.set_ylabel(\"Neuron\")\n",
" ax.set_ylim(0.5, spike_coll.numSpikeTrains + 0.5)\n",
"\n",
"\n",
"def _plot_decoded_ci(ax, time, decoded, cov, stim, title):\n",
" center = np.asarray(decoded, dtype=float).reshape(-1)\n",
" variance = np.asarray(cov, dtype=float).reshape(-1)\n",
" sigma = np.sqrt(np.maximum(variance, 0.0))\n",
" z_val = 3.0\n",
" lower = center - z_val * sigma\n",
" upper = center + z_val * sigma\n",
" ax.plot(time[: center.size], center, \"b\", linewidth=1.5, label=\"x_{k|k}(t)\")\n",
" ax.plot(time[: center.size], lower, \"g\", linewidth=1.0, label=\"x_{k|k}(t)-3σ\")\n",
" ax.plot(time[: center.size], upper, \"g\", linewidth=1.0, label=\"x_{k|k}(t)+3σ\")\n",
" ax.plot(time[: center.size], np.asarray(stim).reshape(-1)[: center.size], \"k\", linewidth=1.5, label=\"x(t)\")\n",
" ax.set_title(title)\n",
" ax.set_xlabel(\"time (s)\")\n",
" ax.legend(loc=\"upper right\", frameon=False, fontsize=8)\n",
"\n",
"\n",
"# SECTION 0: STIMULUS DECODING\n",
"# In this example we decode a univariate stimulus from simulated point-process observations by following the MATLAB DecodingExample workflow.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "feb22f76",
"id": "e37eea70",
"metadata": {},
"outputs": [],
"source": [
"# SECTION 1: Generate the conditional Intensity Function\n",
"plt.close(\"all\")\n",
"#\n",
"#\n",
"delta = 0.001\n",
"Tmax = 10.0\n",
"time = np.arange(0.0, Tmax + delta, delta)\n",
"f = 0.1\n",
"b1 = 1.0\n",
"b0 = -3.0\n",
"x = np.sin(2.0 * np.pi * f * time)\n",
"exp_data = np.exp(b1 * x + b0)\n",
"lambda_data = exp_data / (1.0 + exp_data)\n",
"lambda_cov = Covariate(time, lambda_data / delta, \"\\\\Lambda(t)\", \"time\", \"s\", \"Hz\", [\"lambda_1\"])\n",
"\n",
"numRealizations = 10\n",
"__tracker.new_figure('figure')\n",
"__tracker.new_figure('figure;')\n",
"__tracker.annotate('subplot(2,1,1)')\n",
"__tracker.annotate('spikeColl.plot')\n",
"__tracker.annotate('subplot(2,1,2)')\n",
"__tracker.annotate('lambda.plot')"
"spikeColl = CIF.simulateCIFByThinningFromLambda(lambda_cov, numRealizations=numRealizations)\n",
"\n",
"fig = _prepare_figure(\"figure\", figsize=(8.0, 5.5))\n",
"axs = fig.subplots(2, 1, sharex=True)\n",
"_plot_raster(axs[0], spikeColl)\n",
"axs[0].set_title(\"Simulated spike trains from λ(t)\")\n",
"axs[1].plot(time, lambda_cov.data[:, 0], color=\"b\", linewidth=2.0)\n",
"axs[1].set_title(\"Conditional intensity λ(t)\")\n",
"axs[1].set_xlabel(\"time (s)\")\n",
"axs[1].set_ylabel(\"Hz\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d559b2d3",
"id": "b8dd2913",
"metadata": {},
"outputs": [],
"source": [
"# SECTION 2: Fit a model to the spikedata to obtain a model CIF\n",
"__tracker.new_figure('figure')\n",
"__tracker.new_figure('figure;')\n",
"__tracker.annotate('trial.plot')\n",
"#\n",
"pass\n",
"__tracker.new_figure('figure')\n",
"__tracker.new_figure('figure;')\n",
"#\n",
"# So we now have a model for lambda lambda = exp(b_0 + b_1*x(t))./(1+exp(b_0 + b_1*x(t)) * 1/delta because exp(b_0 + b_1*x(t))<<1 we can approximate this lambda by just the numerator i.e. lambda = exp(b_0 + b_1*x(t))./delta\n",
"# Now suppose we wanted to decode x(t) based on only having observed lambda\n",
"pass\n",
"# Construct a CIF object for each realization based on our encoding\n",
"# results abovel\n",
"# close all;\n",
"# Make noise according to the dynamic range of the stimulus\n",
"A = 1\n",
"__tracker.new_figure('figure')\n",
"__tracker.new_figure('figure;')\n",
"zVal = 3\n",
"__tracker.annotate(\"hEst=plot(time,x_u(1:end),'b',time,ciLower,'g',time,ciUpper,'g')\")\n",
"#\n",
"__tracker.annotate(\"hStim=stim.plot([],{{' ''k'',''Linewidth'',2'}})\")\n",
"__tracker.finalize()"
"stim = Covariate(time, x, \"Stimulus\", \"time\", \"s\", \"V\", [\"stim\"])\n",
"baseline = Covariate(time, np.ones_like(time), \"Baseline\", \"time\", \"s\", \"\", [\"constant\"])\n",
"cc = CovColl([stim, baseline])\n",
"trial = Trial(spikeColl, cc)\n",
"\n",
"fig = _prepare_figure(\"figure\", figsize=(8.0, 6.0))\n",
"axs = fig.subplots(3, 1, sharex=True)\n",
"_plot_raster(axs[0], spikeColl)\n",
"axs[0].set_title(\"Trial spike raster\")\n",
"axs[1].plot(time, stim.data[:, 0], color=\"k\", linewidth=1.5)\n",
"axs[1].set_title(\"Stimulus covariate\")\n",
"axs[1].set_ylabel(\"V\")\n",
"axs[2].plot(time, baseline.data[:, 0], color=\"0.3\", linewidth=1.5)\n",
"axs[2].set_title(\"Baseline covariate\")\n",
"axs[2].set_ylabel(\"constant\")\n",
"axs[2].set_xlabel(\"time (s)\")\n",
"\n",
"cfgColl = ConfigColl(\n",
" [\n",
" TrialConfig([[\"Baseline\", \"constant\"]], 1000.0, [], [], name=\"Baseline\"),\n",
" TrialConfig([[\"Baseline\", \"constant\"], [\"Stimulus\", \"stim\"]], 1000.0, [], [], name=\"Baseline+Stimulus\"),\n",
" ]\n",
")\n",
"results = Analysis.RunAnalysisForAllNeurons(trial, cfgColl, 0)\n",
"\n",
"paramEst = np.column_stack([fit.getCoeffs(2)[:2] for fit in results])\n",
"meanParams = np.mean(paramEst, axis=1)\n",
"aic_matrix = np.vstack([fit.AIC for fit in results])\n",
"logll_matrix = np.vstack([fit.logLL for fit in results])\n",
"config_names = results[0].configNames\n",
"\n",
"fig = _prepare_figure(\"figure\", figsize=(8.0, 4.5))\n",
"axs = fig.subplots(1, 2)\n",
"neuron_idx = np.arange(1, paramEst.shape[1] + 1)\n",
"axs[0].plot(neuron_idx, paramEst[0], \"o-\", color=\"tab:blue\", label=\"b0\")\n",
"axs[0].axhline(meanParams[0], color=\"tab:blue\", linestyle=\"--\", linewidth=1.0)\n",
"axs[0].set_title(\"Baseline coefficients\")\n",
"axs[0].set_xlabel(\"Neuron\")\n",
"axs[0].set_ylabel(\"b0\")\n",
"axs[1].plot(neuron_idx, paramEst[1], \"o-\", color=\"tab:orange\", label=\"b1\")\n",
"axs[1].axhline(meanParams[1], color=\"tab:orange\", linestyle=\"--\", linewidth=1.0)\n",
"axs[1].set_title(\"Stimulus coefficients\")\n",
"axs[1].set_xlabel(\"Neuron\")\n",
"axs[1].set_ylabel(\"b1\")\n",
"\n",
"fig = _prepare_figure(\"figure\", figsize=(8.0, 4.5))\n",
"axs = fig.subplots(1, 2)\n",
"xloc = np.arange(len(config_names))\n",
"axs[0].bar(xloc, np.mean(aic_matrix, axis=0), color=[\"0.6\", \"0.3\"])\n",
"axs[0].set_xticks(xloc, config_names, rotation=15)\n",
"axs[0].set_title(\"Mean AIC across neurons\")\n",
"axs[1].bar(xloc, np.mean(logll_matrix, axis=0), color=[\"0.6\", \"0.3\"])\n",
"axs[1].set_xticks(xloc, config_names, rotation=15)\n",
"axs[1].set_title(\"Mean log-likelihood across neurons\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d7529413",
"metadata": {},
"outputs": [],
"source": [
"# SECTION 3: Decode the stimulus from the fitted CIF\n",
"b0_est = paramEst[0, :]\n",
"b1_est = paramEst[1, :]\n",
"lambdaCIF = [CIF([b0_est[i], b1_est[i]], [\"1\", \"x\"], [\"x\"], \"binomial\") for i in range(numRealizations)]\n",
"\n",
"spikeColl.resample(1.0 / delta)\n",
"dN = spikeColl.dataToMatrix()\n",
"Q = 2.0 * np.std(np.diff(stim.data[:, 0]))\n",
"A = 1.0\n",
"x_p, W_p, x_u, W_u, *_ = DecodingAlgorithms.PPDecodeFilterLinear(A, Q, dN.T, b0_est, b1_est, \"binomial\", delta)\n",
"\n",
"fig = _prepare_figure(\"figure\", figsize=(8.0, 4.5))\n",
"ax = fig.subplots(1, 1)\n",
"_plot_decoded_ci(ax, time, x_u, W_u, stim.data[:, 0], f\"Decoded stimulus using {numRealizations} cells\")\n",
"__tracker.finalize()\n"
]
}
],
Expand All @@ -130,4 +215,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
Loading