# Place Cell Analysis

Interactive notebook for analyzing place cells in 2D environment navigation. Equivalent to:
```bash
pdm run pcell workflow visualize --config placecell/config/example_pcell_config.yaml --data user_data/WL25_20251201/WL25_20251201.yaml
```

This notebook runs the full workflow:
1. **Deconvolution** - Extract neural events using OASIS
2. **Event-Place Matching** - Match events to behavior positions
3. **Interactive Visualization** - Browse place cells with scrollable interface

---

**Note:** This notebook uses interactive widgets (progress bars, sliders). When working over SSH, use **Jupyter Lab** instead of VSCode's notebook extension for proper widget support:

```bash
# On remote machine
cd notebook
jupyter lab --no-browser --port=6006

# On local machine - set up SSH tunnel
ssh -L 6006:localhost:6006 user@remote-host

# Then open http://localhost:6006 in your browser
```

In [None]:
import sys
from pathlib import Path

# Add project root to path
project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import display
from tqdm.notebook import tqdm

from placecell.analysis import (
    build_event_place_dataframe,
    compute_occupancy_map,
    compute_unit_analysis,
    load_curated_unit_ids,
    load_traces,
)
from placecell.config import AppConfig, DataPathsConfig
from placecell.io import load_behavior_data, load_neural_data
from placecell.notebook import build_event_index_dataframe, create_unit_browser, run_deconvolution
from placecell.visualization import plot_summary_scatter

## Configuration

In [None]:
# Paths - adjust these as needed
CONFIG_PATH = project_root / "placecell/config/example_pcell_config.yaml"
DATA_PATH = Path(
    "/mnt/data/minizero_analysis/202512round/202511_analysis_placecell/"
    "20251201/WL25/WL25_20251201.yaml"
)
DATA_DIR = DATA_PATH.parent
OUTPUT_DIR = project_root / "output"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Load configs
cfg = AppConfig.from_yaml(CONFIG_PATH)
data_cfg = DataPathsConfig.from_yaml(DATA_PATH)

# Apply data-specific overrides (e.g., OASIS parameters)
cfg = cfg.with_data_overrides(data_cfg)

# Resolve data paths relative to data directory
neural_path = DATA_DIR / data_cfg.neural_path
neural_timestamp = DATA_DIR / data_cfg.neural_timestamp
behavior_position = DATA_DIR / data_cfg.behavior_position
behavior_timestamp = DATA_DIR / data_cfg.behavior_timestamp
curation_csv = (DATA_DIR / data_cfg.curation_csv) if data_cfg.curation_csv else None

print(f"Config: {CONFIG_PATH}")
print(f"Data: {DATA_PATH}")
print(f"Neural path: {neural_path}")
print(f"Neural timestamp: {neural_timestamp}")
print(f"Behavior position: {behavior_position}")
print(f"Behavior timestamp: {behavior_timestamp}")
print(f"Curation CSV: {curation_csv}")

In [None]:
# Extract config values
bodypart = cfg.behavior.bodypart
behavior_fps = cfg.behavior.behavior_fps
speed_threshold = cfg.behavior.speed_threshold
speed_window_frames = cfg.behavior.speed_window_frames
bins = cfg.behavior.spatial_map.bins
min_occupancy = cfg.behavior.spatial_map.min_occupancy
occupancy_sigma = cfg.behavior.spatial_map.occupancy_sigma
activity_sigma = cfg.behavior.spatial_map.activity_sigma
n_shuffles = cfg.behavior.spatial_map.n_shuffles
random_seed = cfg.behavior.spatial_map.random_seed
event_threshold_sigma = cfg.behavior.spatial_map.event_threshold_sigma
p_value_threshold = cfg.behavior.spatial_map.p_value_threshold
stability_threshold = cfg.behavior.spatial_map.stability_threshold

# Neural config
trace_name = cfg.neural.trace_name
neural_fps = cfg.neural.fps
max_units = cfg.neural.max_units
g = cfg.neural.oasis.g
baseline = cfg.neural.oasis.baseline
penalty = cfg.neural.oasis.penalty
s_min = cfg.neural.oasis.s_min

# Visualization settings
trace_time_window = 600.0  # 10 minutes window for trace display

print(f"Bodypart: {bodypart}")
print(f"Speed threshold: {speed_threshold} px/s")
print(f"Bins: {bins}")
print(f"Shuffles: {n_shuffles}")
print(f"Trace name: {trace_name}")
print(f"OASIS g: {g}")

