# MUA Detection

## Overview

_Developer Note:_ if you may make a PR in the future, be sure to copy this
notebook, and use the `gitignore` prefix `temp` to avoid future conflicts.

This is one notebook in a multi-part series on Spyglass.

- To set up your Spyglass environment and database, see
  [the Setup notebook](./00_Setup.ipynb).
- For additional info on DataJoint syntax, including table definitions and
  inserts, see
  [the Insert Data notebook](./01_Insert_Data.ipynb).
- Prior to running, please generate sorted spikes with the [spike sorting
  pipeline](./02_Spike_Sorting.ipynb) and generate input position data with
  either the [Trodes](./20_Position_Trodes.ipynb) or DLC notebooks
  ([1](./21_Position_DLC_1.ipynb), [2](./22_Position_DLC_2.ipynb),
  [3](./23_Position_DLC_3.ipynb)).

The goal of this notebook is to populate the `MuaEventsV1` table, which depends `SortedSpikesGroup` and `PositionOutput`.

# Imports

In [None]:
import datajoint as dj
from pathlib import Path

dj.config.load(
    Path("dj_local_conf_2.json").absolute()
)  # load config for database connection info

from spyglass.mua.v1.mua import MuaEventsV1, MuaEventsParameters

## Select Position Data

In [None]:
nwb_copy_file_name = "Jasper20251014_.nwb"

In [None]:
import spyglass.position.v1 as sgp
pos_key = {
    "nwb_file_name": nwb_copy_file_name,
    "trodes_pos_params_name": "single_led",
    "interval_list_name": "pos 5 valid times",
}
sgp.TrodesPosSelection().insert1(pos_key, skip_duplicates=True)
sgp.TrodesPosV1.populate(pos_key, display_progress=True)
sgp.TrodesPosV1 & pos_key

In [None]:
from spyglass.position import PositionOutput

pos_merge_id = (PositionOutput.TrodesPosV1() & {'nwb_file_name': nwb_copy_file_name,
                                                'interval_list_name': 'pos 5 valid times'}).fetch1('merge_id')

## Select Sorted Spikes Data

In [None]:
from spyglass.spikesorting.analysis.v1.group import (
    SortedSpikesGroup,
)

SortedSpikesGroup() & {'nwb_file_name': nwb_copy_file_name}

In [None]:
from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput
import spyglass.spikesorting.v1 as sgs

sorter_keys = {
    "nwb_file_name": nwb_copy_file_name,
    "sorter": "clusterless_thresholder",
    "curation_id": 0,
}

(
    sgs.SpikeSortingSelection & sorter_keys
) * SpikeSortingOutput.CurationV1 & sorter_keys

spikesorting_merge_ids = SpikeSortingOutput().get_restricted_merge_ids(
    sorter_keys, restrict_by_artifact=False
)

# create a new sorted spikes group
unit_filter_params_name = "default_exclusion"
# SortedSpikesGroup().create_group(
#     group_name="all_shanks",
#     nwb_file_name=nwb_copy_file_name,
#     keys=[
#         {"spikesorting_merge_id": merge_id}
#         for merge_id in spikesorting_merge_ids
#     ],
#     unit_filter_params_name=unit_filter_params_name
# )

SortedSpikesGroup & {
    "nwb_file_name": nwb_copy_file_name,
    "sorted_spikes_group_name": "all_shanks",
}

In [None]:
SortedSpikesGroup.Units & {
    "nwb_file_name": nwb_copy_file_name,
    "sorted_spikes_group_name": "all_shanks",
    "unit_filter_params_name": unit_filter_params_name,
}

In [None]:
from spyglass.spikesorting.analysis.v1.group import (
    SortedSpikesGroup,
)

# Select sorted spikes data
sorted_spikes_group_key = {
    "nwb_file_name": nwb_copy_file_name,
    "sorted_spikes_group_name": "all_shanks",
    "unit_filter_params_name": "default_exclusion",
}

SortedSpikesGroup & sorted_spikes_group_key

# Setting MUA Parameters

In [None]:
MuaEventsParameters()

Here are the default parameters:

In [None]:
(MuaEventsParameters() & {"mua_param_name": "default"}).fetch1()

Putting everything together: create a key and populate the MuaEventsV1 table

In [None]:
mua_key = {
    "mua_param_name": "default",
    **sorted_spikes_group_key,
    "pos_merge_id": pos_merge_id,
    "detection_interval": "pos 5 valid times",
}

# MuaEventsV1().populate(mua_key)
MuaEventsV1 & mua_key

