# Place Cell Analysis

Interactive notebook for analyzing place cells in 2D environment navigation.

This notebook runs the full workflow:
1. **Deconvolution** - Extract neural events using OASIS
2. **Event-Place Matching** - Match events to behavior positions
3. **Interactive Visualization** - Browse place cells with scrollable interface

---

**Note:** This notebook uses interactive widgets (progress bars, sliders). When working over SSH, use **Jupyter Lab** instead of VSCode's notebook extension for proper widget support:

```bash
# On remote machine
cd notebook
jupyter lab --no-browser --port=6006

# On local machine - set up SSH tunnel
ssh -L 6006:localhost:6006 user@remote-host

# Then open http://localhost:6006 in your browser
```

In [None]:
import sys
from pathlib import Path

# Add project root to path
project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import display
from tqdm.notebook import tqdm

from placecell.analysis import compute_occupancy_map, compute_unit_analysis
from placecell.behavior import build_event_place_dataframe, load_curated_unit_ids
from placecell.config import AppConfig, DataPathsConfig
from placecell.io import load_behavior_data, load_neural_data
from placecell.neural import build_event_index_dataframe, load_traces, run_deconvolution
from placecell.notebook import create_unit_browser
from placecell.visualization import plot_summary_scatter

## Configuration

In [None]:
# Paths - adjust these as needed
CONFIG_PATH = project_root / "placecell/config/example_pcell_config.yaml"
DATA_PATH = Path(
    #"/mnt/data/minizero_analysis/202512round/202511_analysis_placecell/"
    "/Volumes/ProcData/minizero_analysis/202512round/202511_analysis_placecell/"
    "20251205/WL25/WL25_20251205.yaml"
)
DATA_DIR = DATA_PATH.parent
OUTPUT_DIR = project_root / "output"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Load configs
cfg = AppConfig.from_yaml(CONFIG_PATH)
data_cfg = DataPathsConfig.from_yaml(DATA_PATH)

# Apply data-specific overrides (e.g., OASIS parameters)
cfg = cfg.with_data_overrides(data_cfg)

# Resolve data paths relative to data directory
neural_path = DATA_DIR / data_cfg.neural_path
neural_timestamp = DATA_DIR / data_cfg.neural_timestamp
behavior_position = DATA_DIR / data_cfg.behavior_position
behavior_timestamp = DATA_DIR / data_cfg.behavior_timestamp
curation_csv = (DATA_DIR / data_cfg.curation_csv) if data_cfg.curation_csv else None

print(f"Config: {CONFIG_PATH}")
print(f"Data: {DATA_PATH}")
print(f"Neural path: {neural_path}")
print(f"Neural timestamp: {neural_timestamp}")
print(f"Behavior position: {behavior_position}")
print(f"Behavior timestamp: {behavior_timestamp}")
print(f"Curation CSV: {curation_csv}")

In [None]:
# Extract config values
bodypart = cfg.behavior.bodypart
behavior_fps = cfg.behavior.behavior_fps
speed_threshold = cfg.behavior.speed_threshold
speed_window_frames = cfg.behavior.speed_window_frames
bins = cfg.behavior.spatial_map.bins
min_occupancy = cfg.behavior.spatial_map.min_occupancy
occupancy_sigma = cfg.behavior.spatial_map.occupancy_sigma
activity_sigma = cfg.behavior.spatial_map.activity_sigma
n_shuffles = cfg.behavior.spatial_map.n_shuffles
random_seed = cfg.behavior.spatial_map.random_seed
event_threshold_sigma = cfg.behavior.spatial_map.event_threshold_sigma
p_value_threshold = cfg.behavior.spatial_map.p_value_threshold
stability_threshold = cfg.behavior.spatial_map.stability_threshold
stability_method = cfg.behavior.spatial_map.stability_method
min_shift_seconds = cfg.behavior.spatial_map.min_shift_seconds
si_weight_mode = cfg.behavior.spatial_map.si_weight_mode

# Neural config
trace_name = cfg.neural.trace_name
neural_fps = cfg.neural.fps
max_units = cfg.neural.max_units
g = cfg.neural.oasis.g
baseline = cfg.neural.oasis.baseline
penalty = cfg.neural.oasis.penalty
s_min = cfg.neural.oasis.s_min

# Visualization settings
trace_time_window = 600.0  # 10 minutes window for trace display

