Connect to database

In [3]:
import datajoint as dj

dj.config['database.host'] = "gl-ash.biostr.washington.edu"
dj.config['database.user'] = "gabby"
dj.config['database.port'] = 3306

dj.conn()

[2026-02-18 13:38:31,013][INFO]: DataJoint 0.14.6 connected to gabby@gl-ash.biostr.washington.edu:3306


DataJoint connection (connected) gabby@gl-ash.biostr.washington.edu:3306

Imports

In [58]:
import os
import numpy as np
from tqdm.notebook import tqdm
import pandas as pd

import spyglass.common as sgc
import spyglass.spikesorting.v1 as sgs
from spyglass.common.common_interval import Interval
from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput
from spyglass.decoding.v1.waveform_features import UnitWaveformFeatures

import figpack.views as vv

from gl_spyglass.utils.common_neural_functions import validate_references

In [5]:
os.environ['FIGPACK_BUCKET'] = 'gillespielab'

# TODO: update this with your own custom FIGPACK_API_KEY (contact Jeremy Magland if you don't have one)
os.environ['FIGPACK_API_KEY'] = '17979e3bc8880ea91520c1dc2e8ab6a97fc665893fe2728e9d642923c3dcf31c'

Choose tetrodes for your relevant subj(s), date(s)

In [7]:
subj_sess_tetrodes = {
    'pippin': {
        20210421: {
            'can1_ca1_tetrodes': [6, 9, 30],
            'can2_ca1_tetrodes': [35, 36, 40],
        },
    },
}

Choose parameters

Choose between `plot_type` 'one_from_many_tet' or 'all_from_one_tet' depending on which kind of artifact detection you're visualizing. For coincicent spike detection, it's useful to look at one electrode from many tetrodes so that you can see if the cross-tetrode artifacts are getting caught. For threshold artifact detection (from `ArtifactDetection`) it's useful to look at all the electrodes from a single tetrode because it works per sort group (tetrode) to look for big artifacts that are present across multiple electrodes in the tetrode).

In [None]:
subj = 'pippin'
date = 20210421
nwb_file_name = f'{subj}{date}_.nwb'
interval_list_name = '02_r1'
plot_type = 'one_from_many_tet' 

In [68]:
subj_info = subj_sess_tetrodes[subj][date]
epoch = int(interval_list_name[:2])
pos_interval_list_name = (sgc.IntervalList() & {'nwb_file_name': nwb_file_name, 'pipeline': 'position'}).fetch('interval_list_name')[epoch - 1]
sort_interval_list_name = interval_list_name

if plot_type == 'one_from_many_tet':
    can1_ca1_tetrodes = subj_info['can1_ca1_tetrodes']
    can2_ca1_tetrodes = subj_info['can2_ca1_tetrodes']

    # get electrodes info
    electrodes_df, val_can_refs = validate_references(nwb_file_name, is_copy=True)

    # narrow down electrodes_df to good electrodes
    electrodes_df = electrodes_df[electrodes_df['bad_channel'] == 'False']

    # convert tetrode numbers to electrode group numbers 
    can1_ca1_elec_groups = np.asarray(can1_ca1_tetrodes) - 1
    can2_ca1_elec_groups = np.asarray(can2_ca1_tetrodes) - 1

    # find the electrode id for the first good channel on the tetrode
    can1_ca1_elecs = [electrodes_df.loc[electrodes_df['electrode_group_name'] == elec_group, 'electrode_id'].values[0] for elec_group in can1_ca1_elec_groups]
    can2_ca1_elecs = [electrodes_df.loc[electrodes_df['electrode_group_name'] == elec_group, 'electrode_id'].values[0] for elec_group in can2_ca1_elec_groups]

    # select elecs and elec labels
    plot_elecs = np.concatenate([can1_ca1_elecs, can2_ca1_elecs])
    plot_elec_labels = np.concatenate([[f'can1 {e + 1}: elec {elec}' for e, elec in enumerate(can1_ca1_elecs)], [f'can2 {e + 1}: elec {elec}' for e, elec in enumerate(can2_ca1_elecs)]])

if plot_type == 'all_from_one_tet':
    # select which cannula and tetrode index to plot from
    can = 1
    tet_idx = 0
    tetrode = subj_info[f'can{can}_ca1_tetrodes'][tet_idx]

    # get electrodes info
    electrodes_df, val_can_refs = validate_references(nwb_file_name, is_copy=True)

    # narrow down electrodes_df to good electrodes
    electrodes_df = electrodes_df[electrodes_df['bad_channel'] == 'False']

    # find which electrode ids corresponds to this tetrode
    elec_group = tetrode - 1
    sort_group_id = (sgs.SortGroup.SortGroupElectrode() & {'nwb_file_name': nwb_file_name, 'electrode_group_name': elec_group}).fetch('sort_group_id')[0]
    plot_elecs = electrodes_df.loc[electrodes_df['electrode_group_name'] == elec_group, 'electrode_id'].values
    plot_elec_labels = np.asarray([f'can{can} {e + 1}: elec {elec}' for e, elec in enumerate(plot_elecs)])

