# Place Cell Results Viewer

Load saved `.pcellbundle` results and inspect them without re-running the
analysis pipeline. Supports loading a single bundle or multiple bundles for
cross-session comparison.

---

**Note:** Interactive widgets require **Jupyter Lab**:

```bash
cd notebook && jupyter lab --no-browser --port=6006
```

In [None]:
import sys
from pathlib import Path

project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import display

from placecell.dataset import BasePlaceCellDataset
from placecell.notebook import browse_units
from placecell.visualization import (
    plot_behavior_preview,
    plot_coverage,
    plot_diagnostics,
    plot_footprints,
    plot_occupancy_preview,
    plot_session_summary,
    plot_summary_scatter,
)

## Load Bundles

List the `.pcellbundle` paths to load. For a single session, use one entry.

In [None]:
BUNDLE_PATHS = [
    Path("bundles/WL25_20251205.pcellbundle"),
]

datasets = {}
for bp in BUNDLE_PATHS:
    name = bp.stem.replace(".pcellbundle", "")
    ds = BasePlaceCellDataset.load_bundle(bp)
    datasets[name] = ds
    s = ds.summary()
    print(
        f"{name}: {s['n_total']} units, "
        f"{s['n_sig']} sig, {s['n_stable']} stable, "
        f"{s['n_place_cells']} place cells"
    )

## Cross-Session Summary

Counts and proportions of significant, stable, and place cell units across sessions.
Only shown when multiple bundles are loaded.

In [None]:
rows = []
for name, ds in datasets.items():
    rows.append({"dataset": name, **ds.summary()})
summary_df = pd.DataFrame(rows)
display(summary_df)

if len(datasets) > 1:
    plot_session_summary(summary_df)
    plt.show()

---

## Per-Session Results

Select a session to inspect. Change `SESSION` to switch.

In [None]:
SESSION = list(datasets.keys())[0]
ds = datasets[SESSION]
print(f"Session: {SESSION}")
print(f"Config: bins={ds.spatial.bins}, "
      f"activity_sigma={ds.spatial.activity_sigma}, "
      f"n_shuffles={ds.spatial.n_shuffles}, "
      f"p_threshold={ds.spatial.p_value_threshold}")

### Behavior & Occupancy

In [None]:
if ds.trajectory is not None and ds.trajectory_filtered is not None:
    plot_behavior_preview(
        ds.trajectory, ds.trajectory_filtered,
        ds.cfg.behavior.speed_threshold,
        speed_unit="mm/s" if ds.mm_per_px else "px/s",
    )
    plt.show()
else:
    print("Trajectory data not available in this bundle.")

In [None]:
if ds.occupancy_time is not None:
    plot_occupancy_preview(
        ds.trajectory_filtered, ds.occupancy_time,
        ds.valid_mask, ds.x_edges, ds.y_edges,
    )
    plt.show()
else:
    print("Occupancy data not available.")

### Cell Footprints

In [None]:
if ds.max_proj is not None and ds.footprints is not None:
    plot_footprints(ds.max_proj, ds.footprints)
    plt.show()
else:
    print("Max projection or footprints not available.")

### Diagnostics

Event count distribution, SI vs event count, and p-value vs event count.

In [None]:
plot_diagnostics(ds.unit_results, p_value_threshold=ds.spatial.p_value_threshold)
plt.show()

### Significance vs Stability

In [None]:
plot_summary_scatter(ds.unit_results, p_value_threshold=ds.spatial.p_value_threshold)
plt.show()

### Unit Summary Table

Sortable table with key metrics for all units.

In [None]:
p_thresh = ds.spatial.p_value_threshold
table_rows = []
for uid, res in sorted(ds.unit_results.items()):
    is_sig = res.p_val < p_thresh
    is_stable = (
        not np.isnan(res.stability_p_val) and res.stability_p_val < p_thresh
    )
    stab_r = (
        round(res.stability_corr, 4) if np.isfinite(res.stability_corr) else np.nan
    )
    stab_z = (
        round(res.stability_z, 4) if np.isfinite(res.stability_z) else np.nan
    )
    stab_p = (
        round(res.stability_p_val, 4)
        if np.isfinite(res.stability_p_val)
        else np.nan
    )
    table_rows.append({
        "unit_id": uid,
        "n_events": len(res.unit_data),
        "SI (bits/s)": round(res.si, 4),
        "p_val": round(res.p_val, 4),
        "sig": is_sig,
        "stability_r": stab_r,
        "stability_z": stab_z,
        "stability_p": stab_p,
        "stable": is_stable,
        "place_cell": is_sig and is_stable,
    })

unit_table = pd.DataFrame(table_rows).set_index("unit_id")
display(unit_table)

### Place Field Coverage

In [None]:
place_cell_results = ds.place_cells()
n_place_cells = len(place_cell_results)
print(f"Place cells (sig + stable): {n_place_cells} / {len(ds.unit_results)}")

