### setup

In [None]:
import warnings
from itertools import product
from pathlib import Path
from tqdm.notebook import tqdm

import numpy as np
import probeinterface
import spikeinterface as si

from matplotlib import pyplot as plt
from probeinterface import get_probe, Probe
from probeinterface.plotting import plot_probe
from spikeinterface import (
    ConcatenateSegmentRecording,
    comparison,
    concatenate_recordings,
    exporters,
    extractors,
    postprocessing,
    preprocessing,
    sorters,
    qualitymetrics,
    widgets,
)


print(f"SpikeInterface version: {si.__version__}")

warnings.simplefilter("ignore")
%matplotlib inline
%matplotlib widget

### load recording

In [None]:
def load_and_join_recording(
    subject_id: str = "subject_1",
    trial_id: str = "2020-08-22",
    verbose: bool = True,
) -> ConcatenateSegmentRecording:
    """loads the neural recording of the given trial
    and concatenates its segments into one joint object"""
    segmented_recording = extractors.read_neuralynx(
        f"data/raw/{subject_id}/{trial_id}/", stream_id="0"
    )
    joint_recording = concatenate_recordings(
        list(
            map(
                segmented_recording.select_segments,
                range(segmented_recording.get_num_segments()),
            )
        )
    )
    if verbose:
        nch = joint_recording.get_num_channels()
        dur = joint_recording.get_total_duration()
        print(
            f"Recording loaded.\nNumber of channels: {nch}\nTotal duration: {dur:.2f} seconds"
        )
    return joint_recording


recording = load_and_join_recording()

### take sample slice of recording

In [None]:
fs = recording.get_sampling_frequency()
slice_dur = 300
recording_sub = recording.frame_slice(
    start_frame=0 * fs,
    end_frame=slice_dur * fs,
)

recording_saved = recording_sub.save(folder="data/slice")
recording_loaded = si.load_extractor("data/slice/")

### visualize recording

In [None]:
raw_plot = widgets.plot_timeseries(
    recording_loaded,
    backend="ipywidgets",
)

### run sorters

In [None]:
docker_dict = dict(
    # combinato="spikeinterface/combinato-base:latest", # single channel
    hdsort="spikeinterface/hdsort-compiled-base:latest",
    herdingspikes="spikeinterface/herdingspikes-base:latest",
    ironclust="spikeinterface/ironclust-compiled-base:latest",
    # kilosort="spikeinterface/kilosort-compiled-base:latest", # requires GPU
    # kilosort2="spikeinterface/kilosort2-compiled-base:latest", # requires GPU
    # kilosort2_5="spikeinterface/kilosort2_5-compiled-base:latest", # requires GPU
    # klusta="spikeinterface/klusta-base:latest", # error in docker image
    mountainsort4="spikeinterface/mountainsort4-base:latest",
    # pykilosort="spikeinterface/pykilosort-base:latest", # requires GPU
    spykingcircus="spikeinterface/spyking-circus-base:latest",
    tridesclous="spikeinterface/tridesclous-base:latest",
    waveclus="spikeinterface/waveclus-compiled-base:latest",
    # yass="spikeinterface/yass-base:latest", # requires GPU
)

sorting_dict = dict()

param_dict = {
    sorter: sorters.get_default_sorter_params(sorter)
    for sorter in sorters.available_sorters()
}

for sorter, image in tqdm(docker_dict.items()):
    try:
        filter_kwarg = dict(filter=False) if "filter" in param_dict[sorter] else dict()
        sorting_dict[sorter] = sorters.run_sorter(
            sorter,
            recording_loaded,
            docker_image=image,
            output_folder=f"data/sorted/{sorter}",
            **filter_kwarg,
        )
    except Exception as e:
        print(f"Sorting with {sorter} failed due to the following exception: {e}")