# Place Cell Analysis

Interactive notebook for analyzing place cells in 2D environment navigation. Equivalent to:
```bash
pdm run pcell workflow visualize --config placecell/config/example_pcell_config.yaml --data user_data/WL25_20251201/WL25_20251201.yaml
```

This notebook runs the full workflow:
1. **Deconvolution** - Extract neural events using OASIS
2. **Event-Place Matching** - Match events to behavior positions
3. **Interactive Visualization** - Browse place cells with scrollable interface

In [None]:
import sys
from pathlib import Path

# Add project root to path
project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from tqdm.notebook import tqdm
import ipywidgets as widgets
from IPython.display import display, clear_output

from placecell.config import AppConfig, DataPathsConfig
from placecell.io import load_behavior_data, load_neural_data
from placecell.analysis import (
    compute_occupancy_map, 
    compute_unit_analysis,
    build_event_place_dataframe,
    load_traces,
    load_curated_unit_ids,
)
from placecell.visualization import plot_summary_scatter

## Configuration

In [None]:
# Paths - adjust these as needed
CONFIG_PATH = project_root / "placecell/config/example_pcell_config.yaml"
DATA_PATH = project_root / "user_data/WL25_20251201/WL25_20251201.yaml"
DATA_DIR = DATA_PATH.parent
OUTPUT_DIR = project_root / "output"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Load configs
cfg = AppConfig.from_yaml(CONFIG_PATH)
data_cfg = DataPathsConfig.from_yaml(DATA_PATH)

# Resolve data paths relative to data directory
neural_path = DATA_DIR / data_cfg.neural_path
neural_timestamp = DATA_DIR / data_cfg.neural_timestamp
behavior_position = DATA_DIR / data_cfg.behavior_position
behavior_timestamp = DATA_DIR / data_cfg.behavior_timestamp
curation_csv = (DATA_DIR / data_cfg.curation_csv) if data_cfg.curation_csv else None

print(f"Config: {CONFIG_PATH}")
print(f"Data: {DATA_PATH}")
print(f"Neural path: {neural_path}")
print(f"Neural timestamp: {neural_timestamp}")
print(f"Behavior position: {behavior_position}")
print(f"Behavior timestamp: {behavior_timestamp}")
print(f"Curation CSV: {curation_csv}")

In [None]:
# Extract config values
bodypart = cfg.behavior.bodypart
behavior_fps = cfg.behavior.behavior_fps
speed_threshold = cfg.behavior.speed_threshold
speed_window_frames = cfg.behavior.speed_window_frames
bins = cfg.behavior.spatial_map.bins
min_occupancy = cfg.behavior.spatial_map.min_occupancy
occupancy_sigma = cfg.behavior.spatial_map.occupancy_sigma
activity_sigma = cfg.behavior.spatial_map.activity_sigma
n_shuffles = cfg.behavior.spatial_map.n_shuffles
random_seed = cfg.behavior.spatial_map.random_seed
event_threshold_sigma = cfg.behavior.spatial_map.event_threshold_sigma
p_value_threshold = cfg.behavior.spatial_map.p_value_threshold
stability_threshold = cfg.behavior.spatial_map.stability_threshold

# Neural config
trace_name = cfg.neural.trace_name
neural_fps = cfg.neural.fps
max_units = cfg.neural.max_units
g = cfg.neural.oasis.g
baseline = cfg.neural.oasis.baseline
penalty = cfg.neural.oasis.penalty
s_min = cfg.neural.oasis.s_min

# Visualization settings
trace_time_window = 600.0  # 10 minutes window for trace display

print(f"Bodypart: {bodypart}")
print(f"Speed threshold: {speed_threshold} px/s")
print(f"Bins: {bins}")
print(f"Shuffles: {n_shuffles}")
print(f"Trace name: {trace_name}")
print(f"OASIS g: {g}")

## Step 1: Deconvolution

Run OASIS deconvolution to extract neural events from calcium traces.

