Connect to database

In [2]:
import datajoint as dj

dj.config['database.host'] = "gl-ash.biostr.washington.edu"
dj.config['database.user'] = "gabby"
dj.config['database.port'] = 3306

dj.conn()

  import pkg_resources
[2026-02-18 12:54:45,892][INFO]: DataJoint 0.14.6 connected to gabby@gl-ash.biostr.washington.edu:3306


DataJoint connection (connected) gabby@gl-ash.biostr.washington.edu:3306

Imports

In [29]:
import os
import numpy as np

import spyglass.common as sgc
from spyglass.decoding.v1.clusterless import ClusterlessDecodingV1
import spyglass.lfp as lfp
from spyglass.lfp.analysis.v1 import LFPBandV1
from spyglass.lfp.lfp_merge import LFPOutput
import spyglass.position as sgp

import figpack.views as vv

from non_local_detector.visualization.static import get_multiunit_firing_rate

from gl_spyglass.utils.figpack_visualization_functions import create_2D_decode_view_figpack
from gl_spyglass.utils.interval_functions import insert_mobile_times_interval
from gl_spyglass.custom_spyglass_tables.grouped_ripple import LFPBandGroup, RippleLFPSelection, RippleParameters, RippleTimesGroup

In [30]:
os.environ['FIGPACK_BUCKET'] = 'gillespielab'

# TODO: update this with your own custom FIGPACK_API_KEY (contact Jeremy Magland if you don't have one)
os.environ['FIGPACK_API_KEY'] = '17979e3bc8880ea91520c1dc2e8ab6a97fc665893fe2728e9d642923c3dcf31c'

### 1. Select parameters to load your decoding results

In [None]:
nwb_file_name = 'pippin20210421_.nwb'
interval_list_name = '02_r1'

spike_closeness_threshold = 0.00005
max_coincident_fraction = 0.5
removal_window_s = 0.001

# DECODING INTERVALS FOR SINGLE INTERVALS
epoch = int(interval_list_name[:2])
pos_interval_list_name = (sgc.IntervalList() & {'nwb_file_name': nwb_file_name, 'pipeline': 'position'}).fetch('interval_list_name')[epoch - 1]
trodes_pos_params_name = 'default_decoding'
mobile_interval_list_name = insert_mobile_times_interval(nwb_file_name, pos_interval_list_name, trodes_pos_params_name, speed_thresh=4, time_thresh=1)

# Note: if you don't have intervals during trials, change this to whatever you call your decoding and encoding intervals
during_trials_filt_mobile_interval = f'{mobile_interval_list_name} coincident_spikes_removed_times rem_{removal_window_s} close_{spike_closeness_threshold} frac_{max_coincident_fraction} (during trials)'
during_trials_filt_pos_interval = f'{pos_interval_list_name} coincident_spikes_removed_times rem_{removal_window_s} close_{spike_closeness_threshold} frac_{max_coincident_fraction} (during trials)'

wf_group_name = f'ca1_waveforms {interval_list_name}'
pos_group_name = f'{interval_list_name} decoding'
decoding_param_name = 'contfrag_clusterless_placebin3_100chunks_blocksize100__nocache'
encoding_interval = during_trials_filt_mobile_interval
decoding_interval = during_trials_filt_pos_interval 
estimate_decoding_params = True

selection_key = {
    "waveform_features_group_name": wf_group_name,
    "position_group_name": pos_group_name,
    "decoding_param_name": decoding_param_name,
    "nwb_file_name": nwb_file_name,
    "encoding_interval": encoding_interval,
    "decoding_interval": decoding_interval,
    "estimate_decoding_params": estimate_decoding_params,
}

pos 1 mobile times already exists, skipping


### 2. Load in decoding results

In [36]:
# load in decoding results
decoding_results = (ClusterlessDecodingV1 & selection_key).fetch_results()

# unpack the posterior from the decoding results
posterior = decoding_results.acausal_posterior.unstack("state_bins").sum("state")

# fetch actual position
(
    position_info,
    position_variable_names,
) = ClusterlessDecodingV1.fetch_position_info(selection_key)

# fetch the decoding environment
env = ClusterlessDecodingV1.fetch_environments(selection_key)[0]

  self._dims = self._parse_dimensions(dims)
  self._dims = self._parse_dimensions(dims)
  self._dims = self._parse_dimensions(dims)
ndx-franklab-novela - cached version: 0.2.3, loaded version: 0.2.1
  self.warn_for_ignored_namespaces(ignored_namespaces)


### 3. Process decoding information before inputting into figpack function

In [37]:
interior_place_bin_centers = env.place_bin_centers_[
    env.is_track_interior_.ravel()
]
if posterior is None:
    posterior = decoding_results.acausal_posterior.unstack("state_bins").sum("state")
