# Berke Lab Spike Sorting and Decoding 

In [None]:
import datajoint as dj
import numpy as np
import matplotlib.pyplot as plt
from pynwb import NWBHDF5IO

import spyglass.common as sgc
import spyglass.spikesorting.v1 as sgs
import spyglass.position as sgp
from spyglass.common import Nwbfile
from spyglass.utils.nwb_helper_fn import get_nwb_file

# Make sure the session exists
# nwb_file_name = "IM-1594_20230726_.nwb"
nwb_file_name = "IM-1478_20220726_.nwb"

# Fetch file create date and source version to make sure it's up to date
nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name)
nwbf = get_nwb_file(nwb_file_abspath)
print(f"File created on {nwbf.file_create_date[0].strftime('%m/%d/%Y %H:%M:%S')}")
print(f"Source script version {nwbf.source_script}")

Take a quick look at the parameters we are using!

In [None]:
from spyglass.spikesorting.analysis.v1.group import UnitSelectionParams
from spyglass.decoding.v1.core import DecodingParameters

# Preprocessing
preproc_param_name = "franklab_tetrode_hippocampus"
artifact_param_name = "ampl_1000_z_30_prop_075_1ms"
# Sorting
sorter = "mountainsort4"
sorter_param_name = "franklab_tetrode_hippocampus_30KHz"
# Curation
waveform_param_name = "default_not_whitened"
metric_param_name = "franklab_default"
metric_curation_param_name = "default"
# Decoding
unit_filter_params_name = "default_exclusion"
decoding_param_name = "contfrag_sorted"
# Position (for decoding)
trodes_pos_params_name = "berke_double_led_500"

# Optionally print them all!
review_params = True
if review_params:
    # Preprocessing
    display(
        (
            sgs.SpikeSortingPreprocessingParameters()
            & {"preproc_param_name": preproc_param_name}
        ).fetch1()
    )
    display(
        (
            sgs.ArtifactDetectionParameters()
            & {"artifact_param_name": artifact_param_name}
        ).fetch1("artifact_params")
    )
    # Sorting
    display(
        (
            sgs.SpikeSorterParameters()
            & {"sorter": sorter, "sorter_param_name": sorter_param_name}
        ).fetch1()
    )
    # Curation
    display(
        (
            sgs.WaveformParameters() & {"waveform_param_name": waveform_param_name}
        ).fetch1()
    )
    display(
        (sgs.MetricParameters() & {"metric_param_name": metric_param_name}).fetch1()
    )
    display(
        (
            sgs.MetricCurationParameters()
            & {"metric_curation_param_name": metric_curation_param_name}
        ).fetch1()
    )
    # Decoding
    display(
        (
            UnitSelectionParams() & {"unit_filter_params_name": unit_filter_params_name}
        ).fetch1()
    )
    display(
        (DecodingParameters() & {"decoding_param_name": decoding_param_name}).fetch(
            "decoding_params"
        )
    )
    # Position (for decoding)
    display(
        (
            sgp.v1.TrodesPosParams()
            & {"trodes_pos_params_name": trodes_pos_params_name}
        ).fetch1("params")
    )

# First check out all existing entries for this nwb

Helpful if we're halfway through running this.

In [None]:
from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput
from spyglass.spikesorting.analysis.v1.group import SortedSpikesGroup
from spyglass.position import PositionOutput

# Find all entries for this nwb in PositionOutput
print(f"Entries for {nwb_file_name} in PositionOutput.TrodesPosV1")
display(PositionOutput.TrodesPosV1 & {"nwb_file_name": nwb_file_name})

# Find all entries for this nwb in SortGroup
print(f"Entries for {nwb_file_name} in sgs.SortGroup")
display((sgs.SortGroup & {"nwb_file_name": nwb_file_name}))

# Find all entries for this nwb in SpikeSortingRecordingSelection
print(
    f"Entries for {nwb_file_name} in sgs.SpikeSortingRecordingSelection (one per SortGroup, or more if we are trying multiple preprocessing params)"
)
display(sgs.SpikeSortingRecordingSelection() & {"nwb_file_name": nwb_file_name})

# Fetch the recording ids (there is one for each sort group)
recording_ids = (
    sgs.SpikeSortingRecordingSelection() & {"nwb_file_name": nwb_file_name}
).fetch("KEY")

# Use the recording_ids to get the preprocessed recording for each
print(f"Entries for {nwb_file_name} in sgs.SpikeSortingRecording (one per SortGroup)")
display(sgs.SpikeSortingRecording() & recording_ids)

# Get the detected artifact times for this recording
print(
    f"Entries for {nwb_file_name} in sgs.ArtifactDetectionSelection (one per recording_id)"
)
display(sgs.ArtifactDetectionSelection() & recording_ids)
artifact_ids = (sgs.ArtifactDetectionSelection() & recording_ids).fetch("KEY")

print(f"Entries for {nwb_file_name} in sgs.ArtifactDetection (one per recording_id)")
display(sgs.ArtifactDetection() & artifact_ids)