Now we can use `fetch1_dataframe` for mua data, including start times, end times, and speed.

In [None]:
mua_times = (MuaEventsV1 & mua_key).fetch1_dataframe()
mua_times

## Plotting

From this, we can plot MUA firing rate and speed together.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(2, 1, sharex=True, figsize=(15, 4))
speed = MuaEventsV1.get_speed(mua_key)  # get speed from MuaEventsV1 table
time = speed.index.to_numpy()
speed = speed.to_numpy()
multiunit_firing_rate = MuaEventsV1.get_firing_rate(
    mua_key, time
)  # get firing rate from MuaEventsV1 table

time_slice = slice(
    np.searchsorted(time, mua_times.loc[10].start_time) - 1_000,
    np.searchsorted(time, mua_times.loc[10].start_time) + 5_000,
)

axes[0].plot(
    time[time_slice],
    multiunit_firing_rate[time_slice],
    color="black",
)
axes[0].set_ylabel("firing rate (Hz)")
axes[0].set_title("multiunit activity")
axes[1].fill_between(time[time_slice], speed[time_slice], color="lightgrey")
axes[1].set_ylabel("speed (cm/s)")
axes[1].set_xlabel("time (s)")

for id, mua_time in mua_times.loc[
    np.logical_and(
        mua_times["start_time"] > time[time_slice].min(),
        mua_times["end_time"] < time[time_slice].max(),
    )
].iterrows():
    axes[0].axvspan(
        mua_time["start_time"], mua_time["end_time"], color="red", alpha=0.5
    )

We can also create a figurl to visualize the data.

In [None]:
(MuaEventsV1 & mua_key).create_figurl(
    zscore_mua=True,
)

In [None]:
#now looking at laser stim stuff

import pynwb
io = pynwb.NWBHDF5IO("/stelmo/nwb/raw/Jasper20251014.nwb", mode="r")
nwbf = io.read()
nwbf

In [None]:
total_timestamps = nwbf.acquisition['e-series'].timestamps[:]

In [None]:
epoch_start_time, epoch_stop_time, epoch_name = nwbf.intervals['epochs'][5].to_numpy()[0]

In [None]:
laser_data = nwbf.processing['behavior'].data_interfaces['behavioral_events'].time_series['Laser'].data[:]
laser_timestamps = nwbf.processing['behavior'].data_interfaces['behavioral_events'].time_series['Laser'].timestamps[:]

laser_data_epoch = laser_data[(laser_timestamps>epoch_start_time) & (laser_timestamps<=epoch_stop_time)]
laser_data_epoch[0] = 0
laser_timestamps_epoch = laser_timestamps[(laser_timestamps>epoch_start_time) & (laser_timestamps<=epoch_stop_time)]

In [None]:
laser_on_times_epoch = laser_timestamps_epoch[laser_data_epoch==1]
first_laser_on_times = laser_on_times_epoch[np.append(np.diff(laser_on_times_epoch)>0.01,True)]

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# --- Get MUA firing rate and time vector ---
multiunit_firing_rate = MuaEventsV1.get_firing_rate(mua_key, time)
laser_stim = 350
# --- Define laser-aligned time window ---
start_time = first_laser_on_times[laser_stim] - 0.2
stop_time  = first_laser_on_times[laser_stim] + 0.2

time_slice = slice(
    np.searchsorted(time, start_time),
    np.searchsorted(time, stop_time)
)

# --- Plot ---
fig, ax = plt.subplots(figsize=(12, 4))

ax.plot(time[time_slice], multiunit_firing_rate[time_slice], color='black', lw=1.2)
ax.set_xlabel("Time (s)")
ax.set_ylabel("MUA firing rate (Hz)")
ax.set_title("Jasper: Example MUA During Opto Stim")

# Highlight laser period
ax.axvspan(first_laser_on_times[laser_stim], first_laser_on_times[laser_stim] + 0.1,
           color='red', alpha=0.3, label='Laser ON')

ax.legend()
plt.tight_layout()
plt.show()


In [None]:
from scipy.stats import zscore

multiunit_firing_rate = MuaEventsV1.get_firing_rate(mua_key, time)
z_multiunit_firing_rate = zscore(multiunit_firing_rate)

laser_stim = 350
# --- Define laser-aligned time window ---
start_time = first_laser_on_times[laser_stim] - 0.2
stop_time  = first_laser_on_times[laser_stim] + 0.2

time_slice = slice(
    np.searchsorted(time, start_time),
    np.searchsorted(time, stop_time)
)

# --- Plot ---
fig, ax = plt.subplots(figsize=(12, 4))