if n_place_cells > 0:
    coverage_map, n_cells_arr, coverage_frac = ds.coverage()
    plot_coverage(
        coverage_map, n_cells_arr, coverage_frac,
        ds.x_edges, ds.y_edges, ds.valid_mask, n_place_cells,
    )
    plt.show()
else:
    print("No place cells â€” skipping coverage plot.")

### Interactive Unit Browser

Browse all units with full detail panels (footprint, trajectory, rate maps, trace).

In [None]:
%matplotlib widget

fig, controls = browse_units(ds)
plt.show()
display(controls)

### Place Cell Browser (Significant AND Stable Only)

In [None]:
%matplotlib widget

if n_place_cells > 0:
    fig_pc, controls_pc = browse_units(
        ds, unit_results=place_cell_results,
        place_field_threshold=ds.spatial.place_field_threshold,
    )
    plt.show()
    display(controls_pc)
else:
    print("No cells passed both significance and stability tests.")

### Place Cell Gallery

Rate maps (with place field contours) and split-half stability for each place cell,
sorted by stability p-value (most stable first).

In [None]:
from placecell.analysis import compute_place_field_mask

pc = place_cell_results
if len(pc) > 0:
    sorted_uids = sorted(pc.keys(), key=lambda u: pc[u].stability_p_val)
    n = len(sorted_uids)
    ext = [ds.x_edges[0], ds.x_edges[-1], ds.y_edges[0], ds.y_edges[-1]]
    pf_thresh = ds.spatial.place_field_threshold
    pf_min_bins = ds.spatial.place_field_min_bins
    p_thresh = ds.spatial.p_value_threshold

    fig, axes = plt.subplots(n, 5, figsize=(12, 2.4 * n))
    if n == 1:
        axes = axes[np.newaxis, :]

    hist_kw = dict(bins=20, color="gray", alpha=0.7, edgecolor="black", linewidth=0.5)

    for i, uid in enumerate(sorted_uids):
        res = pc[uid]

        # Column 0: 1st half
        rm1 = res.rate_map_first.copy()
        rm1[~ds.valid_mask] = np.nan
        axes[i, 0].imshow(
            rm1.T, origin="lower", extent=ext, cmap="inferno", aspect="equal",
        )
        axes[i, 0].set_title(f"#{uid}  1st half", fontsize=8)
        axes[i, 0].axis("off")

        # Column 1: 2nd half
        rm2 = res.rate_map_second.copy()
        rm2[~ds.valid_mask] = np.nan
        axes[i, 1].imshow(
            rm2.T, origin="lower", extent=ext, cmap="inferno", aspect="equal",
        )
        axes[i, 1].set_title(
            f"2nd half  r={res.stability_corr:.2f}", fontsize=8,
        )
        axes[i, 1].axis("off")

        # Column 2: full session with place field contour
        rm_full = res.rate_map.copy()
        rm_full[~ds.valid_mask] = np.nan
        axes[i, 2].imshow(
            rm_full.T, origin="lower", extent=ext, cmap="inferno", aspect="equal",
        )

        field_mask = compute_place_field_mask(
            res.rate_map, res.shuffled_rate_p95,
            threshold=pf_thresh, min_bins=pf_min_bins,
        )
        if np.any(field_mask):
            axes[i, 2].contour(
                field_mask.T.astype(float), levels=[0.5],
                colors="red", linewidths=1, extent=ext, origin="lower",
            )
        axes[i, 2].set_title(f"Full  SI={res.si:.3f}", fontsize=8)
        axes[i, 2].axis("off")

        # Column 3: SI shuffle distribution
        axes[i, 3].hist(res.shuffled_sis, **hist_kw)
        axes[i, 3].axvline(
            res.si, color="red", linestyle="--", linewidth=1.5,
        )
        sig_color = "green" if res.p_val < p_thresh else "red"
        axes[i, 3].set_title(
            f"SI p={res.p_val:.3f}", fontsize=8, color=sig_color,
        )
        axes[i, 3].tick_params(labelsize=6)
        axes[i, 3].set_xlabel("SI", fontsize=7)

        # Column 4: stability shuffle distribution
        if len(res.shuffled_stability) > 0:
            axes[i, 4].hist(res.shuffled_stability, **hist_kw)
            axes[i, 4].axvline(
                res.stability_corr, color="red", linestyle="--", linewidth=1.5,
            )
            is_stable = (
                not np.isnan(res.stability_p_val)
                and res.stability_p_val < p_thresh
            )
            stab_color = "green" if is_stable else "red"
            axes[i, 4].set_title(
                f"Stab p={res.stability_p_val:.3f}",
                fontsize=8, color=stab_color,
            )
        else:
            axes[i, 4].text(
                0.5, 0.5, "N/A", ha="center", va="center", fontsize=8,
            )
            axes[i, 4].set_title("Stab", fontsize=8, color="gray")
        axes[i, 4].tick_params(labelsize=6)
        axes[i, 4].set_xlabel("r", fontsize=7)

    fig.tight_layout()
    plt.show()
else:
    print("No place cells to display.")