print(f"Entries for {nwb_file_name} in sgs.SpikeSortingSelection")
display(sgs.SpikeSortingSelection() & {"nwb_file_name": nwb_file_name})
sorting_ids = (sgs.SpikeSortingSelection() & {"nwb_file_name": nwb_file_name}).fetch(
    "KEY"
)

print(f"Entries for {nwb_file_name} in sgs.SpikeSorting")
display(sgs.SpikeSorting() & sorting_ids)

print(f"Entries for {nwb_file_name} in sgs.CurationV1")
display(sgs.CurationV1() & sorting_ids)

print(f"Entries for {nwb_file_name} in sgs.MetricCurationSelection")
display(sgs.MetricCurationSelection() & sorting_ids)

print(f"Entries for {nwb_file_name} in sgs.SpikeSortingOutput")
merge_ids = SpikeSortingOutput().get_restricted_merge_ids(
    key={"nwb_file_name": nwb_file_name}, as_dict=True
)
display(SpikeSortingOutput & merge_ids)

print(f"Entries for {nwb_file_name} in SortedSpikesGroup")
display(SortedSpikesGroup & {"nwb_file_name": nwb_file_name})

## Define sort groups and extract recordings

For now we generally sort electrodes on the same shank together. 


I have also written some custom functions that allow us to set SortGroups based on different attributes, so we may have way more `SortGroups` than shanks (if we started by one `SortGroup` per shank, then chose custom ones with different numbers). Generally when we define custom `SortGroups` we'll start from higher numbers to make it clear that they are separate from the ones auto-assigned by `sgs.SortGroup.set_group_by_shank` (e.g. if `set_group_by_shank` assigns groups 0-24 for 25 good shanks, we might choose to start our custom `SortGroups` at id 40 to leave a clear gap in between them)


If you do this (have multiple sets of `SortGroups`), remember to choose the actual sort groups you want and don't just use all of them.

In [None]:
existing_sort_group_ids = (sgs.SortGroup & {"nwb_file_name": nwb_file_name}).fetch(
    "sort_group_id"
)
print(f"All existing sort group ids for this nwb: {existing_sort_group_ids}")

# Quick check before we overwrite everything!
# If no SortGroups exist yet, we generally start with setting them by shank.
if existing_sort_group_ids.size == 0:
    sgs.SortGroup.set_group_by_shank(nwb_file_name=nwb_file_name)

Choose the sort group ids we actually want to use!

In [None]:
# # Use all of them:
# sort_group_ids = (sgs.SortGroup & {"nwb_file_name": nwb_file_name}).fetch("sort_group_id")

# Or pick a few:
# NOTE: For IM-1478_20220726_.nwb I'm using SortGroups 1-25 (output by set_group_by_shank)
# Yang-Sun uses custom groups with ids 40+ for clusterless decoding
sort_group_ids = list(range(25))

print(f"Using sort_group_ids: {sort_group_ids}")

## Preprocessing

Filter and reference the recording so that we isolate the spike band data.

In [None]:
# Define and insert a key for each sort group / interval / parameters you want to sort

group_keys = []
for sort_group_id in sort_group_ids:
    key = {
        "nwb_file_name": nwb_file_name,
        "sort_group_id": sort_group_id,
        "interval_list_name": "00_r1",
        "preproc_param_name": preproc_param_name,
        "team_name": "Berke lab and friends",
    }
    # Insert into the selection table
    sgs.SpikeSortingRecordingSelection.insert_selection(key)

    # Grab the primary key (recording_id) and add to our list so we can insert into SpikeSortingRecording
    group_keys.append((sgs.SpikeSortingRecordingSelection & key).fetch1("KEY"))

# Look at everything we inserted!
display(sgs.SpikeSortingRecordingSelection & group_keys)

print("Group keys:")
print(group_keys)

Now call the `populate` method of `SpikeSortingRecording`. 

Instead of just calling `sgs.SpikeSortingRecording.populate(group_keys)` with all group_keys, we only populate the missing keys.

In [None]:
# Populate SpikeSortingRecording for all group_keys (ignoring ones already populated)

# Print the set of all possible group keys
print(f"There are {len(group_keys)} keys: {group_keys}")

# Get the set of already-populated keys in SpikeSortingRecording
existing_keys = (sgs.SpikeSortingRecording & group_keys).fetch("KEY", as_dict=True)
print(
    f"There are {len(existing_keys)} already in in SpikeSortingRecording: {existing_keys}"
)

# Find missing keys
missing_keys = [key for key in group_keys if key not in existing_keys]
print(f"There are {len(missing_keys)} missing keys: {missing_keys}")

# Populate only missing entries
if missing_keys:
    sgs.SpikeSortingRecording().populate(missing_keys)
else:
    print("All group keys already populated.")

In [None]:
# Make sure everything worked!
display(sgs.SpikeSortingRecording() & group_keys)

## Plot raw ElectricalSeries directly from the NWB

This plots a chunk of `e_series = nwbf.acquisition.get('ElectricalSeries')` for each shank.

No preprocessing has happened at this point.

In [None]:
# Define time window
start_time = 43
duration = 5

# Get RAW ElectricalSeries from this nwb
e_series = nwbf.acquisition.get("ElectricalSeries")
timestamps = e_series.timestamps
n_timestamps = len(timestamps)


