In [None]:
import os

if os.path.basename(os.getcwd()) == "notebooks":
    os.chdir("..")

In [None]:
import datajoint as dj
from datetime import datetime
import spikeinterface as si
from spikeinterface import widgets, exporters, postprocessing, qualitymetrics, sorters
from workflow.pipeline import *
from workflow.utils.paths import (
    get_ephys_root_data_dir,
    get_raw_root_data_dir,
    get_processed_root_data_dir,
)
from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory

### Select from the following list of experiments

In [None]:
display(
    culture.Experiment()
    .proj("experiment_end_time", "drug_name", "drug_concentration", "experiment_plan")
    .fetch(format="frame")
    .reset_index()
)

### Create `spike_sorting` sessions

In [None]:
session_info = dict(
    organoid_id="O09",
    experiment_start_time="2023-05-18 12:25:00",
    insertion_number=0,
    start_time="2023-05-18 12:25:00",
    end_time="2023-05-18 12:30:00",
    session_type="spike_sorting",
)

session_probe_info = dict(
    organoid_id="O09",
    experiment_start_time="2023-05-18 12:25:00",
    insertion_number=0,
    start_time="2023-05-18 12:25:00",
    end_time="2023-05-18 12:30:00",
    probe="Q983",  # probe serial number
    port_id="A",  # Port ID ("A", "B", etc.)
    used_electrodes=[],  # empty if all electrodes were used
)

In [None]:
# Insert the session
SPIKE_SORTING_DURATION = 120  # minutes

# Start and end time of the session. It should be within the experiment time range
start_time = datetime.strptime(session_info["start_time"], "%Y-%m-%d %H:%M:%S")
end_time = datetime.strptime(session_info["end_time"], "%Y-%m-%d %H:%M:%S")
duration = (end_time - start_time).total_seconds() / 60

assert (
    session_info["session_type"] == "spike_sorting"
    and duration <= SPIKE_SORTING_DURATION
), f"Session type must be 'spike_sorting' and duration must be less than {SPIKE_SORTING_DURATION} minutes"

ephys.EphysSession.insert1(session_info, ignore_extra_fields=True, skip_duplicates=True)

ephys.EphysSessionProbe.insert1(
    session_probe_info, ignore_extra_fields=True, skip_duplicates=True
)

del session_probe_info["used_electrodes"]
display(ephys.EphysSession & session_info)
display(ephys.EphysSessionProbe & session_probe_info)

key = (ephys.EphysSession & session_info).fetch1("KEY")

### Insert clustering parameters

#### Sample paramter dictionary. It expects to have `SI_SORTING_PARAMS`, `SI_PREPROCESSING_METHOD`, `SI_QUALITY_METRICS_PARAMS`, `SI_JOB_KWARGS`

- `SI_SORTING_PARAMS`: Run `si.sorters.get_default_sorter_params(sorter_name)` to get the default parameter for a  sorter. Modify values if needed. If empty, the sorter will be run with the default parameter.

- `SI_PREPROCESSING_METHOD`: Select a preprocesesing function from `si_preprocessing.py`
- `SI_WAVEFORM_EXTRACTION_PARAMS`: Waveform extraction parameters. If empty, the sorter will use the default parameter.
- `SI_QUALITY_METRICS_PARAMS`: Quality metric parameters. If empty, the sorter will use the default parameter.
- `SI_JOB_KWARGS`: Sorter job parameters. If empty, the sorter will use the default parameter.

In [None]:
ephys_sorter.SI_SORTERS # list of available sorters supported by spikeinterface

- Print out the default parameter for a sorter.

In [None]:
sorter_name = "spykingcircus2"
si.sorters.get_default_sorter_params(sorter_name)

- Create a parameter dictionary

In [None]:
params = {}
params["SI_SORTING_PARAMS"] = {
    "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100},
    "waveforms": {
        "max_spikes_per_unit": 200,
        "overwrite": True,
        "sparse": True,
        "method": "energy",
        "threshold": 0.25,
    },
    "filtering": {"freq_min": 150, "dtype": "float32"},
    "detection": {"peak_sign": "neg", "detect_threshold": 4},
    "selection": {
        "method": "smart_sampling_amplitudes",
        "n_peaks_per_channel": 5000,
        "min_n_peaks": 20000,
        "select_per_channel": False,
    },
    "clustering": {"legacy": False},
    "matching": {"method": "circus-omp-svd", "method_kwargs": {}},
    "apply_preprocessing": True,
    "shared_memory": True,
    "cache_preprocessing": {
        "mode": "memory",
        "memory_limit": 0.5,
        "delete_cache": True,
    },
    "multi_units_only": False,
    "job_kwargs": {"n_jobs": 0.8},
    "debug": False,
}


params["SI_PREPROCESSING_METHOD"] = "organoid_preprocessing"
params["SI_WAVEFORM_EXTRACTION_PARAMS"] = {
    "ms_before": 1.0,
    "ms_after": 2.0,
    "max_spikes_per_unit": 500,
}
params["SI_QUALITY_METRICS_PARAMS"] = {"n_components": 5, "mode": "by_channel_local"}
params["SI_JOB_KWARGS"] = {"n_jobs": -1, "chunk_size": 30000}

- Insert the paramter. Specify `clustering_method (select from above)`, `paramset_desc (optional)`, `paramset_idx (int)`

In [None]:
paramset_idx = 0
clustering_method = "spykingcircus2"
paramset_desc = ""

ephys.ClusteringParamSet.insert_new_params(
    paramset_idx=0,
    clustering_method=clustering_method,
    paramset_desc=paramset_desc,
    params=params,
)

### Select a session and paramset_idx and insert into `ephys.ClusteringTask` to trigger spike-sorting.

In [None]:
clustering_task = key | {"paramset_idx" : paramset_idx}

ephys.ClusteringTask.insert1(clustering_task)

ephys.ClusteringTask & key

### Wait until spike-sorting is finished. Explore spike-sorting results in downstream tables

In [None]:
ephys.CuratedClustering.Unit & key

In [None]:
ephys.WaveformSet.PeakWaveform & key

In [None]:
ephys.QualityMetrics.Cluster & key