print(f"Bodypart: {bodypart}")
print(f"Speed threshold: {speed_threshold} px/s")
print(f"Bins: {bins}")
print(f"Shuffles: {n_shuffles}")
print(f"Min shift: {min_shift_seconds}s")
print(f"SI weight mode: {si_weight_mode}")
print(f"Stability method: {stability_method}")
print(f"Trace name: {trace_name}")
print(f"OASIS g: {g}")

## Unit Subset (for quick iteration)

Set `UNIT_IDS` to a list of specific unit IDs to process only those units.
Set to `None` to process all units.

In [None]:
# Set to a list of unit IDs for quick iteration, e.g. [0, 5, 42]
# Set to None to process all units
UNIT_IDS = None

## Step 1: Deconvolution

Run OASIS deconvolution to extract neural events from calcium traces.

In [None]:
# Load traces
print(f"Loading traces from: {neural_path / (trace_name + '.zarr')}")
C_da = load_traces(neural_path, trace_name=trace_name)
all_unit_ids = list(map(int, C_da["unit_id"].values))
print(f"Total units in traces: {len(all_unit_ids)}")

# Filter by curation CSV if provided
if curation_csv is not None and curation_csv.exists():
    curated_ids = set(load_curated_unit_ids(curation_csv))
    all_unit_ids = [uid for uid in all_unit_ids if uid in curated_ids]
    print(f"After curation filter: {len(all_unit_ids)} units")

# Apply unit subset
if UNIT_IDS is not None:
    available = set(all_unit_ids)
    all_unit_ids = [uid for uid in UNIT_IDS if uid in available]
    missing = set(UNIT_IDS) - available
    if missing:
        print(f"Warning: requested unit IDs not found: {sorted(missing)}")
    print(f"Selected {len(all_unit_ids)} specific units: {all_unit_ids}")
elif max_units is not None and len(all_unit_ids) > max_units:
    all_unit_ids = all_unit_ids[:max_units]
    print(f"Limited to first {max_units} units (from config)")

print(f"Will process {len(all_unit_ids)} units")

# Run OASIS deconvolution
print(f"Running OASIS deconvolution (g={g})...")

good_unit_ids, C_list, S_list = run_deconvolution(
    C_da=C_da,
    unit_ids=all_unit_ids,
    g=g,
    baseline=baseline,
    penalty=penalty,
    s_min=s_min,
    progress_bar=lambda x: tqdm(x, desc="Deconvolving units"),
)

print(f"Successfully deconvolved {len(good_unit_ids)} units")

In [None]:
# Build event index from deconvolution results
event_index_df = build_event_index_dataframe(good_unit_ids, S_list)
event_index_csv = OUTPUT_DIR / "event_index_notebook.csv"
event_index_df.to_csv(event_index_csv, index=False)
print(f"Event index: {len(event_index_df)} events from {event_index_df['unit_id'].nunique()} units")
print(f"Saved to: {event_index_csv}")

In [None]:
%matplotlib widget
import ipywidgets as widgets
from matplotlib.lines import Line2D

# Interactive deconvolution preview
fig_deconv, ax_deconv = plt.subplots(1, 1, figsize=(10, 2.5))
fig_deconv.canvas.toolbar_visible = False
fig_deconv.canvas.header_visible = False
fig_deconv.canvas.layout.width = "100%"

n_good = len(good_unit_ids)
_deconv_time_window = 600.0  # 10 min window
_deconv_max_time = C_da.sizes["frame"] / neural_fps