# Helper to find the first index >= target_time using binary search (because we have non-uniform timestamps)
def find_index(target_time, left=0, right=n_timestamps - 1):
    while left < right:
        mid = (left + right) // 2
        mid_val = timestamps[mid]
        if mid_val < target_time:
            left = mid + 1
        else:
            right = mid
    return left


# Find start and end indices using binary search (without loading all)
start_idx = find_index(start_time)
end_idx = find_index(start_time + duration)


def safe_index(vector_data, indices):
    """Helper to index NWB VectorData columns (works around HDF5 ordering limits)"""
    indices = np.asarray(indices)
    order = np.argsort(indices)
    sorted_vals = np.array(vector_data[indices[order]])
    return sorted_vals[np.argsort(order)]


t = timestamps[start_idx:end_idx]
data_chunk = e_series.data[start_idx:end_idx, :]  # lazy slice, all channels
electrode_table = nwbf.electrodes
electrode_region = e_series.electrodes
electrode_ids = electrode_region.data[:]  # indices into the electrode table

shanks = safe_index(electrode_table["probe_shank"], electrode_ids)
electrodes = safe_index(electrode_table["probe_electrode"], electrode_ids)
electrode_names = safe_index(electrode_table["electrode_name"], electrode_ids)
open_ephys_names = safe_index(electrode_table["open_ephys_channel_str"], electrode_ids)
bad_channels = safe_index(electrode_table["bad_channel"], electrode_ids)

offset = 1000
unique_shanks = np.unique(shanks)
for shank in unique_shanks:
    # Select and sort electrodes within this shank
    shank_mask = shanks == shank
    sort_idx = np.argsort(electrodes[shank_mask])
    # Apply mask and sort
    data_shank = data_chunk[:, shank_mask][:, sort_idx]
    names_shank = electrode_names[shank_mask][sort_idx]
    chan_shank = open_ephys_names[shank_mask][sort_idx]
    bad_shank = bad_channels[shank_mask][sort_idx]
    plt.figure(figsize=(14, 5))
    for i in range(data_shank.shape[1]):
        y = data_shank[:, i] + i * offset
        plt.axhline(i * offset, color="gray", linestyle="--", lw=0.5)
        alpha = 0.3 if bad_shank[i] else 1.0
        plt.plot(t, y, lw=0.6, alpha=alpha)
        label = f"{names_shank[i]} - {chan_shank[i]}"
        plt.text(
            t[-1] + 0.001 * (t[-1] - t[0]), i * offset, label, va="bottom", fontsize=9
        )
    plt.xlabel("Time (s)")
    plt.ylabel("Voltage (µV, offset by 1000)")
    plt.title(f"{e_series.name} — Shank {shank}")
    plt.xlim(t[0], t[0] + duration * 1.1)
    plt.tight_layout()
    plt.show()

## Plot preprocessed ElectricalSeries

This plots a chunk of `recording = sgs.SpikeSortingRecording.get_recording(spikesorting_group_key)` for each SortGroup.



In [None]:
for spikesorting_group_key in group_keys:
    recording = sgs.SpikeSortingRecording.get_recording(spikesorting_group_key)

    num_channels = recording.get_num_channels()
    channel_ids = recording.get_channel_ids()
    timestamps = recording.get_times()

    data_shank = recording.get_traces(start_frame=start_idx, end_frame=end_idx)
    time_axis = timestamps[start_idx:end_idx]

    plt.figure(figsize=(14, 5))
    for i in range(num_channels):
        y = data_shank[:, i] + i * offset
        plt.axhline(i * offset, color="gray", linestyle="--", lw=0.5)
        plt.plot(t, y, lw=0.6, alpha=alpha)
        label = f"{channel_ids[i]}"
        plt.text(
            t[-1] + 0.001 * (t[-1] - t[0]), i * offset, label, va="bottom", fontsize=9
        )
    plt.xlabel("Time (s)")
    plt.ylabel("Voltage (µV, offset by 1000)")
    plt.title(f"Recording ID {spikesorting_group_key['recording_id']}")
    plt.xlim(t[0], t[0] + duration * 1.1)
    plt.tight_layout()
    plt.show()

## Artifact Detection

In [None]:
print(group_keys)

artifact_detection_keys = []

for group_key in group_keys:
    key = {
        "recording_id": group_key["recording_id"],
        "artifact_param_name": artifact_param_name,
    }
    # Insert into the selection table
    sgs.ArtifactDetectionSelection.insert_selection(key)

    # Grab the primary key (artifact_id) and add to our list so we can insert into ArtifactDetection
    artifact_detection_keys.append((sgs.ArtifactDetectionSelection & key).fetch1("KEY"))

# Look at everything we inserted!
display(sgs.ArtifactDetectionSelection() & artifact_detection_keys)

print("Artifact detection keys:")
print(artifact_detection_keys)

Now call the `populate` method of `ArtifactDetection`.

Instead of just calling `sgs.ArtifactDetection.populate(artifact_detection_keys)` with all artifact_detection_keys, we only populate the missing keys.

In [None]:
# Populate ArtifactDetection for all artifact_detection_keys (ignoring ones already populated)