# set colors
colors = ['black']*len(plot_elecs)

Load in the spike-sorted (600-6000 Hz) data for each selected electrode

In [69]:
# Load in the spike recording data from each electrode id of interest
plot_sort_groups = [(sgs.SortGroup.SortGroupElectrode() & {'nwb_file_name': nwb_file_name, 'electrode_id': elec}).fetch1('sort_group_id') for elec in plot_elecs]
probe_elecs = electrodes_df.loc[electrodes_df['electrode_id'].isin(plot_elecs), 'probe_electrode'].values
elec_dict = {}
for elec, sort_group, probe_elec in zip(plot_elecs, plot_sort_groups, probe_elecs):
    elec_dict[elec] = [sort_group, probe_elec]

for e, elec in tqdm(enumerate(elec_dict.keys())):
    sort_group, probe_elec = elec_dict[elec]
    recording_id = (sgs.SpikeSortingRecordingSelection() & {'nwb_file_name': nwb_file_name, 'interval_list_name': sort_interval_list_name, 'sort_group_id': sort_group}).fetch1('recording_id')
    spike_sorting_recording = (sgs.SpikeSortingRecording & {'recording_id': recording_id}).fetch_nwb()
    if len(spike_sorting_recording) == 0:
        continue
    timestamps = spike_sorting_recording[0]['object_id'].timestamps[:]
    data = spike_sorting_recording[0]['object_id'].data[:, probe_elec]
    if e == 0:
        spike_recording_data = np.zeros((len(plot_elecs), data.shape[0]))
    spike_recording_data[e, :] = data

# convert to dataframe
spike_recording_df = pd.DataFrame(spike_recording_data.T, columns=plot_elecs, index=timestamps)

0it [00:00, ?it/s]

Set plotting parameters

In [None]:
full_data = spike_recording_df
elecs = plot_elecs 
labels = plot_elec_labels

time_window = None
data_type = 'dataframe'
downsamp=1  # (no downsampling)
zero_win_start = False
hide_nav = False 
hide_time_labels = False
sampling_frequency_hz = 30000  # make sure to update this if your data is not sampled at 30 kHz!
elec_separation = 1000  # can adjust this if you want more or less separation between the electrodes in the plot
plot_artifacts = True 
plot_spikes = False  # can only plot spikes for a single sort group (all_from_one_tet)
artifact_type = 'coincident_spikes'  # 'coincident_spikes' or 'large_amplitude_events' 

Process the data before visualization

In [71]:
# note that coincident_spikes should be used for one_from_many_tets while large_amplitude_events should be used for all_from_one_tet
if (artifact_type == 'coincident_spikes') & (plot_type == 'all_from_one_tet'):
    raise Warning("coincident_spikes artifact type is not recommended for all_from_one_tet plot type, consider switching to large_amplitude_events for better visualization of artifacts")
if (artifact_type == 'large_amplitude_events') & (plot_type == 'one_from_many_tets'):
    raise Warning("large_amplitude_events artifact type is not recommended for one_from_many_tets plot type, consider switching to coincident_spikes for better visualization of artifacts")

if (plot_spikes) & (plot_type == 'one_from_many_tet'):
    raise Warning("spikes cannot be plotted for one_from_many_tet plot type, consider switching to all_from_one_tet if you want to plot spikes")

In [72]:
# reformat the data to 0 the time windows to the interval list start time
int_start_time, int_end_time = (sgc.IntervalList() & {'nwb_file_name': nwb_file_name, 'interval_list_name': interval_list_name}).fetch1('valid_times')[0]
int_length = int_end_time - int_start_time

if time_window:
    win_length = time_window[1] - time_window[0]
    int_window_start = time_window[0]
    int_window_end = time_window[1]
else:
    win_length = None
    int_window_start = 0 
    int_window_end = int_length

if data_type == 'eseries':
    win_mask = (full_data.timestamps[::downsamp] >= (int_start_time + int_window_start)) & (full_data.timestamps[::downsamp] < (int_start_time + int_window_end))    
    timestamps = full_data.timestamps[::downsamp][win_mask]
if data_type == 'dataframe':
    timestamps = full_data.index.values
    win_mask = (timestamps[::downsamp] >= (int_start_time + int_window_start)) & (timestamps[::downsamp] < (int_start_time + int_window_end))
    timestamps = timestamps[::downsamp][win_mask]

if zero_win_start:
    win_start_time = timestamps[0]
    timestamps -= win_start_time
else:
    timestamps -= int_start_time  # zero to the beginning of the interval list