def _render_deconv(unit_idx, t_start):
    ax_deconv.clear()
    uid = good_unit_ids[unit_idx]
    trace = C_da.sel(unit_id=uid).values
    t_full = np.arange(len(trace)) / neural_fps
    spikes = S_list[unit_idx]

    t_end = min(_deconv_max_time, t_start + _deconv_time_window)
    mask = (t_full >= t_start) & (t_full <= t_end)

    ax_deconv.plot(t_full[mask], trace[mask], "b-", linewidth=0.5, alpha=0.7)

    # Spike stems in time window
    spike_frames = np.nonzero(spikes > 0)[0]
    spike_times = spike_frames / neural_fps
    spike_mask = (spike_times >= t_start) & (spike_times <= t_end)
    if np.any(spike_mask):
        st = spike_times[spike_mask]
        sa = spikes[spike_frames[spike_mask]]
        y_min, y_max = ax_deconv.get_ylim()
        amp_max = sa.max() if sa.max() > 0 else 1.0
        max_spike_h = (y_max - y_min) * 0.3
        for t_s, a_s in zip(st, sa):
            h = (a_s / amp_max) * max_spike_h
            ax_deconv.plot([t_s, t_s], [y_min, y_min + h], color="red", lw=0.8)

    ax_deconv.set_xlim(t_start, t_end)
    ax_deconv.set_ylabel(trace_name, fontsize=9)
    ax_deconv.set_xlabel("Time (s)", fontsize=9)
    ax_deconv.set_title(
        f"Deconvolution Preview â€” Unit {uid} ({unit_idx + 1}/{n_good})",
        fontsize=10,
    )
    ax_deconv.legend(
        handles=[
            Line2D([0], [0], color="blue", linewidth=0.5, label="Fluorescence"),
            Line2D([0], [0], color="red", linewidth=1.5, label="Deconvolved spikes"),
        ],
        loc="upper right",
        fontsize=8,
        framealpha=0.9,
    )
    fig_deconv.canvas.draw_idle()


# Widgets
_unit_slider_d = widgets.IntSlider(
    value=0, min=0, max=n_good - 1, step=1,
    description="Unit:", continuous_update=False,
    layout=widgets.Layout(width="100%"),
)
_time_slider_d = widgets.FloatSlider(
    value=0, min=0, max=max(0, _deconv_max_time - _deconv_time_window),
    step=10, description="Time (s):", continuous_update=False,
    layout=widgets.Layout(width="100%"),
)
_prev_btn_d = widgets.Button(description="< Prev", layout=widgets.Layout(width="70px"))
_next_btn_d = widgets.Button(description="Next >", layout=widgets.Layout(width="70px"))


def _on_prev_d(_):
    _unit_slider_d.value = (_unit_slider_d.value - 1) % n_good


def _on_next_d(_):
    _unit_slider_d.value = (_unit_slider_d.value + 1) % n_good


_prev_btn_d.on_click(_on_prev_d)
_next_btn_d.on_click(_on_next_d)


def _update_deconv(_=None):
    _render_deconv(_unit_slider_d.value, _time_slider_d.value)


_unit_slider_d.observe(_update_deconv, names="value")
_time_slider_d.observe(_update_deconv, names="value")

_nav_d = widgets.HBox(
    [_prev_btn_d, _unit_slider_d, _next_btn_d],
    layout=widgets.Layout(width="100%"),
)
_controls_d = widgets.VBox([_nav_d, _time_slider_d], layout=widgets.Layout(width="100%"))

_render_deconv(0, 0)
plt.show()
display(_controls_d)

## Step 2: Event-Place Matching

Match neural events to behavior positions.

In [None]:
# Build event-place dataframe
print("Matching events to behavior positions...")

event_place_df = build_event_place_dataframe(
    event_index_path=event_index_csv,
    neural_timestamp_path=neural_timestamp,
    behavior_position_path=behavior_position,
    behavior_timestamp_path=behavior_timestamp,
    bodypart=bodypart,
    behavior_fps=behavior_fps,
    speed_threshold=speed_threshold,
    speed_window_frames=speed_window_frames,
)

print(f"Event-place entries: {len(event_place_df)}")
print(f"Unique units: {event_place_df['unit_id'].nunique()}")

# Save event-place (optional)
event_place_csv = OUTPUT_DIR / "event_place_notebook.csv"
event_place_df.to_csv(event_place_csv, index=False)
print(f"Saved event-place to: {event_place_csv}")

## Step 3: Analysis & Visualization

Compute spatial tuning metrics (rate maps, spatial information, stability) and visualize results.

In [None]:
# Filter by speed threshold
df_filtered = event_place_df[event_place_df["speed"] >= speed_threshold].copy()
df_all_events = event_index_df.copy()

print(f"Speed-filtered events: {len(df_filtered)}")
print(f"Unique units after filtering: {df_filtered['unit_id'].nunique()}")

In [None]:
# Load behavior data
trajectory_with_speed, trajectory_df = load_behavior_data(
    behavior_position=behavior_position,
    behavior_timestamp=behavior_timestamp,
    bodypart=bodypart,
    speed_window_frames=speed_window_frames,
    speed_threshold=speed_threshold,
)

print(f"Trajectory frames: {len(trajectory_df)}")