# Print the set of all possible keys
print(f"There are {len(artifact_detection_keys)} keys: {artifact_detection_keys}")

# Get the set of already-populated keys in ArtifactDetection
existing_keys = (sgs.ArtifactDetection & artifact_detection_keys).fetch(
    "KEY", as_dict=True
)
print(
    f"There are {len(existing_keys)} already in in ArtifactDetection: {existing_keys}"
)

# Find missing keys
missing_keys = [key for key in artifact_detection_keys if key not in existing_keys]
print(f"There are {len(missing_keys)} missing keys: {missing_keys}")

# Populate only missing entries
if missing_keys:
    sgs.ArtifactDetection().populate(missing_keys)
else:
    print("All group keys already populated.")

In [None]:
# Make sure everything worked!
display(sgs.ArtifactDetection() & artifact_detection_keys)

## Run Spike Sorting

The spike sorting pipeline is powered by `spikeinterface`, a community-developed Python package that enables one to easily apply multiple spike sorters to a single recording. Some spike sorters have special requirements, such as GPU. Others need to be installed separately from spyglass. In the Frank lab, we have been using `mountainsort4`, though the pipeline have been tested with `mountainsort5`, `kilosort2_5`, `kilosort3`, and `ironclust` as well.

When using `mountainsort5`, make sure to run `pip install mountainsort5`. `kilosort2_5`, `kilosort3`, and `ironclust` are MATLAB-based, but we can run these without having to install MATLAB thanks to `spikeinterface`. It does require downloading additional files (as singularity containers) so make sure to do `pip install spython`. These sorters also require GPU access, so also do ` pip install cuda-python` (and make sure your computer does have a GPU). 

In [None]:
# Insert into SpikeSortingSelection

spike_sorting_keys = []

for group_key in group_keys:
    # Sometimes not all of these correctly populated in ArtifactDetectionSelection but we want to move forward anyway.
    # So we do a check that the interval list actually exists so we can move forward with the ones that did
    art_id = (
        sgs.ArtifactDetectionSelection & {"recording_id": group_key["recording_id"]}
    ).fetch1("artifact_id")
    interval_list_entry = sgc.IntervalList() & {"interval_list_name": str(art_id)}
    if len(interval_list_entry.fetch()) == 0:
        print(f"No interval list entry for {art_id}, skipping.")
        continue

    ss_key = {
        "recording_id": group_key["recording_id"],
        "sorter": sorter,
        "nwb_file_name": nwb_file_name,
        "interval_list_name": str(art_id),
        "sorter_param_name": sorter_param_name,
    }
    # Insert into the selection table
    sgs.SpikeSortingSelection.insert_selection(ss_key)

    # Grab the primary key (sorting_id) and add to our list so we can insert into SpikeSorting
    spike_sorting_keys.append((sgs.SpikeSortingSelection & ss_key).proj().fetch1("KEY"))

# Look at everything we inserted!
display(sgs.SpikeSortingSelection() & spike_sorting_keys)

print("Spike sorting keys:")
print(spike_sorting_keys)

Now call the `populate` method of `SpikeSorting`.

Instead of just calling `sgs.SpikeSorting.populate(spike_sorting_keys)` with all spike_sorting_keys, we only populate the missing keys.

In [None]:
# Populate SpikeSorting for all spike_sorting_keys (ignoring ones already populated)

# Print the set of all possible keys
print(f"There are {len(spike_sorting_keys)} keys: {spike_sorting_keys}")

# Get the set of already-populated keys in sgs.SpikeSorting
existing_keys = (sgs.SpikeSorting & spike_sorting_keys).fetch("KEY", as_dict=True)
print(f"There are {len(existing_keys)} already in in sgs.SpikeSorting: {existing_keys}")

# Find missing keys
missing_keys = [key for key in spike_sorting_keys if key not in existing_keys]
print(f"There are {len(missing_keys)} missing keys: {missing_keys}")

# Populate only missing entries
if missing_keys:
    sgs.SpikeSorting.populate(missing_keys)
else:
    print("All group keys already populated.")

In [None]:
# Make sure everything worked!
display(sgs.SpikeSorting() & spike_sorting_keys)

The spike sorting results (spike times of detected units) are saved in an NWB file. We can access this in two ways. First, we can access it via the `fetch_nwb` method, which allows us to directly access the spike times saved in the `units` table of the NWB file. Second, we can access it as a `spikeinterface.NWBSorting` object. This allows us to take advantage of the rich APIs of `spikeinterface` to further analyze the sorting. 

In [None]:
for ss_key in spike_sorting_keys:
    sorting_nwb = (sgs.SpikeSorting & ss_key).fetch_nwb()
    sorting_si = sgs.SpikeSorting.get_sorting(ss_key)

Note that the spike times of `fetch_nwb` is in units of seconds aligned with the timestamps of the recording. The spike times of the `spikeinterface.NWBSorting` object is in units of samples (as is generally true for sorting objects in `spikeinterface`).

## Automatic Curation

Next step is to curate the results of spike sorting. This is often necessary because spike sorting algorithms are not perfect;
they often return clusters that are clearly not biological in origin, and sometimes oversplit clusters that should have been merged.
We have two main ways of curating spike sorting: by computing quality metrics followed by thresholding, and manually applying curation labels.
To do either, we first insert the spike sorting to `CurationV1` using `insert_curation` method.