# initialize TimeSeriesGraph view
view = vv.TimeseriesGraph(
    legend_opts={"location": "northeast"},
    hide_x_gridlines=True,
    hide_nav_toolbar=hide_nav,
    hide_time_axis_labels=hide_time_labels,
    y_label='Voltage (uV)',
)

Load in and plot artifacts

In [73]:
if plot_artifacts:
    # Load in artifact times
    if artifact_type == 'large_amplitude_events':
        # plot spike artifact intervals from ArtifactDetection
        spike_artifacts_params_name = 'amp_3000_0.5_prop'
        recording_id = (sgs.SpikeSortingRecordingSelection() & {
                'nwb_file_name': nwb_file_name, 
                'interval_list_name': interval_list_name, 
                'sort_group_id': sort_group_id,
                'preproc_param_name': 'default'
            }).fetch1('recording_id')
        try:
            artifact_id = str((sgs.ArtifactDetectionSelection() & {'artifact_param_name': spike_artifacts_params_name, 'recording_id': recording_id}).fetch1('artifact_id'))
            spike_artifact_times = (sgc.IntervalList() & {'nwb_file_name': nwb_file_name, 'interval_list_name': artifact_id, 'pipeline': 'spikesorting_artifact_v1'}).fetch1('valid_times')
        except Exception as e:
            print(f'WARNING: ArtifactDetection not yet populated for {nwb_file_name}, {interval_list_name} with artifact param name {spike_artifacts_params_name}, skipping')
            spike_artifact_times = []
    
    if artifact_type == 'coincident_spikes':
        # plot spike artifact intervals from custom coincident spike detection pipeline
        spike_closeness_threshold = 0.00005
        max_coincident_fraction = 0.5
        removal_window_s = 0.001
        pipeline = 'coincident spike detection'
        # spike_artifacts_params_name = f'coincident spike detection +- rem_{removal_window_s} close_{spike_closeness_threshold} frac_{max_coincident_fraction}'

        epoch = int(interval_list_name[:2])
        pos_interval_list_name = (
            sgc.IntervalList()
            & {"nwb_file_name": nwb_file_name, "pipeline": "position"}
        ).fetch("interval_list_name")[epoch - 1]
        coinc_interval_list_name =  f'{pos_interval_list_name} coincident_spikes_removed_times rem_{removal_window_s} close_{spike_closeness_threshold} frac_{max_coincident_fraction}'

        spike_valid_times = (sgc.IntervalList() & {'nwb_file_name': nwb_file_name, 
                                    'interval_list_name': coinc_interval_list_name,
                                    # 'interval_list_name': f'{pos_interval_list_name} +- coincident_spikes_removed_times',
                                    'pipeline': pipeline}).fetch1('valid_times')
                                    # 'pipeline': 'coincident spike detection'}).fetch1('valid_times')

        spike_artifact_times = Interval((sgc.IntervalList() & {'nwb_file_name': nwb_file_name, 'interval_list_name': pos_interval_list_name}).fetch1('valid_times')).subtract(spike_valid_times).times

    # Plot the artifact intervals
    if len(spike_artifact_times) == 0:
        print(f'No artifacts were detected with artifact params {coinc_interval_list_name}')
    else:
        # convert artifact times to be on the same timescale as the data (zeroed to start of the interval list)
        if zero_win_start:
            spike_artifact_times = spike_artifact_times - win_start_time
        else:
            spike_artifact_times = spike_artifact_times - int_start_time

        last_timestamp = timestamps[-1]
        last_artifact = len(spike_artifact_times)
        for a, (start, end) in enumerate(spike_artifact_times):
            if end > last_timestamp:
                last_artifact = a
                break
        
        first_timestamp = timestamps[0]
        first_artifact = 0
        for a, (start, end) in enumerate(spike_artifact_times):
            if start > first_timestamp:
                first_artifact = a
                break

        # plot artifacts interval times
        view.add_interval_series(
            name='detected artifacts',
            t_start=np.asarray(spike_artifact_times, dtype=np.float64)[first_artifact:last_artifact, 0],
            t_end=np.asarray(spike_artifact_times, dtype=np.float64)[first_artifact:last_artifact, 1],
            color='lightcoral',
            border_color='lightcoral',
        )

No artifacts were detected with artifact params pos 1 valid times coincident_spikes_removed_times rem_0.001 close_5e-05 frac_0.5


Load in and plot spike times

