# 1D Maze Results Viewer

Load saved `.pcellbundle` results from 1D maze analysis and inspect them
without re-running the pipeline. Supports loading a single bundle or
multiple bundles for cross-session comparison.

---

**Note:** Interactive widgets require **Jupyter Lab**:

```bash
cd notebook && jupyter lab --no-browser --port=6006
```

In [None]:
import sys
from pathlib import Path

project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import display

from placecell.dataset import PlaceCellDataset
from placecell.notebook import browse_units_1d, create_shuffle_browser_1d
from placecell.visualization import (
    plot_diagnostics,
    plot_graph_overlay,
    plot_occupancy_preview_1d,
    plot_session_summary,
    plot_shuffle_test_1d,
    plot_summary_scatter,
)

## Load Bundles

List the `.pcellbundle` paths to load. For a single session, use one entry.

In [None]:
BUNDLE_PATHS = [
    Path("bundles/WL25_20251219.pcellbundle"),
]

datasets = {}
for bp in BUNDLE_PATHS:
    name = bp.stem.replace(".pcellbundle", "")
    ds = PlaceCellDataset.load_bundle(bp)
    datasets[name] = ds
    s = ds.summary()
    print(
        f"{name}: {s['n_total']} units, "
        f"{s['n_sig']} sig, {s['n_stable']} stable, "
        f"{s['n_place_cells']} place cells"
    )

## Cross-Session Summary

Counts and proportions of significant, stable, and place cell units across sessions.
Only shown when multiple bundles are loaded.

In [None]:
rows = []
for name, ds in datasets.items():
    rows.append({"dataset": name, **ds.summary()})
summary_df = pd.DataFrame(rows)
display(summary_df)

if len(datasets) > 1:
    plot_session_summary(summary_df)
    plt.show()

---

## Per-Session Results

Select a session to inspect. Change `SESSION` to switch.

In [None]:
SESSION = list(datasets.keys())[0]
ds = datasets[SESSION]
p_thresh = ds.spatial_1d.p_value_threshold
print(f"Session: {SESSION}")
print(
    f"Config: bin_width_mm={ds.spatial_1d.bin_width_mm}, "
    f"activity_sigma={ds.spatial_1d.activity_sigma}, "
    f"n_shuffles={ds.spatial_1d.n_shuffles}, "
    f"p_threshold={p_thresh}"
)

### Behavior Graph Overlay

Zone polylines from the behavior graph overlaid on the video frame.

In [None]:
if ds.graph_polylines is not None:
    plot_graph_overlay(
        ds.graph_polylines,
        ds.graph_mm_per_pixel,
        tube_order=ds.maze_cfg.tube_order,
        video_frame=ds.behavior_video_frame,
    )
    plt.show()
else:
    print("No behavior graph available in this bundle.")

### Occupancy

In [None]:
if ds.occupancy_time is not None and ds.edges_1d is not None:
    plot_occupancy_preview_1d(
        ds.trajectory_1d_filtered,
        ds.occupancy_time,
        ds.valid_mask,
        ds.edges_1d,
        trajectory_1d=ds.trajectory_1d,
        tube_boundaries=ds.tube_boundaries,
        tube_labels=ds.effective_tube_order,
    )
    plt.show()
else:
    print("Occupancy data not available in this bundle.")

### Diagnostics

Event count distribution, SI vs event count, and p-value vs event count.

In [None]:
plot_diagnostics(ds.unit_results, p_value_threshold=p_thresh)
plt.show()

### Significance vs Stability

In [None]:
plot_summary_scatter(ds.unit_results, p_value_threshold=p_thresh)
plt.show()

### Unit Summary Table

Sortable table with key metrics for all units.

In [None]:
table_rows = []
for uid, res in sorted(ds.unit_results.items()):
    is_sig = res.p_val < p_thresh
    is_stable = not np.isnan(res.stability_p_val) and res.stability_p_val < p_thresh
    table_rows.append({
        "unit_id": uid,
        "n_events": len(res.unit_data),
        "SI (bits/s)": round(res.si, 4),
        "p_val": round(res.p_val, 4),
        "sig": is_sig,
        "stability_r": round(res.stability_corr, 4) if np.isfinite(res.stability_corr) else np.nan,
        "stability_z": round(res.stability_z, 4) if np.isfinite(res.stability_z) else np.nan,
        "stability_p": round(res.stability_p_val, 4) if np.isfinite(res.stability_p_val) else np.nan,
        "stable": is_stable,
        "place_cell": is_sig and is_stable,
    })

unit_table = pd.DataFrame(table_rows).set_index("unit_id")
display(unit_table)

### Population Rate Map

Rate maps of all place cells sorted by peak position.

In [None]:
plot_shuffle_test_1d(
    ds.unit_results,
    ds.edges_1d,
    p_value_threshold=p_thresh,
    tube_boundaries=ds.tube_boundaries,
    tube_labels=ds.effective_tube_order,
)
plt.show()

### Per-Unit Shuffle Browser