In [None]:
curation_key_list = []

for ss_key in spike_sorting_keys:

    # Check if this sorting_id has already been inserted with curation_id=1
    initial_curation_key = {"sorting_id": str(ss_key["sorting_id"]), "curation_id": 0}
    initial_curation_entry = sgs.CurationV1() & initial_curation_key

    # If it hasn't been inserted yet, insert into the curation table
    if len(initial_curation_entry.fetch()) == 0:
        sgs.CurationV1.insert_curation(
            sorting_id=str(ss_key["sorting_id"]),
            description="initial automatic curation",
        )
    else:
        print(f"Entry for {initial_curation_key} already exists in sgs.CurationV1")

    curation_key_list.append(initial_curation_key)

# Look at everything we inserted!
display(sgs.CurationV1() & curation_key_list)

print("Curation keys (initial automatic curation):")
print(curation_key_list)

We will first do an automatic curation based on quality metrics. Under the hood, this part again makes use of `spikeinterface`. Some of the quality metrics that we often compute are the nearest neighbor isolation and noise overlap metrics, as well as SNR and ISI violation rate. For computing some of these metrics, the waveforms must be extracted and projected onto a feature space. Thus here we set the parameters for waveform extraction as well as how to curate the units based on these metrics (e.g. if `nn_noise_overlap` is greater than 0.1, mark as `noise`).

In [None]:
metric_curation_keys = []
for ss_key in spike_sorting_keys:
    mc_key = {
        "sorting_id": str(ss_key["sorting_id"]),
        "curation_id": 0,
        "waveform_param_name": waveform_param_name,
        "metric_param_name": metric_param_name,
        "metric_curation_param_name": metric_curation_param_name,
    }

    # Insert into selection table
    sgs.MetricCurationSelection.insert_selection(mc_key)

    # Grab the primary key (metric_curation_id) and add to our list so we can insert into MetricCuration
    metric_curation_keys.append((sgs.MetricCurationSelection & mc_key).fetch1("KEY"))

# Look at everything we inserted!
display(sgs.MetricCurationSelection() & metric_curation_keys)

print("Metric curation keys:")
print(metric_curation_keys)

Now call the `populate` method of `MetricCuration`.

Instead of just calling `sgs.MetricCuration.populate(metric_curation_keys)` with all metric_curation_keys, we only populate the missing keys.

In [None]:
# Populate MetricCuration for all metric_curation_keys (ignoring ones already populated)

# Print the set of all possible keys
print(f"There are {len(metric_curation_keys)} keys: {metric_curation_keys}")

# Get the set of already-populated keys in sgs.MetricCuration
existing_keys = (sgs.MetricCuration & metric_curation_keys).fetch("KEY", as_dict=True)
print(f"There are {len(existing_keys)} already in in MetricCuration: {existing_keys}")

# Find missing keys
missing_keys = [key for key in metric_curation_keys if key not in existing_keys]
print(f"There are {len(missing_keys)} missing keys: {missing_keys}")

# Populate only missing entries
if missing_keys:
    sgs.MetricCuration().populate(missing_keys)
else:
    print("All group keys already populated.")

In [None]:
# Make sure everything worked!
display(sgs.MetricCuration() & metric_curation_keys)

To do another round of curation, fetch the relevant info and insert back into CurationV1 using `insert_curation`.


Because this is the second round, we have `curation_id=1` and `parent_curation_id=0` (to match the `curation_id=0` of the first round we inserted)


In [None]:
curation_key_list_round2 = []

for mc_key in metric_curation_keys:

    # Check if this sorting_id has already been inserted with curation_id=1
    round_2_key = {
        "sorting_id": str((sgs.MetricCurationSelection & mc_key).fetch1("sorting_id")),
        "curation_id": 1,
    }
    round_2_curation_entry = sgs.CurationV1() & round_2_key

    # If it hasn't been inserted yet, insert into the curation table for a second round
    if len(round_2_curation_entry.fetch()) == 0:
        labels = sgs.MetricCuration.get_labels(mc_key)
        merge_groups = sgs.MetricCuration.get_merge_groups(mc_key)
        metrics = sgs.MetricCuration.get_metrics(mc_key)
        sgs.CurationV1.insert_curation(
            sorting_id=(sgs.MetricCurationSelection & mc_key).fetch1("sorting_id"),
            parent_curation_id=0,
            labels=labels,
            merge_groups=merge_groups,
            metrics=metrics,
            description="after metric curation",
        )
    else:
        print(f"Entry for {round_2_key} already exists in sgs.CurationV1")

    curation_key_list_round2.append(round_2_key)

# Look at everything we inserted!
display(sgs.CurationV1() & curation_key_list_round2)

print("Curation keys (after metric curation):")
print(curation_key_list_round2)

In [None]:
# Combined output of both rounds of curation!
display(sgs.CurationV1() & (curation_key_list_round2 + curation_key_list))

## For now, we skip manual curation.

## Insert into merge table for downstream usage 