## Step 1: Deconvolution

Run OASIS deconvolution to extract neural events from calcium traces.

In [None]:
# Load traces
print(f"Loading traces from: {neural_path / (trace_name + '.zarr')}")
C_da = load_traces(neural_path, trace_name=trace_name)
all_unit_ids = list(map(int, C_da["unit_id"].values))
print(f"Total units in traces: {len(all_unit_ids)}")

# Filter by curation CSV if provided
if curation_csv is not None and curation_csv.exists():
    curated_ids = set(load_curated_unit_ids(curation_csv))
    all_unit_ids = [uid for uid in all_unit_ids if uid in curated_ids]
    print(f"After curation filter: {len(all_unit_ids)} units")

# Apply max_units limit if configured
if max_units is not None and len(all_unit_ids) > max_units:
    all_unit_ids = all_unit_ids[:max_units]
    print(f"Limited to first {max_units} units")

print(f"Will process {len(all_unit_ids)} units")

In [None]:
# Run OASIS deconvolution
print(f"Running OASIS deconvolution (g={g})...")

good_unit_ids, C_list, S_list = run_deconvolution(
    C_da=C_da,
    unit_ids=all_unit_ids,
    g=g,
    baseline=baseline,
    penalty=penalty,
    s_min=s_min,
    progress_bar=lambda x: tqdm(x, desc="Deconvolving units"),
)

print(f"Successfully deconvolved {len(good_unit_ids)} units")

In [None]:
# Build event index DataFrame
event_index_df = build_event_index_dataframe(good_unit_ids, S_list)
print(f"Total events detected: {len(event_index_df)}")

# Save event index (optional)
event_index_csv = OUTPUT_DIR / "event_index_notebook.csv"
event_index_df.to_csv(event_index_csv, index=False)
print(f"Saved event index to: {event_index_csv}")

## Step 2: Event-Place Matching

Match neural events to behavior positions.

In [None]:
# Build event-place dataframe
print("Matching events to behavior positions...")

event_place_df = build_event_place_dataframe(
    event_index_path=event_index_csv,
    neural_timestamp_path=neural_timestamp,
    behavior_position_path=behavior_position,
    behavior_timestamp_path=behavior_timestamp,
    bodypart=bodypart,
    behavior_fps=behavior_fps,
    speed_threshold=speed_threshold,
    speed_window_frames=speed_window_frames,
)

print(f"Event-place entries: {len(event_place_df)}")
print(f"Unique units: {event_place_df['unit_id'].nunique()}")

# Save event-place (optional)
event_place_csv = OUTPUT_DIR / "event_place_notebook.csv"
event_place_df.to_csv(event_place_csv, index=False)
print(f"Saved event-place to: {event_place_csv}")

## Step 3: Load Data for Visualization

In [None]:
# Filter by speed threshold
df_filtered = event_place_df[event_place_df["speed"] > speed_threshold].copy()
df_all_events = event_index_df.copy()

print(f"Speed-filtered events: {len(df_filtered)}")
print(f"Unique units after filtering: {df_filtered['unit_id'].nunique()}")

In [None]:
# Load behavior data
trajectory_with_speed, trajectory_df = load_behavior_data(
    behavior_position=behavior_position,
    behavior_timestamp=behavior_timestamp,
    bodypart=bodypart,
    speed_window_frames=speed_window_frames,
    speed_threshold=speed_threshold,
)

print(f"Trajectory frames: {len(trajectory_df)}")

In [None]:
# Compute occupancy map
occupancy_time, valid_mask, x_edges, y_edges = compute_occupancy_map(
    trajectory_df=trajectory_df,
    bins=bins,
    behavior_fps=behavior_fps,
    occupancy_sigma=occupancy_sigma,
    min_occupancy=min_occupancy,
)

print(f"Occupancy map shape: {occupancy_time.shape}")
print(f"Valid bins: {valid_mask.sum()} / {valid_mask.size}")

In [None]:
# Load neural data (for visualization)
traces, max_proj, footprints = load_neural_data(
    neural_path=neural_path,
    trace_name=trace_name,
)

print(f"Traces shape: {traces.shape if traces is not None else 'None'}")
print(f"Max proj shape: {max_proj.shape if max_proj is not None else 'None'}")
print(f"Footprints shape: {footprints.shape if footprints is not None else 'None'}")

## Step 4: Compute Unit Analysis

In [None]:
# Set random seed
if random_seed is not None:
    np.random.seed(random_seed)