Browse each unit's rate map, SI shuffle distribution, and stability shuffle distribution.

In [None]:
%matplotlib widget

fig_shuf, controls_shuf = create_shuffle_browser_1d(
    ds.unit_results,
    ds.edges_1d,
    p_value_threshold=p_thresh,
    tube_boundaries=ds.tube_boundaries,
    tube_labels=ds.effective_tube_order,
)
plt.show()
display(controls_shuf)

### Interactive Unit Browser

Browse individual units: rate maps (1st half / 2nd half / full), shuffle histograms, and calcium trace with events.

In [None]:
%matplotlib widget

fig_units, controls_units = browse_units_1d(ds)
plt.show()
display(controls_units)

### Place Cell Browser (Significant AND Stable Only)

In [None]:
%matplotlib widget

place_cell_results = ds.place_cells()
n_place_cells = len(place_cell_results)
print(f"Place cells (sig + stable): {n_place_cells} / {len(ds.unit_results)}")

if n_place_cells > 0:
    fig_pc, controls_pc = browse_units_1d(ds, unit_results=place_cell_results)
    plt.show()
    display(controls_pc)
else:
    print("No cells passed both significance and stability tests.")

### Place Cell Gallery

Rate maps and split-half stability for each place cell, sorted by stability p-value.

In [None]:
pc = place_cell_results
if len(pc) > 0:
    sorted_uids = sorted(pc.keys(), key=lambda u: pc[u].stability_p_val)
    n = len(sorted_uids)
    centers = 0.5 * (ds.edges_1d[:-1] + ds.edges_1d[1:])

    fig, axes = plt.subplots(n, 4, figsize=(14, 2.0 * n))
    if n == 1:
        axes = axes[np.newaxis, :]

    for i, uid in enumerate(sorted_uids):
        res = pc[uid]
        valid = np.isfinite(res.rate_map)

        # Column 0: Full rate map with 1st/2nd half overlay
        ax = axes[i, 0]
        ax.fill_between(centers, 0, np.where(valid, res.rate_map, 0),
                        alpha=0.15, color="black", where=valid)
        ax.plot(centers, res.rate_map, color="black", linewidth=1.2, label="Full")
        ax.plot(centers, np.where(np.isfinite(res.rate_map_first), res.rate_map_first, np.nan),
                color="steelblue", linewidth=0.8, alpha=0.7, label="1st")
        ax.plot(centers, np.where(np.isfinite(res.rate_map_second), res.rate_map_second, np.nan),
                color="coral", linewidth=0.8, alpha=0.7, label="2nd")
        if ds.tube_boundaries:
            for b in ds.tube_boundaries:
                ax.axvline(b, color="gray", linestyle=":", linewidth=0.5, alpha=0.5)
        ax.set_xlim(ds.edges_1d[0], ds.edges_1d[-1])
        ax.set_xticklabels([])
        ax.set_title(f"#{uid}  r={res.stability_corr:.2f}", fontsize=8)
        ax.tick_params(labelsize=6)
        if i == 0:
            ax.legend(fontsize=5, loc="upper right")

        # Column 1: Rate map heatmap (single row)
        ax = axes[i, 1]
        rm_row = res.rate_map[np.newaxis, :]
        rm_row = np.where(np.isfinite(rm_row), rm_row, 0)
        ax.imshow(rm_row, aspect="auto", cmap="jet", vmin=0, vmax=1, interpolation="nearest")
        ax.set_yticks([])
        ax.set_xticklabels([])
        ax.set_title(f"SI={res.si:.3f}", fontsize=8)

        # Column 2: SI shuffle
        ax = axes[i, 2]
        ax.hist(res.shuffled_sis, bins=20, color="gray", alpha=0.7, edgecolor="none")
        ax.axvline(res.si, color="red", linestyle="--", linewidth=1.5)
        sig_color = "green" if res.p_val < p_thresh else "red"
        ax.set_title(f"SI p={res.p_val:.3f}", fontsize=8, color=sig_color)
        ax.tick_params(labelsize=6)
        ax.set_xlabel("SI", fontsize=7)

        # Column 3: Stability shuffle
        ax = axes[i, 3]
        if len(res.shuffled_stability) > 0:
            ax.hist(res.shuffled_stability, bins=20, color="gray", alpha=0.7, edgecolor="none")
            ax.axvline(res.stability_corr, color="red", linestyle="--", linewidth=1.5)
            stab_color = "green" if (not np.isnan(res.stability_p_val) and res.stability_p_val < p_thresh) else "red"
            ax.set_title(f"Stab p={res.stability_p_val:.3f}", fontsize=8, color=stab_color)
        else:
            ax.text(0.5, 0.5, "N/A", ha="center", va="center", fontsize=8)
            ax.set_title("Stab", fontsize=8, color="gray")
        ax.tick_params(labelsize=6)
        ax.set_xlabel("r", fontsize=7)

    fig.tight_layout()
    plt.show()
else:
    print("No place cells to display.")