Regardless of Curation method used, to make use of spikeorting results in downstream pipelines like Decoding, we will need to insert it into the `SpikeSortingOutput` merge table. 

In [None]:
from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput

# Insert the second round curation results into the merge table
for key in curation_key_list_round2:

    # Check if this key has already been inserted
    merge_insert_key = (sgs.CurationV1 & key).fetch("KEY", as_dict=True)
    merge_entry = SpikeSortingOutput() & merge_insert_key

    # If it hasn't been inserted yet, insert it
    if len(merge_entry.fetch()) == 0:
        SpikeSortingOutput.insert(merge_insert_key, part_name="CurationV1")
    else:
        print(f"Entry for {merge_insert_key} already exists in SpikeSortingOutput")

Look at our entries in `SpikeSortingOutput`

In [None]:
# For now we just restrict on nwb file name and curation id
# We could also restrict on sorter, interval_list_name, etc
selection_key = {"nwb_file_name": nwb_file_name, "curation_id": 1}
merge_ids = SpikeSortingOutput().get_restricted_merge_ids(
    key=selection_key, sources="v1", as_dict=True
)

# View all of our entries in the table
display(SpikeSortingOutput() & merge_ids)

---------------------------------------


# Decode from sorted spikes

The elements we will need to decode with sorted spikes are:
- `PositionGroup`
- `SortedSpikesGroup`
- `DecodingParameters`
- `encoding_interval`
- `decoding_interval`


In [None]:
from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput
from spyglass.spikesorting.analysis.v1.group import SortedSpikesGroup
import spyglass.spikesorting.v1 as sgs

sorter_keys = {
    "nwb_file_name": nwb_file_name,
    "sorter": sorter,
    "curation_id": 1,
}
# Check out the set of sorting we'll use
display(
    (sgs.SpikeSortingSelection & sorter_keys)
    * (SpikeSortingOutput.CurationV1 & sorter_keys)
)

# Get the merge_ids for the selected sorting
spikesorting_merge_ids = SpikeSortingOutput().get_restricted_merge_ids(
    sorter_keys, restrict_by_artifact=False
)
print(
    f"We have {len(spikesorting_merge_ids)} merge_ids for this group {spikesorting_merge_ids}"
)

In [None]:
# Create a new sorted spikes group
ss_group_name = "sorted_spikes_group"

group_entry = SortedSpikesGroup & {
    "nwb_file_name": nwb_file_name,
    "sorted_spikes_group_name": ss_group_name,
}

# If the group hasn't been created yet, create it
if len(group_entry.fetch()) == 0:
    SortedSpikesGroup().create_group(
        group_name=ss_group_name,
        nwb_file_name=nwb_file_name,
        keys=[
            {"spikesorting_merge_id": merge_id} for merge_id in spikesorting_merge_ids
        ],
        unit_filter_params_name=unit_filter_params_name,
    )
else:
    print(f"SortedSpikesGroup already exists!")

# Check out the new group
display(
    SortedSpikesGroup
    & {"nwb_file_name": nwb_file_name, "sorted_spikes_group_name": ss_group_name}
)

# And look at the sorting within the group
display(
    SortedSpikesGroup.Units
    & {
        "nwb_file_name": nwb_file_name,
        "sorted_spikes_group_name": ss_group_name,
        "unit_filter_params_name": unit_filter_params_name,
    }
)

## Grouping Position Data

Note that we can use the `upsample_rate` parameter to define the rate to which position data will be upsampled to to for decoding in Hz. This is useful if we want to decode at a finer time scale than the position data sampling frequency. In practice, a value of 500Hz is used in many analyses. Skipping or providing a null value for this parameter will default to using the position sampling rate.

You will also want to specify the name of the position variables if they are different from the default names. The default names are `position_x` and `position_y`.

In [None]:
from spyglass.position import PositionOutput
import spyglass.position as sgp

pos_group_name = "sorted_spikes_pos_group"

# Set up position key for position we want to use to decode
position_selection_key = {
    "nwb_file_name": nwb_file_name,
    "interval_list_name": "pos 0 valid times",  # Berke lab has only one epoch, so this is always our interval list name
    "trodes_pos_params_name": trodes_pos_params_name,
}

# Insert into selection table
sgp.v1.TrodesPosSelection.insert1(
    position_selection_key,
    skip_duplicates=True,
)

# Fetch the primary key so we can populate the the position table
# (it's actually the same as position_selection_key so we could have just used that)
position_key = (sgp.v1.TrodesPosSelection() & position_selection_key).fetch1("KEY")
pos_entry = PositionOutput.TrodesPosV1() & position_key

# If we don't have a position entry with our decoding parameters yet, insert it
if len(pos_entry.fetch()) == 0:
    sgp.v1.TrodesPosV1.populate(position_key)
else:
    print(f"Entry for {position_key} already exists in sgp.v1.TrodesPosV1")

# Look at it!
display(PositionOutput.TrodesPosV1 & position_key)

print(f"Pos selection key: {position_selection_key}")
print(f"Pos key: {position_key}")

In [None]:
from spyglass.decoding.v1.core import PositionGroup

position_merge_ids = (PositionOutput.TrodesPosV1 & position_key).fetch("merge_id")