# Compute analysis for each unit
unique_units = sorted(df_filtered["unit_id"].unique())
n_units = len(unique_units)
print(f"Computing analysis for {n_units} units...")

unit_results = {}
for unit_id in tqdm(unique_units, desc="Computing unit analysis"):
    result = compute_unit_analysis(
        unit_id=unit_id,
        df_filtered=df_filtered,
        trajectory_df=trajectory_df,
        occupancy_time=occupancy_time,
        valid_mask=valid_mask,
        x_edges=x_edges,
        y_edges=y_edges,
        activity_sigma=activity_sigma,
        event_threshold_sigma=event_threshold_sigma,
        n_shuffles=n_shuffles,
        behavior_fps=behavior_fps,
        min_occupancy=min_occupancy,
        stability_threshold=stability_threshold,
    )

    # Visualization data
    vis_data_above = result["events_above_threshold"]
    vis_data_below = pd.DataFrame()
    if df_all_events is not None:
        unit_all_events = df_all_events[df_all_events["unit_id"] == unit_id]
        vis_data_below = unit_all_events[unit_all_events["s"] > result["vis_threshold"]]

    # Trace data
    trace_data = None
    trace_times = None
    if traces is not None:
        try:
            trace_data = traces.sel(unit_id=int(unit_id)).values
            trace_times = np.arange(len(trace_data)) / neural_fps
        except (KeyError, IndexError):
            pass

    unit_results[unit_id] = {
        "rate_map": result["rate_map"],
        "si": result["si"],
        "shuffled_sis": result["shuffled_sis"],
        "p_val": result["p_val"],
        "stability_corr": result["stability_corr"],
        "stability_z": result["stability_z"],
        "vis_data_above": vis_data_above,
        "vis_data_below": vis_data_below,
        "unit_data": result["unit_data"],
        "trace_data": trace_data,
        "trace_times": trace_times,
    }

print(f"Done. Computed analysis for {len(unit_results)} units.")

## Occupancy Preview

In [None]:
fig_occ, axes_occ = plt.subplots(1, 3, figsize=(14, 4))

# Trajectory
axes_occ[0].plot(trajectory_df["x"], trajectory_df["y"], "k-", alpha=0.5, linewidth=0.5)
axes_occ[0].set_title("Trajectory (speed-filtered)")
axes_occ[0].set_aspect("equal")
axes_occ[0].axis("off")

# Occupancy map
im = axes_occ[1].imshow(
    occupancy_time.T, origin="lower", cmap="hot", aspect="equal",
    extent=[x_edges[0], x_edges[-1], y_edges[0], y_edges[-1]]
)
axes_occ[1].contour(valid_mask.T, levels=[0.5], colors="white", linewidths=1.5,
                    extent=[x_edges[0], x_edges[-1], y_edges[0], y_edges[-1]])
axes_occ[1].set_title(f"Occupancy (sigma={occupancy_sigma}, min={min_occupancy}s)")
plt.colorbar(im, ax=axes_occ[1], label="Time (s)")

# Speed distribution
all_speeds = trajectory_with_speed["speed"].values
speed_max = np.percentile(all_speeds[~np.isnan(all_speeds)], 99)
axes_occ[2].hist(all_speeds.clip(max=speed_max), bins=50, color="gray", alpha=0.7)
axes_occ[2].axvline(speed_threshold, color="red", linestyle="--", linewidth=2,
                    label=f"Threshold={speed_threshold}")
axes_occ[2].set_title("Speed distribution")
axes_occ[2].set_xlabel("Speed (px/s)")
axes_occ[2].legend()

plt.tight_layout()
plt.show()

## Summary Scatter Plot

In [None]:
fig_scatter = plot_summary_scatter(
    unit_results,
    p_value_threshold=p_value_threshold,
    stability_threshold=stability_threshold,
)
plt.show()

## Interactive Cell Browser

Use the slider to scroll through cells. Use the time slider to scroll through the trace.

In [None]:
%matplotlib widget

fig, controls = create_unit_browser(
    unit_results=unit_results,
    unique_units=unique_units,
    trajectory_df=trajectory_df,
    df_all_events=df_all_events,
    max_proj=max_proj,
    footprints=footprints,
    x_edges=x_edges,
    y_edges=y_edges,
    trace_name=trace_name,
    neural_fps=neural_fps,
    speed_threshold=speed_threshold,
    p_value_threshold=p_value_threshold,
    stability_threshold=stability_threshold,
    trace_time_window=trace_time_window,
)
display(fig)
display(controls)