### setup

In [None]:
import warnings
from itertools import combinations
from pathlib import Path
from tqdm.notebook import tqdm

import numpy as np
import probeinterface
import spikeinterface as si

from matplotlib import pyplot as plt
from probeinterface import get_probe, Probe
from probeinterface.plotting import plot_probe
from spikeinterface import (
    BinaryFolderRecording,
    comparison,
    concatenate_recordings,
    exporters,
    extractors,
    postprocessing,
    preprocessing,
    sorters,
    qualitymetrics,
    widgets,
)


print(f"SpikeInterface version: {si.__version__}")

warnings.simplefilter("ignore")
%matplotlib inline
%matplotlib widget

### load recording

In [None]:
def prepare_recording(
    subject_id: str = "subject_1",
    trial_id: str = "2020-08-22",
    take_slice: bool = False,
    slice_size: int = 300,
    verbose: bool = True,
) -> BinaryFolderRecording:
    """loads the neural recording of the given trial, concatenates its segments
    into one joint object, then exports and loads the it in a sorting-ready format"""
    import_path = f"data/raw/{subject_id}/{trial_id}/"
    segmented_recording = extractors.read_neuralynx(import_path, stream_id="0")
    joint_recording = concatenate_recordings(
        list(
            map(
                segmented_recording.select_segments,
                range(segmented_recording.get_num_segments()),
            )
        )
    )
    fs = joint_recording.get_sampling_frequency()
    if verbose:
        nch = joint_recording.get_num_channels()
        dur = joint_recording.get_total_duration()
        print(
            "\n".join(
                (
                    "Recording loaded",
                    f"Sampling frequency: {fs} Hz",
                    f"Number of channels: {nch}",
                    f"Total duration: {dur:.2f} seconds",
                )
            )
        )

    recording_slice = (
        joint_recording.frame_slice(
            start_frame=0 * fs,
            end_frame=slice_size * fs,
        )
        if take_slice
        else joint_recording
    )
    export_path = import_path.replace("raw", "slices" if take_slice else "prepared")
    recording_saved = recording_slice.save(folder=export_path, verbose=False)
    recording_loaded = si.load_extractor(export_path)
    return recording_loaded


recording = prepare_recording()
# recording_slice = prepare_recording(take_slice=True, verbose=False)

### visualize recording

In [None]:
raw_plot = widgets.plot_timeseries(recording, backend="ipywidgets")

### add fake probe

In [None]:
def compute_min_dist(positions: np.ndarray) -> float:
    """computes the minimal distance between any two channels on the given probe"""
    channel_pairs = combinations(positions, 2)
    channel_distances = np.linalg.norm([p1 - p2 for p1, p2 in channel_pairs], axis=1)
    return channel_distances.min()


positions = np.random.uniform(low=0, high=100, size=(16, 2))
while not compute_min_dist(positions) >= 10:
    positions = np.random.uniform(low=0, high=100, size=(16, 2))

probe = Probe()
probe.set_contacts(positions=positions, shapes="circle", shape_params=dict(radius=5))
probe.set_device_channel_indices(np.arange(recording.get_num_channels()))
rec_w_probe = recording.set_probe(probe)

%matplotlib inline
plot_probe(probe)
plt.show()

### run sorters

In [None]:
docker_dict = dict(
    # combinato="spikeinterface/combinato-base:latest", # single channel
    hdsort="spikeinterface/hdsort-compiled-base:latest",
    herdingspikes="spikeinterface/herdingspikes-base:latest",
    ironclust="spikeinterface/ironclust-compiled-base:latest",
    # kilosort="spikeinterface/kilosort-compiled-base:latest", # requires GPU
    # kilosort2="spikeinterface/kilosort2-compiled-base:latest", # requires GPU
    # kilosort2_5="spikeinterface/kilosort2_5-compiled-base:latest", # requires GPU
    # klusta="spikeinterface/klusta-base:latest", # error in docker image
    mountainsort4="spikeinterface/mountainsort4-base:latest",
    # pykilosort="spikeinterface/pykilosort-base:latest", # requires GPU
    spykingcircus="spikeinterface/spyking-circus-base:latest",
    tridesclous="spikeinterface/tridesclous-base:latest",
    waveclus="spikeinterface/waveclus-compiled-base:latest",
    # yass="spikeinterface/yass-base:latest", # requires GPU
)

sorting_dict = dict()

param_dict = {
    sorter: sorters.get_default_sorter_params(sorter)
    for sorter in sorters.available_sorters()
}

for sorter, image in tqdm(docker_dict.items()):
    try:
        filter_kwarg = dict(filter=False) if "filter" in param_dict[sorter] else dict()
        sorting_dict[sorter] = sorters.run_sorter(
            sorter,
            rec_w_probe,
            docker_image=image,
            output_folder=f"data/sorted/{sorter}",
            **filter_kwarg,
        )
    except Exception as e:
        print(f"Sorting with {sorter} failed due to the following exception: {e}")