place_bin_size = (
    posterior.x_position.values[1] - posterior.x_position.values[0],
    posterior.y_position.values[1] - posterior.y_position.values[0],
)

position = position_info[position_variable_names]
position_time = position_info.index.values

head_dir = position_info['orientation']

### 4. Create decode figpack view

In [38]:
decode_view = create_2D_decode_view_figpack(
    position_time=position_time,
    position=position,
    posterior=posterior,
    interior_place_bin_centers=interior_place_bin_centers,
    place_bin_size=place_bin_size,
    head_dir=head_dir,
)

### 5. Add any additional views from below (or create other custom ones that are relevant for your data)

In [39]:
# Set values that are useful for all subsequent views
start_time = position_time[0]
hide_nav = False 
hide_time_labels = False

#### State probability view

In [40]:
state_prob_view = vv.TimeseriesGraph(    
    legend_opts={"location": "northeast"},
    hide_nav_toolbar=hide_nav,
    hide_time_axis_labels=hide_time_labels,
    y_label='state probability'
)
COLOR_CYCLE = [
    "#1f77b4",
    "#ff7f0e",
    "#2ca02c",
    "#d62728",
    "#9467bd",
    "#8c564b",
    "#e377c2",
    "#7f7f7f",
    "#bcbd22",
    "#17becf",
]
for state, color in zip(decoding_results.states.values, COLOR_CYCLE):
    state_prob_view.add_line_series(
        name=state,
        t=np.asarray(position_time - start_time),
        y=np.asarray(
            decoding_results.sel(states=state).acausal_state_probabilities,
            dtype=np.float32,
        ),
        color=color,
        width=1,
    )

#### Speed view

In [41]:
speed_view = vv.TimeseriesGraph(
    legend_opts={"location": "northeast"},
    hide_nav_toolbar=hide_nav,
    hide_time_axis_labels=hide_time_labels,
    y_label='speed (cm/s)',
)

speed = position_info['speed'].values
speed_view.add_line_series(
    name="speed [cm/s]",
    t=np.asarray(position_time - start_time),
    y=np.asarray(speed, dtype=np.float32),
    color="black",
    width=1,
)

#### Multiunit firing rate view with detected ripples

In [42]:
# Note - loading in the spike times might take a while (up to ~15 min), be patient!
spike_times, _ = ClusterlessDecodingV1.fetch_spike_data(selection_key)

ndx-franklab-novela - cached version: 0.2.3, loaded version: 0.2.1
  self.warn_for_ignored_namespaces(ignored_namespaces)
ndx-franklab-novela - cached version: 0.2.3, loaded version: 0.2.1
  self.warn_for_ignored_namespaces(ignored_namespaces)
ndx-franklab-novela - cached version: 0.2.3, loaded version: 0.2.1
  self.warn_for_ignored_namespaces(ignored_namespaces)
ndx-franklab-novela - cached version: 0.2.3, loaded version: 0.2.1
  self.warn_for_ignored_namespaces(ignored_namespaces)
ndx-franklab-novela - cached version: 0.2.3, loaded version: 0.2.1
  self.warn_for_ignored_namespaces(ignored_namespaces)
ndx-franklab-novela - cached version: 0.2.3, loaded version: 0.2.1
  self.warn_for_ignored_namespaces(ignored_namespaces)
ndx-franklab-novela - cached version: 0.2.3, loaded version: 0.2.1
  self.warn_for_ignored_namespaces(ignored_namespaces)
ndx-franklab-novela - cached version: 0.2.3, loaded version: 0.2.1
  self.warn_for_ignored_namespaces(ignored_namespaces)
ndx-franklab-novela - ca

In [43]:
multiunit_firing_rate = get_multiunit_firing_rate(
    spike_times, decoding_results.time.values
)

In [44]:
multiunit_view = vv.TimeseriesGraph(
    legend_opts={"location": "northeast"},
    hide_nav_toolbar=hide_nav,
    hide_time_axis_labels=hide_time_labels,
    y_label='multiunit firing rate (spikes/s)',
)

multiunit_view.add_line_series(
    name="multiunit firing rate (spikes/s)",
    t=np.asarray(position_time) - start_time,
    y=np.asarray(multiunit_firing_rate['firing_rate'].values),
    color="black",
    width=1,
)

NOTE: ripple detection must already be populated in order to add the detected ripples, feel free to comment out the following two cells if you don't have that yet

In [None]:
# LOAD IN RIPPLE TIMES
lfp_electrode_group_name = 'good_single_elecs'
lfp_sampling_rate = 1_000
lfp_filter_name = 'LFP 0-400 Hz'
artifact_params_name = 'mad_7_0.66_thresh_200ms'
ripple_filter_name = 'Ripple 100-250 Hz'
ripple_param_name = 'shvartsman_sd4_part2'