In [None]:
from oasis.oasis_methods import oasisAR2

# Load traces
print(f"Loading traces from: {neural_path / (trace_name + '.zarr')}")
C_da = load_traces(neural_path, trace_name=trace_name)
all_unit_ids = list(map(int, C_da["unit_id"].values))
print(f"Total units in traces: {len(all_unit_ids)}")

# Filter by curation CSV if provided
if curation_csv is not None and curation_csv.exists():
    curated_ids = set(load_curated_unit_ids(curation_csv))
    all_unit_ids = [uid for uid in all_unit_ids if uid in curated_ids]
    print(f"After curation filter: {len(all_unit_ids)} units")

# Apply max_units limit if configured
if max_units is not None and len(all_unit_ids) > max_units:
    all_unit_ids = all_unit_ids[:max_units]
    print(f"Limited to first {max_units} units")

print(f"Will process {len(all_unit_ids)} units")

In [None]:
# Run OASIS deconvolution
print(f"Running OASIS deconvolution (g={g})...")

good_unit_ids = []
C_list = []
S_list = []

for uid in tqdm(all_unit_ids, desc="Deconvolving units"):
    y = np.ascontiguousarray(C_da.sel(unit_id=uid).values, dtype=np.float64)
    
    # Baseline correction
    if isinstance(baseline, str) and baseline.startswith("p"):
        p = float(baseline[1:])
        b = float(np.percentile(y, p))
    else:
        b = float(baseline)
    
    y_corrected = y - b
    
    try:
        c, s = oasisAR2(y_corrected, g1=g[0], g2=g[1], lam=penalty, s_min=s_min)
        good_unit_ids.append(int(uid))
        C_list.append(np.asarray(c, dtype=float))
        S_list.append(np.asarray(s, dtype=float))
    except Exception as e:
        print(f"Skipping unit {uid}: {e}")

print(f"Successfully deconvolved {len(good_unit_ids)} units")

In [None]:
# Build event index DataFrame
event_rows = []
S_arr = np.stack(S_list, axis=0)

for i, uid in enumerate(good_unit_ids):
    s_vec = S_arr[i]
    frames = np.nonzero(s_vec > 0)[0]
    for fr in frames:
        event_rows.append({"unit_id": uid, "frame": int(fr), "s": float(s_vec[fr])})

event_index_df = pd.DataFrame(event_rows)
print(f"Total events detected: {len(event_index_df)}")

# Save event index (optional)
event_index_csv = OUTPUT_DIR / "event_index_notebook.csv"
event_index_df.to_csv(event_index_csv, index=False)
print(f"Saved event index to: {event_index_csv}")

## Step 2: Event-Place Matching

Match neural events to behavior positions.

In [None]:
# Build event-place dataframe
print("Matching events to behavior positions...")

event_place_df = build_event_place_dataframe(
    event_index_path=event_index_csv,
    neural_timestamp_path=neural_timestamp,
    behavior_position_path=behavior_position,
    behavior_timestamp_path=behavior_timestamp,
    bodypart=bodypart,
    behavior_fps=behavior_fps,
    speed_threshold=speed_threshold,
    speed_window_frames=speed_window_frames,
)

print(f"Event-place entries: {len(event_place_df)}")
print(f"Unique units: {event_place_df['unit_id'].nunique()}")

# Save event-place (optional)
event_place_csv = OUTPUT_DIR / "event_place_notebook.csv"
event_place_df.to_csv(event_place_csv, index=False)
print(f"Saved event-place to: {event_place_csv}")

## Step 3: Load Data for Visualization

In [None]:
# Filter by speed threshold
df_filtered = event_place_df[event_place_df["speed"] > speed_threshold].copy()
df_all_events = event_index_df.copy()

print(f"Speed-filtered events: {len(df_filtered)}")
print(f"Unique units after filtering: {df_filtered['unit_id'].nunique()}")