pos_group_entry = PositionGroup & {
    "nwb_file_name": nwb_file_name,
    "position_group_name": pos_group_name,
}

# If we don't have a position group yet, create it!
if len(pos_group_entry.fetch()) == 0:
    PositionGroup().create_group(
        nwb_file_name=nwb_file_name,
        group_name=pos_group_name,
        keys=[{"pos_merge_id": merge_id} for merge_id in position_merge_ids],
        upsample_rate=500,
    )
else:
    print(f"Position group already exists!")

## Decoding

Now we can decode the position using the sorted spikes using the `SortedSpikesDecodingSelection` table. 

In [None]:
from spyglass.decoding import SortedSpikesDecodingSelection

selection_key = {
    "sorted_spikes_group_name": ss_group_name,
    "unit_filter_params_name": unit_filter_params_name,
    "position_group_name": pos_group_name,
    "decoding_param_name": decoding_param_name,
    "nwb_file_name": nwb_file_name,
    "encoding_interval": "00_r1",  # to encode using the entire session, this is always our interval list name
    "decoding_interval": "epoch0_block1",
    "estimate_decoding_params": False,
}

SortedSpikesDecodingSelection.insert1(
    selection_key,
    skip_duplicates=True,
)

Run decoding

In [None]:
from spyglass.decoding.v1.sorted_spikes import SortedSpikesDecodingV1

decoding_entry = SortedSpikesDecodingV1 & selection_key

# Run the decoding if we don't have output yet
if len(decoding_entry.fetch()) == 0:
    SortedSpikesDecodingV1.populate(selection_key)
else:
    print("Decoding entry already exists!")

In [None]:
from spyglass.decoding.decoding_merge import DecodingOutput

display(DecodingOutput.SortedSpikesDecodingV1 & selection_key)

# Fetch results
decoding_results = (SortedSpikesDecodingV1 & selection_key).fetch_results()
display(decoding_results)

### Plot place fields

In [None]:
from spyglass.decoding.decoding_merge import DecodingOutput
import matplotlib.pyplot as plt

max_firing_rate = 15  # spikes/s
show_colorbar = False

# Fetch classifier
classifier = (SortedSpikesDecodingV1 & selection_key).fetch_model()
fs = classifier.sampling_frequency

# Fetch place fields and reshape
place_fields = classifier.encoding_model_[("", 0)]["place_fields"]  # units, place_bins
print(place_fields.shape)
place_fields = place_fields.reshape(
    (-1, *classifier.environments[0].centers_shape_)
)  # units, x, y
print(place_fields.shape)

# Set up subplots
n_units = place_fields.shape[0]

n_cols = 10  # number of columns in the grid
n_rows = int(np.ceil(n_units / n_cols))
fig, axes = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows))

# Plot the place field for each unit
mappable = None
for i in range(n_units):
    ax = axes.flat[i]
    m = ax.pcolormesh(
        classifier.environments[0].edges_[0],
        classifier.environments[0].edges_[1],
        place_fields[i].T * fs,
        vmin=0,
        vmax=max_firing_rate,
    )

    # Optional add colorbar
    if show_colorbar:
        fig.colorbar(m, ax=ax, label="spikes/s")

    ax.set_title(f"Unit {i}", fontsize=8)
    ax.set_aspect("equal")
    ax.set_axis_off()

# Turn off any unused subplots
for j in range(n_units, n_rows * n_cols):
    axes.flat[j].axis("off")

plt.tight_layout()
plt.show()

## Create a figurl

In [None]:
# from non_local_detector.visualization import (
#     create_interactive_2D_decoding_figurl,
# )

# (
#     position_info,
#     position_variable_names,
# ) = SortedSpikesDecodingV1.fetch_position_info(selection_key)
# results_time = decoding_results.acausal_posterior.isel(intervals=0).time.values
# position_info = position_info.loc[results_time[0] : results_time[-1]]

# env = SortedSpikesDecodingV1.fetch_environments(selection_key)[0]
# spike_times = SortedSpikesDecodingV1.fetch_spike_data(selection_key)

# url = create_interactive_2D_decoding_figurl(
#     position_time=position_info.index.to_numpy(),
#     position=position_info[position_variable_names],
#     env=env,
#     results=decoding_results,
#     posterior=decoding_results.acausal_posterior.isel(intervals=0)
#     .unstack("state_bins")
#     .sum("state"),
#     spike_times=spike_times,
#     head_dir=position_info["orientation"],
#     speed=position_info["speed"],
# )
# url

-----------------------------------------------

# Evaluate!

Check out our units and plot ISI and place info, etc

In [None]:
import pandas as pd

# Set up a dataframe with one row per unit
all_units = pd.DataFrame()
for curation_key in curation_key_list_round2:
    units_for_this_sortgroup = sgs.CurationV1().get_sorting(
        curation_key, as_dataframe=True
    )
    all_units = pd.concat((all_units, units_for_this_sortgroup))

# Exclude units with "noise" in curation_label
good_units = all_units[~all_units["curation_label"].apply(lambda x: "noise" in x)]
display(good_units)

