# Place Cell Results Viewer

Load saved `.pcellbundle` results and inspect them without re-running the
analysis pipeline. Supports loading a single bundle or multiple bundles for
cross-session comparison.

---

**Note:** Interactive widgets require **Jupyter Lab**:

```bash
cd notebook && jupyter lab --no-browser --port=6006
```

In [None]:
import sys
from pathlib import Path

project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import display

from placecell.analysis import compute_place_field_mask
from placecell.dataset import BasePlaceCellDataset
from placecell.notebook import browse_units
from placecell.visualization import (
    plot_behavior_preview,
    plot_coverage,
    plot_diagnostics,
    plot_footprints_filled,
    plot_occupancy_preview,
    plot_position_and_traces_2d,
    plot_session_summary,
    plot_summary_scatter,
)

## Load Bundles

List the `.pcellbundle` paths to load. For a single session, use one entry.

In [None]:
BUNDLE_PATHS = [
    Path("../user_data/bundles/WL25_20251205.pcellbundle"),
]

datasets = {}
for bp in BUNDLE_PATHS:
    name = bp.stem.replace(".pcellbundle", "")
    ds = BasePlaceCellDataset.load_bundle(bp)
    datasets[name] = ds
    s = ds.summary()
    print(
        f"{name}: {s['n_total']} units, "
        f"{s['n_sig']} sig, {s['n_stable']} stable, "
        f"{s['n_place_cells']} place cells"
    )

## Cross-Session Summary

Counts and proportions of significant, stable, and place cell units across sessions.
Only shown when multiple bundles are loaded.

In [None]:
rows = []
for name, ds in datasets.items():
    rows.append({"dataset": name, **ds.summary()})
summary_df = pd.DataFrame(rows)
display(summary_df)

if len(datasets) > 1:
    plot_session_summary(summary_df)
    plt.show()

---

## Per-Session Results

Select a session to inspect. Change `SESSION` to switch.

In [None]:
SESSION = list(datasets.keys())[0]
ds = datasets[SESSION]
print(f"Session: {SESSION}")
print(f"Config: bins={ds.spatial.bins}, "
      f"activity_sigma={ds.spatial.activity_sigma}, "
      f"n_shuffles={ds.spatial.n_shuffles}, "
      f"p_threshold={ds.spatial.p_value_threshold}")

### Behavior & Occupancy

In [None]:
if ds.trajectory is not None and ds.trajectory_filtered is not None:
    plot_behavior_preview(
        ds.trajectory, ds.trajectory_filtered,
        ds.cfg.behavior.speed_threshold,
        speed_unit="mm/s" if ds.mm_per_px else "px/s",
    )
    plt.show()
else:
    print("Trajectory data not available in this bundle.")

In [None]:
if ds.occupancy_time is not None:
    plot_occupancy_preview(
        ds.trajectory_filtered, ds.occupancy_time,
        ds.valid_mask, ds.x_edges, ds.y_edges,
    )
    plt.show()
else:
    print("Occupancy data not available.")

### Cell Footprints

In [None]:
if ds.max_proj is not None and ds.footprints is not None:
    plot_footprints_filled(ds.max_proj, ds.footprints)
    plt.show()
else:
    print("Max projection or footprints not available.")

### Diagnostics

Event count distribution, SI vs event count, and p-value vs event count.

In [None]:
plot_diagnostics(ds.unit_results, p_value_threshold=ds.spatial.p_value_threshold)
plt.show()

### Significance vs Stability

In [None]:
plot_summary_scatter(
    ds.unit_results, p_value_threshold=ds.spatial.p_value_threshold,
    n_shuffles=ds.spatial.n_shuffles,
    min_shift_seconds=ds.spatial.min_shift_seconds,
)
plt.show()

### Trajectory + Rate Map Examples

Trajectory with event locations (top) and rate map with place field contour (bottom)
for 4 selected units. Change `GALLERY_UIDS` to pick specific units.

In [None]:
place_cell_results = ds.place_cells()
p_thresh = ds.spatial.p_value_threshold

# Pick 4 random place cells (or all units if fewer than 4 place cells)
'''
rng = np.random.default_rng(54)
pool = list(place_cell_results.keys()) if place_cell_results else list(ds.unit_results.keys())
GALLERY_UIDS = sorted(rng.choice(pool, size=min(4, len(pool)), replace=False))
'''
GALLERY_UIDS = [35, 42, 55, 46]  # Example fixed UIDs for demonstration
print(f"Gallery units: {GALLERY_UIDS}")

ext = [ds.x_edges[0], ds.x_edges[-1], ds.y_edges[0], ds.y_edges[-1]]
n = len(GALLERY_UIDS)

# Pre-compute rate maps to find common vmin/vmax
rate_maps = []
for uid in GALLERY_UIDS:
    rm = ds.unit_results[uid].rate_map.copy()
    rm[~ds.valid_mask] = np.nan
    rate_maps.append(rm)
vmin = 0
vmax = max(np.nanmax(rm) for rm in rate_maps)

fig, axes = plt.subplots(2, n, figsize=(3.2 * n + 0.8, 6))
if n == 1:
    axes = axes[:, np.newaxis]