In [None]:
# Load behavior data
trajectory_with_speed, trajectory_df = load_behavior_data(
    behavior_position=behavior_position,
    behavior_timestamp=behavior_timestamp,
    bodypart=bodypart,
    speed_window_frames=speed_window_frames,
    speed_threshold=speed_threshold,
)

print(f"Trajectory frames: {len(trajectory_df)}")

In [None]:
# Compute occupancy map
occupancy_time, valid_mask, x_edges, y_edges = compute_occupancy_map(
    trajectory_df=trajectory_df,
    bins=bins,
    behavior_fps=behavior_fps,
    occupancy_sigma=occupancy_sigma,
    min_occupancy=min_occupancy,
)

print(f"Occupancy map shape: {occupancy_time.shape}")
print(f"Valid bins: {valid_mask.sum()} / {valid_mask.size}")

In [None]:
# Load neural data (for visualization)
traces, max_proj, footprints = load_neural_data(
    neural_path=neural_path,
    trace_name=trace_name,
)

print(f"Traces shape: {traces.shape if traces is not None else 'None'}")
print(f"Max proj shape: {max_proj.shape if max_proj is not None else 'None'}")
print(f"Footprints shape: {footprints.shape if footprints is not None else 'None'}")

## Step 4: Compute Unit Analysis

In [None]:
# Set random seed
if random_seed is not None:
    np.random.seed(random_seed)

# Compute analysis for each unit
unique_units = sorted(df_filtered["unit_id"].unique())
n_units = len(unique_units)
print(f"Computing analysis for {n_units} units...")

unit_results = {}
for unit_id in tqdm(unique_units, desc="Computing unit analysis"):
    result = compute_unit_analysis(
        unit_id=unit_id,
        df_filtered=df_filtered,
        trajectory_df=trajectory_df,
        occupancy_time=occupancy_time,
        valid_mask=valid_mask,
        x_edges=x_edges,
        y_edges=y_edges,
        activity_sigma=activity_sigma,
        event_threshold_sigma=event_threshold_sigma,
        n_shuffles=n_shuffles,
        behavior_fps=behavior_fps,
        min_occupancy=min_occupancy,
        stability_threshold=stability_threshold,
    )

    # Visualization data
    vis_data_above = result["events_above_threshold"]
    vis_data_below = pd.DataFrame()
    if df_all_events is not None:
        unit_all_events = df_all_events[df_all_events["unit_id"] == unit_id]
        vis_data_below = unit_all_events[unit_all_events["s"] > result["vis_threshold"]]

    # Trace data
    trace_data = None
    trace_times = None
    if traces is not None:
        try:
            trace_data = traces.sel(unit_id=int(unit_id)).values
            trace_times = np.arange(len(trace_data)) / neural_fps
        except (KeyError, IndexError):
            pass

    unit_results[unit_id] = {
        "rate_map": result["rate_map"],
        "si": result["si"],
        "shuffled_sis": result["shuffled_sis"],
        "p_val": result["p_val"],
        "stability_corr": result["stability_corr"],
        "stability_z": result["stability_z"],
        "vis_data_above": vis_data_above,
        "vis_data_below": vis_data_below,
        "unit_data": result["unit_data"],
        "trace_data": trace_data,
        "trace_times": trace_times,
    }

print(f"Done. Computed analysis for {len(unit_results)} units.")

## Occupancy Preview

In [None]:
fig_occ, axes_occ = plt.subplots(1, 3, figsize=(14, 4))

# Trajectory
axes_occ[0].plot(trajectory_df["x"], trajectory_df["y"], "k-", alpha=0.5, linewidth=0.5)
axes_occ[0].set_title("Trajectory (speed-filtered)")
axes_occ[0].set_aspect("equal")
axes_occ[0].axis("off")

# Occupancy map
im = axes_occ[1].imshow(
    occupancy_time.T, origin="lower", cmap="hot", aspect="equal",
    extent=[x_edges[0], x_edges[-1], y_edges[0], y_edges[-1]]
)
axes_occ[1].contour(valid_mask.T, levels=[0.5], colors="white", linewidths=1.5,
                    extent=[x_edges[0], x_edges[-1], y_edges[0], y_edges[-1]])