In [None]:
# Preview: behavior data (raw vs filtered trajectory + speed histogram)
fig_beh, (ax_raw, ax_filt, ax_hist) = plt.subplots(1, 3, figsize=(10, 3.5))

# Raw trajectory (all speeds)
ax_raw.plot(trajectory_with_speed["x"], trajectory_with_speed["y"],
            "k-", linewidth=0.3, alpha=0.5)
ax_raw.set_title(f"All frames ({len(trajectory_with_speed)})")
ax_raw.set_aspect("equal")
ax_raw.axis("off")

# Speed-filtered trajectory
ax_filt.plot(trajectory_df["x"], trajectory_df["y"],
             "k-", linewidth=0.3, alpha=0.5)
ax_filt.set_title(f"Speed > {speed_threshold} px/s ({len(trajectory_df)})")
ax_filt.set_aspect("equal")
ax_filt.axis("off")

# Speed histogram
all_speeds = trajectory_with_speed["speed"].dropna()
speed_max = np.percentile(all_speeds, 99)
ax_hist.hist(all_speeds.clip(upper=speed_max), bins=50,
             color="gray", edgecolor="black", alpha=0.7)
ax_hist.axvline(speed_threshold, color="red", linestyle="--", linewidth=2,
                label=f"Threshold: {speed_threshold}")
ax_hist.set_xlabel("Speed (px/s)")
ax_hist.set_ylabel("Count")
ax_hist.set_title("Speed Distribution")
ax_hist.legend()

fig_beh.tight_layout()
plt.show()

In [None]:
# Compute occupancy map
occupancy_time, valid_mask, x_edges, y_edges = compute_occupancy_map(
    trajectory_df=trajectory_df,
    bins=bins,
    behavior_fps=behavior_fps,
    occupancy_sigma=occupancy_sigma,
    min_occupancy=min_occupancy,
)

print(f"Occupancy map shape: {occupancy_time.shape}")
print(f"Valid bins: {valid_mask.sum()} / {valid_mask.size}")

In [None]:
# Preview: occupancy map with trajectory and valid mask
fig_occ, (ax_traj, ax_occ, ax_mask) = plt.subplots(1, 3, figsize=(10, 3.5))

# Trajectory on occupancy
ax_traj.plot(trajectory_df["x"], trajectory_df["y"],
             "k-", alpha=0.5, linewidth=0.3)
ax_traj.set_title("Trajectory (filtered)")
ax_traj.set_aspect("equal")
ax_traj.axis("off")

# Occupancy heatmap
im = ax_occ.imshow(
    occupancy_time.T, origin="lower",
    extent=[x_edges[0], x_edges[-1], y_edges[0], y_edges[-1]],
    cmap="hot", aspect="equal",
)
ax_occ.set_title(f"Occupancy (sigma={occupancy_sigma})")
plt.colorbar(im, ax=ax_occ, label="Time (s)")

# Valid mask
ax_mask.imshow(
    valid_mask.T.astype(float), origin="lower",
    extent=[x_edges[0], x_edges[-1], y_edges[0], y_edges[-1]],
    cmap="gray", aspect="equal",
)
ax_mask.set_title(f"Valid bins ({valid_mask.sum()}/{valid_mask.size}, min={min_occupancy}s)")
ax_mask.axis("off")

fig_occ.tight_layout()
plt.show()

In [None]:
# Load neural data (for visualization)
traces, max_proj, footprints = load_neural_data(
    neural_path=neural_path,
    trace_name=trace_name,
)

print(f"Traces shape: {traces.shape if traces is not None else 'None'}")
print(f"Max proj shape: {max_proj.shape if max_proj is not None else 'None'}")
print(f"Footprints shape: {footprints.shape if footprints is not None else 'None'}")