for col, uid in enumerate(GALLERY_UIDS):
    res = ds.unit_results[uid]

    # Top: trajectory + events (matching browser styling)
    ax_t = axes[0, col]
    if ds.trajectory_filtered is not None:
        ax_t.plot(
            ds.trajectory_filtered["x"], ds.trajectory_filtered["y"],
            "k-", alpha=1, linewidth=2, rasterized=True, zorder=1,
        )
    vis_data = res.vis_data_above
    if vis_data is not None and len(vis_data) > 0:
        amps = vis_data["s"].values
        x_vals = vis_data["x"].values
        y_vals = vis_data["y"].values

        # Occupancy-normalized alpha (same as browser)
        x_bin = np.clip(np.digitize(x_vals, ds.x_edges) - 1, 0, len(ds.x_edges) - 2)
        y_bin = np.clip(np.digitize(y_vals, ds.y_edges) - 1, 0, len(ds.y_edges) - 2)
        event_occ = ds.occupancy_time[x_bin, y_bin]
        norm_amps = amps / np.maximum(event_occ, 0.01)
        norm_max = np.max(norm_amps) if len(norm_amps) > 0 and np.max(norm_amps) > 0 else 1.0
        alphas = norm_amps / norm_max

        ax_t.scatter(
            x_vals, y_vals,
            c="red", s=100, alpha=alphas, edgecolors="none",
            rasterized=True, zorder=2,
        )
    ax_t.set_xlim(ext[0], ext[1])
    ax_t.set_ylim(ext[2], ext[3])
    ax_t.set_aspect("equal")
    ax_t.set_title(f"Cell {uid}", fontsize=9)
    ax_t.axis("off")

    # Bottom: rate map + place field contour (common color scale)
    ax_r = axes[1, col]
    rm = rate_maps[col]
    im = ax_r.imshow(
        rm.T, origin="lower", extent=ext, cmap="inferno",
        aspect="equal", vmin=vmin, vmax=vmax,
    )

    field_mask = compute_place_field_mask(
        res.rate_map,
        threshold=ds.spatial.place_field_threshold,
        min_bins=ds.spatial.place_field_min_bins,
        shuffled_rate_p95=res.shuffled_rate_p95,
    )
    if np.any(field_mask):
        ax_r.contour(
            field_mask.T.astype(float), levels=[0.5],
            colors="white", linewidths=1, extent=ext, origin="lower",
        )
    ax_r.set_title(f"SI={res.si:.3f}  p={res.p_val:.3f}", fontsize=8)
    ax_r.axis("off")

# Reserve right margin for colorbar, then position it
fig.tight_layout(rect=[0, 0, 0.91, 1])
bot = axes[1, 0].get_position()
cax = fig.add_axes([0.93, bot.y0, 0.015, bot.height])
cbar = fig.colorbar(im, cax=cax, orientation="vertical")
cbar.set_label("Normalized spatial neural activity rate (events/s)", fontsize=9)

plt.show()

### Unit Summary Table

Sortable table with key metrics for all units.

In [None]:
p_thresh = ds.spatial.p_value_threshold
table_rows = []
for uid, res in sorted(ds.unit_results.items()):
    is_sig = res.p_val < p_thresh
    is_stable = (
        not np.isnan(res.stability_p_val) and res.stability_p_val < p_thresh
    )
    stab_r = (
        round(res.stability_corr, 4) if np.isfinite(res.stability_corr) else np.nan
    )
    stab_z = (
        round(res.stability_z, 4) if np.isfinite(res.stability_z) else np.nan
    )
    stab_p = (
        round(res.stability_p_val, 4)
        if np.isfinite(res.stability_p_val)
        else np.nan
    )
    table_rows.append({
        "unit_id": uid,
        "n_events": len(res.unit_data),
        "SI (bits/s)": round(res.si, 4),
        "p_val": round(res.p_val, 4),
        "sig": is_sig,
        "stability_r": stab_r,
        "stability_z": stab_z,
        "stability_p": stab_p,
        "stable": is_stable,
        "place_cell": is_sig and is_stable,
    })

unit_table = pd.DataFrame(table_rows).set_index("unit_id")
display(unit_table)

### Place Field Coverage

In [None]:
place_cell_results = ds.place_cells()
n_place_cells = len(place_cell_results)
print(f"Place cells (sig + stable): {n_place_cells} / {len(ds.unit_results)}")

if n_place_cells > 0:
    coverage_map, _, _ = ds.coverage()
    plot_coverage(
        coverage_map,
        ds.x_edges, ds.y_edges, ds.valid_mask, n_place_cells,
    )
    plt.show()
else:
    print("No place cells â€” skipping coverage plot.")

### Speed + Place Cell Traces

Animal speed and 20 example place cell calcium traces, time-synced.
Cells are sorted by spatial information (highest first).

In [None]:
if place_cell_results and ds.trajectory is not None:
    plot_position_and_traces_2d(
        ds.trajectory,
        place_cell_results,
        behavior_fps=ds.cfg.behavior.behavior_fps,
        speed_threshold=ds.cfg.behavior.speed_threshold,
        trajectory_filtered=ds.trajectory_filtered,
        speed_unit="mm/s" if ds.mm_per_px else "px/s",
    )
    plt.show()
else:
    print("No place cells or trajectory data available.")

### Interactive Unit Browser

Browse all units with full detail panels (footprint, trajectory, rate maps, trace).

In [None]:
%matplotlib widget

fig, controls = browse_units(ds)
plt.show()
display(controls)

### Place Cell Browser (Significant AND Stable Only)

In [None]:
%matplotlib widget

if n_place_cells > 0:
    fig_pc, controls_pc = browse_units(
        ds, unit_results=place_cell_results,
        place_field_threshold=ds.spatial.place_field_threshold,
    )
    plt.show()
    display(controls_pc)
else:
    print("No cells passed both significance and stability tests.")

### Place Cell Gallery (disabled)

Full rate maps and split-half stability for every place cell.
Uncomment the cell below to enable.

In [None]:
# Uncomment to show the full place cell gallery (one row per place cell)
# pc = place_cell_results
# if len(pc) > 0:
#     sorted_uids = sorted(pc.keys(), key=lambda u: pc[u].stability_p_val)
#     ...
pass