axes_occ[1].set_title(f"Occupancy (sigma={occupancy_sigma}, min={min_occupancy}s)")
plt.colorbar(im, ax=axes_occ[1], label="Time (s)")

# Speed distribution
all_speeds = trajectory_with_speed["speed"].values
speed_max = np.percentile(all_speeds[~np.isnan(all_speeds)], 99)
axes_occ[2].hist(all_speeds.clip(max=speed_max), bins=50, color="gray", alpha=0.7)
axes_occ[2].axvline(speed_threshold, color="red", linestyle="--", linewidth=2,
                    label=f"Threshold={speed_threshold}")
axes_occ[2].set_title("Speed distribution")
axes_occ[2].set_xlabel("Speed (px/s)")
axes_occ[2].legend()

plt.tight_layout()
plt.show()

## Summary Scatter Plot

In [None]:
fig_scatter = plot_summary_scatter(
    unit_results,
    p_value_threshold=p_value_threshold,
    stability_threshold=stability_threshold,
)
plt.show()

## Interactive Cell Browser

Use the slider to scroll through cells. Use the time slider to scroll through the trace.

In [None]:
# Interactive browser with ipympl (requires: pdm install -G notebook)
%matplotlib widget

# Create figure once
fig = plt.figure(figsize=(16, 9))

# Create axes
ax1 = fig.add_axes([0.03, 0.42, 0.18, 0.45])  # Max projection
ax2 = fig.add_axes([0.25, 0.42, 0.18, 0.45])  # Trajectory
ax3 = fig.add_axes([0.47, 0.42, 0.16, 0.45])  # Rate map
ax3_cbar = fig.add_axes([0.635, 0.49, 0.015, 0.315])  # Colorbar
ax4 = fig.add_axes([0.74, 0.42, 0.18, 0.45])  # SI histogram
ax5 = fig.add_axes([0.05, 0.08, 0.90, 0.28])  # Trace

# Text annotations holder
text_annotations = []