lfp_s_key = {
    'nwb_file_name': nwb_file_name,
    'lfp_electrode_group_name': lfp_electrode_group_name,
    'target_interval_list_name': interval_list_name,
    'filter_name': lfp_filter_name,
    'filter_sampling_rate': 30_000,  # sampling rate of the data (Hz)
    'target_sampling_rate': lfp_sampling_rate,  # sampling rate of the lfp output (Hz)
}

lfp_merge_id = (lfp.LFPOutput.LFPV1() & lfp_s_key).fetch1('merge_id')

# select ripples to include
trodes_pos_params_name = 'default'
pos_s_key = {
    "nwb_file_name": nwb_file_name,
    "interval_list_name": pos_interval_list_name,
    "trodes_pos_params_name": trodes_pos_params_name,
}
pos_key = (sgp.v1.TrodesPosSelection() & pos_s_key).fetch1("KEY")
pos_merge_key = (sgp.PositionOutput.merge_get_part(pos_key)).fetch1("KEY")
pos_merge_id = pos_merge_key['merge_id']
lfp_artifact_s_key = {
    'nwb_file_name': nwb_file_name,
    'lfp_electrode_group_name': lfp_electrode_group_name,
    'target_interval_list_name': interval_list_name,
    'filter_name': lfp_filter_name,
    'filter_sampling_rate': 30_000,  # I'm pretty sure this is the sampling rate for the original data but not sure
    'artifact_params_name': artifact_params_name,
}
artifact_removed_interval_list_name = (lfp.v1.LFPArtifactRemovedIntervalList() & lfp_artifact_s_key).fetch1('artifact_removed_interval_list_name')
rip_interval_list_name = artifact_removed_interval_list_name
lfp_band_s_key = {
    'lfp_merge_id': lfp_merge_id,
    'filter_name': ripple_filter_name,
    'filter_sampling_rate': lfp_sampling_rate,
    'target_interval_list_name': rip_interval_list_name,
    'lfp_band_sampling_rate': lfp_sampling_rate,
    'nwb_file_name': LFPOutput.merge_get_parent({"merge_id": lfp_merge_id}).fetch1("nwb_file_name"),
}
lfp_band_key = (LFPBandV1 & lfp_band_s_key).fetch1("KEY")

# Load in ripple data from the ripple group table
lfp_band_key['ripple_param_name'] = ripple_param_name
ripple_data = (RippleTimesGroup().RippleTimes() & lfp_band_key).fetch1_dataframe()
ripple_times = ripple_data[['start_time', 'end_time']].values

In [None]:
multiunit_view.add_interval_series(
    name='detected ripples',
    t_start = np.asarray(ripple_times[:, 0]) - start_time,
    t_end = np.asarray(ripple_times[:, 1]) - start_time,
    color='yellow',
)

### 6. Plot all the views together and generate the figpack

In [47]:
vertical_panel1_content = [
    vv.LayoutItem(decode_view, stretch=1, title="Decode"),
]

vertical_panel2_content = [
    vv.LayoutItem(state_prob_view, stretch=0.33, title="Probability of State"),
    vv.LayoutItem(speed_view, stretch=0.33, title="Speed"),
    vv.LayoutItem(multiunit_view, stretch=0.33, title='Multiunit Firing Rate + Detected Ripples')
]

layout = vv.Box(
    title=f"{nwb_file_name} {interval_list_name}",
    direction='horizontal',
    items=[
        vv.LayoutItem(
            vv.Box(
                direction='vertical',
                show_titles=True,
                items=vertical_panel1_content,
            )
        ),
        vv.LayoutItem(
            vv.Box(
                direction='vertical',
                show_titles=True,
                items=vertical_panel2_content,
            )
        ),
    ],
)
layout.show(title=f"{nwb_file_name} {interval_list_name} 2D decoding", upload=True)

Found 8 files to upload, total size: 30.24 MB
Uploading 8 files in batches of 20 with up to 16 concurrent uploads per batch...
Processing batch 1/1 (8 files)...
Uploaded 1/8: assets/index-GPjx4QpG.css
Uploaded 2/8: data.zarr/.zmetadata
Uploaded 3/8: extension_manifest.json
Uploaded 4/8: extension-figpack-franklab.js
Uploaded 5/8: index.html
Uploaded 6/8: assets/index-BY1Hwjm4.js
Uploaded 7/8: assets/neurosift-logo-CLsuwLMO.png
Uploaded 8/8: data.zarr/_consolidated_0.dat
Creating manifest...
Total size: 30.24 MB
Uploading manifest.json...
Finalizing figure...
Upload completed successfully
View the figure at: https://figures.figpack.org/figures/default/f51e32bf620ac21e4e22e3ef/index.html


'https://figures.figpack.org/figures/default/f51e32bf620ac21e4e22e3ef/index.html'