ax.plot(time[time_slice], z_multiunit_firing_rate[time_slice], color='black', lw=1.2)
ax.set_xlabel("Time (s)")
ax.set_ylabel("MUA firing rate (Z-Score)")
ax.set_title("Jasper: Example MUA During Opto Stim")

# Highlight laser period
ax.axvspan(first_laser_on_times[laser_stim], first_laser_on_times[laser_stim] + 0.1,
           color='red', alpha=0.3, label='Laser ON')

ax.legend()
plt.tight_layout()
plt.show()


In [None]:

window = [-0.2, 0.2]  # seconds before/after laser onset
dt = np.median(np.diff(time))  # sampling interval (s)
samples_before = int(window[0] / dt)
samples_after  = int(window[1] / dt)
window_samples = np.arange(samples_before, samples_after)

# --- 3. Collect peri-laser MUA traces ---
laser_times = np.array(first_laser_on_times)
valid_traces = []  # store each trial’s MUA segment

for t in laser_times:
    # Skip if the window would exceed data boundaries
    if (t + window[0] < time[0]) or (t + window[1] > time[-1]):
        continue

    center_idx = np.searchsorted(time, t)
    trace = multiunit_firing_rate[center_idx + samples_before : center_idx + samples_after]
    if len(trace) == len(window_samples):  # ensure consistent length
        valid_traces.append(trace)

valid_traces = np.array(valid_traces)  # shape: (n_trials, n_timepoints)

# --- 4. Compute mean ± SEM across trials ---
mean_mua = valid_traces.mean(axis=0)
sem_mua  = valid_traces.std(axis=0) / np.sqrt(valid_traces.shape[0])
time_axis = window_samples * dt  # time relative to laser onset

# --- Ensure arrays are 1D ---
time_axis = np.ravel(time_axis)
mean_mua = np.ravel(mean_mua)
sem_mua = np.ravel(sem_mua)


# --- 5. Plot ---
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(time_axis, mean_mua, color='black', label='Mean MUA')
ax.fill_between(time_axis, mean_mua - sem_mua, mean_mua + sem_mua,
                color='gray', alpha=0.3, label='± SEM')

# Shade the laser period (e.g., 100 ms)
ax.axvspan(0, 0.1, color='red', alpha=0.3, label='Laser ON')

ax.set_xlabel("Time relative to laser onset (s)")
ax.set_ylabel("MUA firing rate (Hz)")
ax.set_title("Jasper: Average MUA Across All Opto Stims")
ax.legend()
plt.tight_layout()
plt.show()


In [None]:
window = [-0.2, 0.2]  # seconds before/after laser onset
dt = np.median(np.diff(time))  # sampling interval (s)
samples_before = int(window[0] / dt)
samples_after  = int(window[1] / dt)
window_samples = np.arange(samples_before, samples_after)

# --- 3. Collect peri-laser MUA traces ---
laser_times = np.array(first_laser_on_times)
valid_traces = []  # store each trial’s MUA segment

for t in laser_times:
    # Skip if the window would exceed data boundaries
    if (t + window[0] < time[0]) or (t + window[1] > time[-1]):
        continue

    center_idx = np.searchsorted(time, t)
    trace = z_multiunit_firing_rate[center_idx + samples_before : center_idx + samples_after]
    if len(trace) == len(window_samples):  # ensure consistent length
        valid_traces.append(trace)

valid_traces = np.array(valid_traces)  # shape: (n_trials, n_timepoints)

# --- 4. Compute mean ± SEM across trials ---
mean_mua = valid_traces.mean(axis=0)
sem_mua  = valid_traces.std(axis=0) / np.sqrt(valid_traces.shape[0])
time_axis = window_samples * dt  # time relative to laser onset

# --- Ensure arrays are 1D ---
time_axis = np.ravel(time_axis)
mean_mua = np.ravel(mean_mua)
sem_mua = np.ravel(sem_mua)


# --- 5. Plot ---
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(time_axis, mean_mua, color='black', label='Mean MUA')
ax.fill_between(time_axis, mean_mua - sem_mua, mean_mua + sem_mua,
                color='gray', alpha=0.3, label='± SEM')

# Shade the laser period (e.g., 100 ms)
ax.axvspan(0, 0.1, color='red', alpha=0.3, label='Laser ON')

ax.set_xlabel("Time relative to laser onset (s)")
ax.set_ylabel("MUA firing rate (z-score)")
ax.set_title("Jasper: Average MUA Across All Opto Stims")
ax.legend()
plt.tight_layout()
plt.show()