def render_unit(unit_idx, trace_start):
    """Render visualization for a specific unit (in-place update)."""
    global text_annotations
    
    unit_id = unique_units[unit_idx]
    result = unit_results[unit_id]
    
    # Clear all axes
    for ax in [ax1, ax2, ax3, ax3_cbar, ax4, ax5]:
        ax.clear()
    
    # Clear text annotations
    for txt in text_annotations:
        txt.remove()
    text_annotations = []
    
    # 1. Max projection with neuron footprint
    if max_proj is not None:
        ax1.imshow(max_proj, cmap="gray", aspect="equal")
        if footprints is not None:
            try:
                unit_fp = footprints.sel(unit_id=unit_id).values
                if unit_fp.max() > 0:
                    ax1.contour(unit_fp, levels=[unit_fp.max() * 0.3], 
                               colors="red", linewidths=1.5)
            except (KeyError, IndexError, ValueError):
                pass
        ax1.set_title(f"Unit {unit_id}")
    else:
        ax1.text(0.5, 0.5, "No max projection", ha="center", va="center", 
                transform=ax1.transAxes)
        ax1.set_title(f"Unit {unit_id}")
    ax1.axis("off")
    
    # 2. Trajectory + events
    vis_data_above = result["vis_data_above"]
    ax2.plot(trajectory_df["x"], trajectory_df["y"], "k-", alpha=1.0, 
            linewidth=1, zorder=1)
    
    if not vis_data_above.empty:
        amps = vis_data_above["s"].values
        amp_max = np.max(amps) if len(amps) > 0 and np.max(amps) > 0 else 1.0
        alphas = amps / amp_max
        ax2.scatter(vis_data_above["x"], vis_data_above["y"], c="red", 
                   s=30, alpha=alphas, zorder=2)
    
    ax2.set_title(f"Trajectory ({len(vis_data_above)} events)")
    ax2.set_aspect("equal")
    ax2.axis("off")
    
    # 3. Rate map
    rate_map_data = result["rate_map"].T
    im = ax3.imshow(rate_map_data, origin="lower",
                   extent=[x_edges[0], x_edges[-1], y_edges[0], y_edges[-1]],
                   aspect="equal", cmap="jet")
    ax3.set_title("Rate map")
    ax3.axis("off")
    im.set_clim(0.0, 1.0)
    plt.colorbar(im, cax=ax3_cbar)
    ax3_cbar.set_ylabel("Norm. rate", rotation=270, labelpad=10)
    
    # 4. SI histogram
    ax4.hist(result["shuffled_sis"], bins=15, color="gray", alpha=0.7, 
            edgecolor="black")
    ax4.axvline(result["si"], color="red", linestyle="--", linewidth=2)
    ax4.set_title(f"SI: {result['si']:.2f}, p={result['p_val']:.3f}")
    ax4.set_xlabel("SI (bits/s)")
    ax4.set_ylabel("Count")
    ax4.set_box_aspect(1)
    
    # 5. Trace with events
    if result["trace_data"] is not None and result["trace_times"] is not None:
        trace = result["trace_data"]
        t_full = result["trace_times"]
        
        t_max = t_full[-1] if len(t_full) > 0 else trace_time_window
        t_start = max(0, trace_start)
        t_end = min(t_max, t_start + trace_time_window)
        
        mask = (t_full >= t_start) & (t_full <= t_end)
        t_visible = t_full[mask]
        trace_visible = trace[mask]
        
        ax5.plot(t_visible, trace_visible, "b-", linewidth=0.5, label="Fluorescence")
        
        # Event spikes
        event_times_gray, event_amps_gray = [], []
        event_times_red, event_amps_red = [], []
        
        if df_all_events is not None:
            unit_all = df_all_events[df_all_events["unit_id"] == unit_id]
            if "frame" in unit_all.columns and "s" in unit_all.columns and not unit_all.empty:
                event_t = unit_all["frame"].values / neural_fps
                event_a = unit_all["s"].values
                m = (event_t >= t_start) & (event_t <= t_end)
                if np.any(m):
                    event_times_gray = event_t[m]
                    event_amps_gray = event_a[m]
        
        if "frame" in vis_data_above.columns and "s" in vis_data_above.columns and not vis_data_above.empty:
            event_t = vis_data_above["frame"].values / neural_fps
            event_a = vis_data_above["s"].values
            m = (event_t >= t_start) & (event_t <= t_end)
            if np.any(m):
                event_times_red = event_t[m]
                event_amps_red = event_a[m]
        
        # Scale spikes
        y_min, y_max = ax5.get_ylim()
        baseline_y = y_min
        all_amps = np.concatenate([
            event_amps_gray if len(event_amps_gray) > 0 else [],
            event_amps_red if len(event_amps_red) > 0 else [],
        ])
        amp_max = np.max(all_amps) if len(all_amps) > 0 else 1.0
        y_range = y_max - y_min
        max_spike_height = y_range * 0.3
        
        def scale_h(a):
            return (a / amp_max) * max_spike_height if amp_max > 0 else 0
        
        for t, a in zip(event_times_gray, event_amps_gray):
            ax5.plot([t, t], [baseline_y, baseline_y + scale_h(a)], color="gray", lw=1.5)
        for t, a in zip(event_times_red, event_amps_red):
            ax5.plot([t, t], [baseline_y, baseline_y + scale_h(a)], color="red", lw=1.5)
        
        ax5.set_xlim(t_start, t_end)
        ax5.set_xlabel("Time (s)")
        ax5.set_ylabel(trace_name)
        
        # Legend
        legend_elements = [Line2D([0], [0], color="blue", linewidth=0.5, label="Fluorescence")]
        if len(event_times_gray) > 0:
            legend_elements.append(Line2D([0], [0], color="gray", linewidth=1.5, 
                                         label=f"Events (< {speed_threshold:.1f} px/s)"))
        if len(event_times_red) > 0:
            legend_elements.append(Line2D([0], [0], color="red", linewidth=1.5, 
                                         label=f"Events (>= {speed_threshold:.1f} px/s)"))
        ax5.legend(handles=legend_elements, loc="upper left", fontsize=8, framealpha=0.9)
    else:
        ax5.text(0.5, 0.5, "No trace data", ha="center", va="center", 
                transform=ax5.transAxes)
    
    # Status text
    n_events = len(result["unit_data"]) if not result["unit_data"].empty else 0
    p_val = result["p_val"]
    stab_corr = result["stability_corr"]
    
    sig_pass = p_val < p_value_threshold
    sig_text = "pass" if sig_pass else "fail"
    sig_color = "green" if sig_pass else "red"
    
    if np.isnan(stab_corr):
        stab_text, stab_color = "N/A", "gray"
    else:
        stab_pass = stab_corr >= stability_threshold
        stab_text = "pass" if stab_pass else "fail"
        stab_color = "green" if stab_pass else "red"
    
    txt = fig.text(0.02, 0.98, 
            f"Unit ID: {unit_id} ({unit_idx + 1}/{n_units}) | N={n_events} events",
            ha="left", va="top", fontsize=11, fontweight="bold",
            transform=fig.transFigure)
    text_annotations.append(txt)
    
    txt = fig.text(0.02, 0.95, f"Significance (p={p_val:.3f}): ", ha="left", va="top",
            fontsize=10, transform=fig.transFigure)
    text_annotations.append(txt)
    txt = fig.text(0.18, 0.95, sig_text, ha="left", va="top", fontsize=10,
            fontweight="bold", color=sig_color, transform=fig.transFigure)
    text_annotations.append(txt)
    
    stab_str = f"r={stab_corr:.2f}" if not np.isnan(stab_corr) else ""
    txt = fig.text(0.02, 0.92, f"Stability ({stab_str}): ", ha="left", va="top",
            fontsize=10, transform=fig.transFigure)
    text_annotations.append(txt)
    txt = fig.text(0.16, 0.92, stab_text, ha="left", va="top", fontsize=10,
            fontweight="bold", color=stab_color, transform=fig.transFigure)
    text_annotations.append(txt)
    
    fig.canvas.draw_idle()

