### setup

In [None]:
import numpy as np
import pandas as pd
import spikeinterface as si
from matplotlib import pyplot as plt
from probeinterface import get_probe, Probe
from spikeinterface import comparison, extractors, sorters, widgets

In [None]:
recording = si.load_extractor("data/prepared/subject_1/2020-08-22/")
recording_sub = recording.channel_slice(["64", "65", "66", "67", "68", "69"])

### run sorting with the selected sorters

In [None]:
sorting_ms4 = sorters.run_mountainsort4(
    recording_sub,
    output_folder="data/sorted/mountainsort4",
    verbose=True,
    filter=False,
    num_workers=4,
)

In [None]:
sorting_wc = sorters.run_waveclus(
    recording_sub,
    docker_image="spikeinterface/waveclus-compiled-base:latest",
    output_folder="data/sorted/waveclus",
    verbose=True,
)

### compare sorting results

In [None]:
sorting_ms4 = sorters.read_sorter_folder("data/sorted/mountainsort4/")
sorting_wc = sorters.read_sorter_folder("data/sorted/waveclus/")

In [None]:
multicomp = comparison.compare_multiple_sorters(
    [sorting_ms4, sorting_wc],
    name_list=["mountainsort4", "waveclus"],
    delta_time=0.5,
    match_score=0.5,
)

In [None]:
agreement_sorting = multicomp.get_agreement_sorting(minimum_agreement_count=2)
[
    len(agreement_sorting.get_unit_spike_train(unit))
    for unit in agreement_sorting.get_unit_ids()
]

In [None]:
fig = plt.figure(figsize=(20, 20))
widgets.plot_agreement_matrix(
    multicomp.comparisons[("mountainsort4", "waveclus")], figure=fig
)
plt.show()

### extract waveforms

In [None]:
n = recording_sub.get_num_channels()
positions = np.zeros(shape=(n, 2))
probe = Probe()
probe.set_contacts(positions=positions, shapes="circle", shape_params=dict(radius=5))
probe.set_device_channel_indices(np.arange(n))
rec_w_probe = recording_sub.set_probe(probe)

In [None]:
waveforms_agreement = si.extract_waveforms(
    rec_w_probe,
    agreement_sorting,
    folder=f"data/waveforms/agreement",
    max_spikes_per_unit=None,
)

In [None]:
fig = plt.figure(figsize=(30, 5))
fig.tight_layout()
widgets.plot_unit_waveforms(
    waveforms_agreement,
    figure=fig,
    ncols=6,
)
plt.show()