In [None]:
# Add mean isi for each unit
good_units["mean_isi"] = good_units["spike_times"].apply(
    lambda s: np.diff(np.sort(s)).mean() if len(s) > 1 else np.nan
)
display(good_units)

# Plot histogram of ISIs
plt.figure(figsize=(6, 4))
plt.hist(good_units["mean_isi"].dropna(), bins=50)
plt.xlabel("Mean ISI (s)")
plt.ylabel("Count")
plt.title("Distribution of mean ISIs")
plt.show()

# Plot histpgram of spike counts
plt.figure(figsize=(6, 4))
plt.hist(good_units["num_spikes"].dropna(), bins=50)
plt.xlabel("Number of spikes")
plt.ylabel("Count")
plt.title("Distribution of spike counts")
plt.show()

### Plot ISI histograms for each unit

In [None]:
import matplotlib.pyplot as plt
import numpy as np

max_isi = 0.15

# Set up subplots
n_units = len(good_units)

n_cols = 10  # number of columns in the grid
n_rows = int(np.ceil(n_units / n_cols))
fig, axes = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows))

# Plot the ISI histogram for each unit
for plot_idx, (i, row) in enumerate(good_units.iterrows()):
    ax = axes.flat[plot_idx]
    spikes = np.sort(row["spike_times"])
    isi = np.diff(spikes)
    ax.hist(isi[isi <= max_isi], bins=50)
    ax.set_title(f"Unit {plot_idx}", fontsize=8)
    ax.set_xlim(0, max_isi)

# Turn off any unused subplots
for j in range(n_units, n_rows * n_cols):
    axes.flat[j].axis("off")

plt.tight_layout()
plt.show()

In [None]:
# Fetch the position info we used for decoding
pos_merge_id = (PositionOutput.TrodesPosV1() & position_key).fetch("KEY")
pos_df = (PositionOutput & pos_merge_id).fetch1_dataframe()
display(pos_df)

# Make it an array with shape (n_time, 2)
positions = pos_df[["position_x", "position_y"]].to_numpy()

In [None]:
import numpy as np

# Get timestamps of position used for decoding
timestamps = pos_df.index
# Extend the last bin slightly to catch trailing spikes
timestamps = np.append(timestamps, timestamps[-1] + np.diff(timestamps[-2:]).mean())

# Set up an array with shape (n_time, n_neurons)
# that is a binary indicator of whether there was a spike in a given time bin for a given neuron
n_time = len(timestamps) - 1
n_neurons = len(good_units)
spikes = np.zeros((n_time, n_neurons), dtype=int)

for i, spike_times in enumerate(good_units["spike_times"]):
    spike_counts, _ = np.histogram(spike_times, bins=timestamps)
    spikes[:, i] = (spike_counts > 0).astype(int)

In [None]:
# Sanity check that this worked.
spikes_per_neuron = spikes.sum(axis=0)
original_counts = good_units["spike_times"].apply(len).to_numpy()
print(spikes_per_neuron)
print(original_counts)
print(original_counts - spikes_per_neuron)

### Plot spike locations for each unit

In [None]:
import matplotlib.pyplot as plt
import numpy as np

n_time, n_units = spikes.shape

# Set up subplots
n_cols = 10  # number of columns in the grid
n_rows = int(np.ceil(n_units / n_cols))
fig, axes = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows))

# Plot the rat's position when the cell fires for each unit
for i in range(n_units):
    ax = axes.flat[i]
    spike_mask = spikes[:, i].astype(bool)
    pos_spike = positions[spike_mask]

    ax.scatter(positions[:, 0], positions[:, 1], s=1, alpha=0.1, color="gray")
    ax.scatter(pos_spike[:, 0], pos_spike[:, 1], s=1, color="red")
    ax.set_title(f"Unit {i}", fontsize=8)
    ax.set_aspect("equal")
    ax.set_axis_off()

# Turn off any unused subplots
for j in range(n_units, n_rows * n_cols):
    axes.flat[j].axis("off")

plt.tight_layout()
plt.show()

In [None]:
# All of the relevant primary keys we created.

print("Keys for sgs.SpikeSortingRecording")
print(group_keys)
display(sgs.SpikeSortingRecording & group_keys)

print("Keys for sgs.ArtifactDetection")
print(artifact_detection_keys)
display(sgs.ArtifactDetection() & artifact_detection_keys)

print("Keys for sgs.SpikeSorting")
print(spike_sorting_keys)
display(sgs.SpikeSorting() & spike_sorting_keys)

print("Keys for sgs.CurationV1 (initial round of curation)")
print(curation_key_list)
display(sgs.CurationV1() & curation_key_list)

print("Keys for sgs.MetricCurationSelection")
print(metric_curation_keys)
display(sgs.MetricCurationSelection() & metric_curation_keys)

print("Keys for sgs.CurationV1 (after metric curation)")
print(curation_key_list_round2)
display(sgs.CurationV1() & curation_key_list_round2)

print("Keys for SpikeSortingOutput")
print(merge_ids)
display(SpikeSortingOutput() & merge_ids)

print("Key for DecodingOutput.SortedSpikesDecodingV1")
print(selection_key)
display(DecodingOutput.SortedSpikesDecodingV1 & selection_key)