In [74]:
spike_time_window = 0.00005  # window to plot around each spike time (in seconds)
if plot_spikes:
    # load in spike times to confirm that we're picking up the spikes we think we are with the clusterless thresholder
    features_param_name = 'amplitude'
    sort_interval_name = interval_list_name
    sorter_name = 'clusterless_thresholder'
    sorting_param_name = 'default_clusterless'

    recording_id = (sgs.SpikeSortingRecordingSelection() & {
                        'nwb_file_name': nwb_file_name, 
                        'interval_list_name': sort_interval_name,
                        'preproc_param_name': 'default',
                        'sort_group_id': sort_group_id,
                    }).fetch1('recording_id')

    spikesorting_merge_id = ((SpikeSortingOutput.CurationV1 * sgs.SpikeSortingSelection) & {
        'nwb_file_name': nwb_file_name,
        'recording_id': recording_id,
        'sorter': sorter_name,
        'sorter_param_name': sorting_param_name,
    }).fetch1('merge_id')

    waveform_s_key = {
            'spikesorting_merge_id': spikesorting_merge_id,
            'features_param_name': features_param_name,
        }

    sort_group_spike_times, spike_waveform_features = (
        UnitWaveformFeatures & waveform_s_key
    ).fetch_data()

    # fetch the spikes times for one particular sort group
    sort_group_spike_time_int_starts = sort_group_spike_times[0] - spike_time_window
    sort_group_spike_time_int_ends = sort_group_spike_times[0] + spike_time_window
    sort_group_spike_time_ints = list(zip(sort_group_spike_time_int_starts, sort_group_spike_time_int_ends))

    # add spike times to plot
    if len(sort_group_spike_time_ints) == 0:
        print(f'No spike time intervals were detected with sorter {sorter_name}')
    else:
        # convert artifact times to be on the same timescale as the data (zeroed to start of the interval list)
        if zero_win_start:
            sort_group_spike_time_ints = sort_group_spike_time_ints - win_start_time
        else:
            sort_group_spike_time_ints = sort_group_spike_time_ints - int_start_time

        last_timestamp = timestamps[-1]
        last_spike = len(sort_group_spike_time_ints)
        for a, (start, end) in enumerate(sort_group_spike_time_ints):
            if end > last_timestamp:
                last_spike = a
                break
        
        first_timestamp = timestamps[0]
        first_spike = 0
        for a, (start, end) in enumerate(sort_group_spike_time_ints):
            if start > first_timestamp:
                first_spike = a
                break

        # plot spike times interval times
        view.add_interval_series(
            name='detected spikes',
            t_start=np.asarray(sort_group_spike_time_ints, dtype=np.float64)[first_spike:last_spike, 0],
            t_end=np.asarray(sort_group_spike_time_ints, dtype=np.float64)[first_spike:last_spike, 1],
            color='green',
            border_color='green',
        )

Plot spike-sorted electrode data

In [75]:
# Plot electrodes in a uniform time series
view.add_uniform_series(
    name='ca1 electrodes',
    start_time_sec=timestamps[0],
    sampling_frequency_hz=sampling_frequency_hz,
    data=full_data[elecs].values,
    channel_names=labels,
    colors=colors,
    width=1,
    channel_spacing=elec_separation,
    timestamps_for_inserting_nans=timestamps,
)

Generate and show plot

In [76]:
layout = vv.Box(
    title=f"{nwb_file_name} {interval_list_name}",
    direction='vertical',
    items=[
        vv.LayoutItem(view, title='ca1 electrodes', stretch=1),
    ]
)
layout.show(title=f"{nwb_file_name} {interval_list_name} {coinc_interval_list_name if (artifact_type == 'coincident_spikes') else ''}", upload=True)

Found 46 files to upload, total size: 3735.76 MB
Uploading 46 files in batches of 20 with up to 16 concurrent uploads per batch...
Processing batch 1/3 (20 files)...
Uploaded 1/46: extension_manifest.json
Uploaded 2/46: index.html
Uploaded 3/46: data.zarr/_consolidated_22.dat
Uploaded 4/46: data.zarr/_consolidated_33.dat
Uploaded 5/46: data.zarr/_consolidated_35.dat
Uploaded 6/46: data.zarr/_consolidated_1.dat
Uploaded 7/46: data.zarr/_consolidated_0.dat
Uploaded 8/46: data.zarr/_consolidated_26.dat
Uploaded 9/46: data.zarr/_consolidated_16.dat
Uploaded 10/46: data.zarr/_consolidated_17.dat
Uploaded 11/46: data.zarr/_consolidated_4.dat
Uploaded 12/46: data.zarr/_consolidated_25.dat
Uploaded 13/46: data.zarr/_consolidated_10.dat
Uploaded 14/46: data.zarr/_consolidated_14.dat
Uploaded 15/46: data.zarr/_consolidated_28.dat
Uploaded 16/46: data.zarr/_consolidated_29.dat
Uploaded 17/46: data.zarr/_consolidated_27.dat
Uploaded 18/46: data.zarr/_consolidated_3.dat
Uploaded 19/46: data.zarr/_c

'https://gillespielab.figpack.org/figures/default/d502ce9ab999deedbacad7ad/index.html'