# Get max trace time for slider
max_trace_time = 0
for r in unit_results.values():
    if r["trace_times"] is not None and len(r["trace_times"]) > 0:
        max_trace_time = max(max_trace_time, r["trace_times"][-1])

# Create widgets
unit_slider = widgets.IntSlider(
    value=0, min=0, max=n_units - 1, step=1,
    description="Unit:",
    continuous_update=False,
    layout=widgets.Layout(width="600px")
)

trace_slider = widgets.FloatSlider(
    value=0, min=0, max=max(0, max_trace_time - trace_time_window), step=10,
    description="Time (s):",
    continuous_update=False,
    layout=widgets.Layout(width="600px")
)

# Navigation buttons
prev_btn = widgets.Button(description="< Prev", layout=widgets.Layout(width="80px"))
next_btn = widgets.Button(description="Next >", layout=widgets.Layout(width="80px"))

def on_prev(b):
    unit_slider.value = (unit_slider.value - 1) % n_units

def on_next(b):
    unit_slider.value = (unit_slider.value + 1) % n_units

prev_btn.on_click(on_prev)
next_btn.on_click(on_next)

def update(change=None):
    render_unit(unit_slider.value, trace_slider.value)

unit_slider.observe(update, names="value")
trace_slider.observe(update, names="value")

# Layout
nav_box = widgets.HBox([prev_btn, unit_slider, next_btn])
controls = widgets.VBox([nav_box, trace_slider])

# Initial render
render_unit(0, 0)

display(controls)