## Import libraries

In [1]:
import os
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import spikeinterface.full as si

%matplotlib widget

Matplotlib created a temporary cache directory at /tmp/matplotlib-apu1hcd0 because the default path (/home/jupyter-ikharitonov/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


## Define paths and configure default parameters

In [2]:
# Main working directory
base_folder = Path.home() / 'RANCZLAB-NAS' / 'iakov' / 'allen_sorting_07_11_23'

# Input data and sorting
binary_filename = 'spike_band.dat'
output_folder = base_folder / 'kilosort3_sorting_files'
sorting_save_path = base_folder / 'kilosort3_sorting_output'

In [3]:
# Default parameters for compute-intensive steps
# n_jobs is the number of CPU cores used. -1 means use all cores.
job_kwargs = dict(n_jobs=-1, chunk_duration="1s", progress_bar=True)

# Plotting backend
backend = 'ipywidgets'

## Load preprocessed recording and sorting

In [5]:
recording = si.load_extractor(base_folder / "preprocessed_recording")
sorting = si.load_extractor(sorting_save_path)



## Extract waveforms

In [11]:
waveforms = si.extract_waveforms(recording, sorting, folder = base_folder / "waveforms_dense", sparse=False, overwrite=True, **job_kwargs)
print(waveforms)

Exception: The recording is not filtered, you must filter it using `bandpass_filter()`.If the recording is already filtered, you can also do `recording.annotate(is_filtered=True).
If you trully want to extract unfiltered waveforms, use `allow_unfiltered=True`.

## Compute metrics

In [None]:
# Sparsity, pca scores, spike amplitudes, spike locations, correlograms, template similarity, quality metrics list

In [None]:
sparsity = si.compute_sparsity(waveforms, method='radius', radius_um=100.0)
sparse_waveforms = si.extract_waveforms(recording, sorting, folder = base_folder / "waveforms_sparse", sparsity=sparsity, **job_kwargs)

In [None]:
principal_components = si.compute_principal_components(waveforms, n_components=3, load_if_exists=False, **job_kwargs)

In [None]:
spike_amplitudes = si.compute_spike_amplitudes(waveforms, outputs="by_unit", load_if_exists=True, **job_kwargs)

In [None]:
unit_locations = si.compute_unit_locations(we, method="monopolar_triangulation", load_if_exists=True, **job_kwargs)
spike_locations = si.compute_spike_locations(we, method="center_of_mass", load_if_exists=True, **job_kwargs)

In [None]:
cross_correlograms, bins = si.compute_correlograms(waveforms)

In [None]:
template_similarity = si.compute_template_similarity(waveforms)

In [None]:
template_metrics = si.calculate_template_metrics(waveforms)

In [None]:
metric_names = si.get_quality_metric_list()
quality_metrics = si.compute_quality_metrics(waveforms, metric_names=metric_names, verbose=True, **job_kwargs)

## Display postprocessing information

### Spike trains

In [None]:
si.plot_rasters(sorting)

### Templates

In [None]:
si.plot_unit_templates(waveforms, backend=backend)

### Spike amplitudes

In [None]:
si.plot_amplitudes(waveforms, backend=backend)

### Unit locations

In [None]:
si.plot_unit_locations(waveforms, backend=backend)

### Spike locations

In [None]:
si.plot_spike_locations(waveforms, max_spikes_per_unit=300, backend=backend)

### Autocorrelograms

In [None]:
si.plot_autocorrelograms(waveforms, unit_ids=sorting.unit_ids[::30])

### Cross-correlograms

In [None]:
si.plot_crosscorrelograms(waveforms, unit_ids=sorting.unit_ids[::30])

### Template metrics

In [None]:
display(template_metrics)

In [None]:
si.plot_template_metrics(waveforms, include_metrics=["peak_to_valley", "half_width"], backend=backend)

### Quality metrics

In [None]:
display(quality_metrics)

In [None]:
plot_metrics = ["amplitude_cutoff", "presence_ratio", "isi_violations_ratio", "snr"]
si.plot_quality_metrics(we, include_metrics=plot_metrics, backend=backend)

## Curate data

In [None]:
isi_viol_thresh = 0.2
amp_cutoff_thresh = 0.1

curation_query = f"amplitude_cutoff < {amp_cutoff_thresh} & isi_violations_ratio < {isi_viol_thresh}"

keep_units = quality_metrics.query(curation_query)
keep_unit_ids = keep_units.index.values

sorting_curated = sorting.select_units(keep_unit_ids)
sorting_curated.save(folder = base_folder / 'curated_sorting_output', format='npz_folder', **job_kwargs)
print(f"Number of units before curation: {len(sorting.get_unit_ids())}")
print(f"Number of units after curation: {len(sorting_curated.get_unit_ids())}")

waveforms_curated = waveforms.select_units(keep_unit_ids, new_folder = base_folder / "waveforms_curated")
print(waveforms_curated)