In [None]:
# Plot max projection, cell footprints, and overlay
if max_proj is not None and footprints is not None:
    fig_fp, (ax_mp, ax_fp, ax_ov) = plt.subplots(1, 3, figsize=(10, 3.5))

    unit_ids_fp = footprints.coords["unit_id"].values
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]

    # Left: max projection
    ax_mp.imshow(max_proj, cmap="gray", aspect="equal")
    ax_mp.set_title("Max Projection")
    ax_mp.axis("off")

    # Middle: colored footprint contours on black background
    ax_fp.imshow(np.zeros_like(max_proj), cmap="gray", aspect="equal")
    for i, uid in enumerate(unit_ids_fp):
        fp = footprints.sel(unit_id=uid).values
        if fp.max() > 0:
            ax_fp.contour(fp, levels=[fp.max() * 0.3], colors=[colors[i % len(colors)]], linewidths=1)
    ax_fp.set_title(f"Cell Footprints ({len(unit_ids_fp)})")
    ax_fp.axis("off")

    # Right: overlay
    ax_ov.imshow(max_proj, cmap="gray", aspect="equal")
    for i, uid in enumerate(unit_ids_fp):
        fp = footprints.sel(unit_id=uid).values
        if fp.max() > 0:
            ax_ov.contour(fp, levels=[fp.max() * 0.3], colors=[colors[i % len(colors)]], linewidths=1)
    ax_ov.set_title("Overlay")
    ax_ov.axis("off")

    fig_fp.tight_layout()
    plt.show()
else:
    print("Max projection or footprints not available, skipping plot.")

In [None]:
# Set random seed
if random_seed is not None:
    np.random.seed(random_seed)

# Compute analysis for each unit (limited to the deconvolved units)
unique_units = sorted(df_filtered["unit_id"].unique())
unique_units = [uid for uid in unique_units if uid in good_unit_ids]
n_units = len(unique_units)
print(f"Computing analysis for {n_units} units...")

unit_results = {}
for unit_id in tqdm(unique_units, desc="Computing unit analysis"):
    result = compute_unit_analysis(
        unit_id=unit_id,
        df_filtered=df_filtered,
        trajectory_df=trajectory_df,
        occupancy_time=occupancy_time,
        valid_mask=valid_mask,
        x_edges=x_edges,
        y_edges=y_edges,
        activity_sigma=activity_sigma,
        event_threshold_sigma=event_threshold_sigma,
        n_shuffles=n_shuffles,
        behavior_fps=behavior_fps,
        min_occupancy=min_occupancy,
        occupancy_sigma=occupancy_sigma,
        stability_threshold=stability_threshold,
        stability_method=stability_method,
        min_shift_seconds=min_shift_seconds,
        si_weight_mode=si_weight_mode,
    )

    # Visualization data
    vis_data_above = result["events_above_threshold"]
    vis_data_below = pd.DataFrame()
    if df_all_events is not None:
        unit_all_events = df_all_events[df_all_events["unit_id"] == unit_id]
        vis_data_below = unit_all_events[unit_all_events["s"] > result["vis_threshold"]]

    # Trace data
    trace_data = None
    trace_times = None
    if traces is not None:
        try:
            trace_data = traces.sel(unit_id=int(unit_id)).values
            trace_times = np.arange(len(trace_data)) / neural_fps
        except (KeyError, IndexError):
            pass

    unit_results[unit_id] = {
        "rate_map": result["rate_map"],
        "si": result["si"],
        "shuffled_sis": result["shuffled_sis"],
        "p_val": result["p_val"],
        "stability_corr": result["stability_corr"],
        "stability_z": result["stability_z"],
        "stability_p_val": result["stability_p_val"],
        "rate_map_first": result["rate_map_first"],
        "rate_map_second": result["rate_map_second"],
        "vis_data_above": vis_data_above,
        "vis_data_below": vis_data_below,
        "unit_data": result["unit_data"],
        "trace_data": trace_data,
        "trace_times": trace_times,
    }

print(f"Done. Computed analysis for {len(unit_results)} units.")

## Summary Scatter Plot

In [None]:
plot_summary_scatter(
    unit_results=unit_results,
    p_value_threshold=p_value_threshold,
    stability_threshold=stability_threshold,
)
plt.show()

## Interactive Cell Browser

Use the slider to scroll through cells. Use the time slider to scroll through the trace.

In [None]:
%matplotlib widget

fig, controls = create_unit_browser(
    unit_results=unit_results,
    unique_units=unique_units,
    trajectory_df=trajectory_df,
    df_all_events=df_all_events,
    max_proj=max_proj,
    footprints=footprints,
    x_edges=x_edges,
    y_edges=y_edges,
    occupancy_time=occupancy_time,
    trace_name=trace_name,
    neural_fps=neural_fps,
    speed_threshold=speed_threshold,
    p_value_threshold=p_value_threshold,
    stability_threshold=stability_threshold,
    trace_time_window=trace_time_window,
)

# Display both figure and controls
plt.show()
display(controls)