# Debuggin DataJoint Pipeline: Spike Sorting Computations

This notebook is designed to help advanced users to quickly debug the computations and/or investigate errors that may occur during the execution of a DataJoint `populate` function. 

It provides a briefly guide on how to dissect the `make` function, enabling a deeper understanding of the pipeline's computational steps and facilitating faster issue resolution.

**Note: This notebook is intended as a supplementary tool for debugging and should not replace best practices in coding development.**

The spike sorting analysis is managed by the `ephys_sorter` schema containing three main tables in the DataJoint pipeline: 

1. PreProcessing

2. SIClustering

3. PostProcessing

Please review and understand the code for each table [here](https://github.com/dj-sciops/utah_organoids_element-array-ephys/blob/main/element_array_ephys/spike_sorting/si_spike_sorting.py).

### **Key Steps**

- **Setup**

- **Step 1: Select Session of Interest**

- **Step 2: `populate` Necessary Tables before Spike Sorting**

- **Step 3: Execute Each Part of the Spike Sorting Computations to Debug**

#### **Setup**

First, import the necessary packages for the data pipeline and essential schemas.

In [4]:
import os

if os.path.basename(os.getcwd()) == "notebooks":
    os.chdir("..")

In [26]:
import datajoint as dj
import datetime
import pandas as pd
import numpy as np

In [6]:
from workflow.pipeline import ephys, ephys_sorter

[2024-07-24 23:02:09,486][INFO]: Connecting milagros@db.datajoint.com:3306
[2024-07-24 23:02:11,072][INFO]: Connected milagros@db.datajoint.com:3306


#### **Step 1: Select Session of Interest**


In [15]:
session_key = {
    "organoid_id": "O09",
    "experiment_start_time": datetime.datetime(2023, 5, 18, 12, 25),
    "start_time": "2023-05-18 12:25:00",
    "end_time": "2023-05-18 12:26:30",
}

Ensure your `session_key` is already inserted in the `EphysSession` and `EphysSessionProbe`. If not, follow the notebook [CREATE_new_session.ipynb](./CREATE_new_session.ipynb).


In [18]:
ephys.EphysSession * ephys.EphysSessionProbe & session_key

organoid_id  e.g. O17,experiment_start_time,insertion_number,start_time,end_time,session_type,probe  unique identifier for this model of probe (e.g. serial number),port_id,"used_electrodes  list of electrode IDs used in this session (if null, all electrodes are used)"
O09,2023-05-18 12:25:00,0,2023-05-18 12:25:00,2023-05-18 12:26:30,spike_sorting,Q983,A,=BLOB=


In [19]:
ephys_key = {**session_key, "insertion_number": 0}

#### **Step 2: `populate` Necessary Tables before Spike Sorting**


Populate the necessary tables:

In [20]:
ephys.EphysSessionInfo.populate()

Ensure your `ClusteringTask` has been defined for this specific `session_key` and a specific `paramset_idx`. If not, follow the notebook [CREATE_new_clustering_task.ipynb](./CREATE_new_clustering_task.ipynb).

In [24]:
key = (ephys.ClusteringTask & ephys_key & "paramset_idx=101").fetch1("KEY")
key

{'organoid_id': 'O09',
 'experiment_start_time': datetime.datetime(2023, 5, 18, 12, 25),
 'insertion_number': 0,
 'start_time': datetime.datetime(2023, 5, 18, 12, 25),
 'end_time': datetime.datetime(2023, 5, 18, 12, 26, 30),
 'paramset_idx': 101}

#### **Step 3: Execute Each Part of the Spike Sorting Computations to Debug**

To debug, copy and paste the code of the three `make` functions here to reproduce the `si_recording` and `si_sorting` objects for exploration and testing purposes.

In [31]:
import spikeinterface as si
from element_array_ephys import probe, readers
from element_interface.utils import find_full_path, memoized_result
from spikeinterface import exporters, postprocessing, qualitymetrics, sorters

# This line has been updated to import the module here
from element_array_ephys.spike_sorting import si_preprocessing

In [32]:
# ----------------- PreProcessing Make Function Copied Here ----------------- #

# Get clustering method and output directory.
clustering_method, output_dir, params = (
    ephys.ClusteringTask * ephys.ClusteringParamSet & key
).fetch1("clustering_method", "clustering_output_dir", "params")
acq_software = (ephys.EphysRawFile & key).fetch("acq_software", limit=1)[0]

# Get sorter method and create output directory.
sorter_name = clustering_method.replace(".", "_")

for required_key in (
    "SI_PREPROCESSING_METHOD",
    "SI_SORTING_PARAMS",
    "SI_POSTPROCESSING_PARAMS",
):
    if required_key not in params:
        raise ValueError(
            f"{required_key} must be defined in ClusteringParamSet for SpikeInterface execution"
        )

# Set directory to store recording file.
if not output_dir:
    output_dir = ephys.ClusteringTask.infer_output_dir(key, relative=True, mkdir=True)
    # update clustering_output_dir
    ephys.ClusteringTask.update1(
        {**key, "clustering_output_dir": output_dir.as_posix()}
    )
output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir)
recording_dir = output_dir / sorter_name / "recording"
recording_dir.mkdir(parents=True, exist_ok=True)
recording_file = recording_dir / "si_recording.pkl"

# Get probe information to recording object
probe_info = (probe.Probe * ephys.EphysSessionProbe & key).fetch1()
electrode_query = probe.ElectrodeConfig.Electrode & (
    probe.ElectrodeConfig & {"probe_type": probe_info["probe_type"]}
)

# Filter for used electrodes. If probe_info["used_electrodes"] is None, it means all electrodes were used.
number_of_electrodes = len(electrode_query)
probe_info["used_electrodes"] = (
    probe_info["used_electrodes"]
    if probe_info["used_electrodes"] is not None and len(probe_info["used_electrodes"])
    else list(range(number_of_electrodes))
)
unused_electrodes = [
    elec
    for elec in range(number_of_electrodes)
    if elec not in probe_info["used_electrodes"]
]
electrodes_df = (
    (probe.ProbeType.Electrode * electrode_query)
    .fetch(format="frame", order_by="electrode")
    .reset_index()[["electrode", "x_coord", "y_coord", "shank", "channel_idx"]]
)

"""Get the row indices of the port from the data matrix."""
session_info = (ephys.EphysSessionInfo & key).fetch1("session_info")
port_indices = np.array(
    [
        ind
        for ind, ch in enumerate(session_info["amplifier_channels"])
        if ch["port_prefix"] == probe_info["port_id"]
    ]
)  # get the row indices of the port

# Create SI recording extractor object
si_extractor: si.extractors.neoextractors = (
    si.extractors.extractorlist.recording_extractor_full_dict[
        acq_software.replace(" ", "").lower()
    ]
)  # data extractor object

files, file_times = (
    ephys.EphysRawFile
    & key
    & f"file_time BETWEEN '{key['start_time']}' AND '{key['end_time']}'"
).fetch("file_path", "file_time", order_by="file_time")

si_recording = None
# Read data. Concatenate if multiple files are found.
for file_path in (find_full_path(ephys.get_ephys_root_data_dir(), f) for f in files):
    if not si_recording:
        stream_name = [
            s for s in si_extractor.get_streams(file_path)[0] if "amplifier" in s
        ][0]
        si_recording: si.BaseRecording = si_extractor(
            file_path, stream_name=stream_name
        )
    else:
        si_recording: si.BaseRecording = si.concatenate_recordings(
            [
                si_recording,
                si_extractor(file_path, stream_name=stream_name),
            ]
        )

si_recording = si_recording.channel_slice(
    si_recording.channel_ids[port_indices]
)  # select only the port data

# Create SI probe object
si_probe = readers.probe_geometry.to_probeinterface(electrodes_df)
si_probe.set_device_channel_indices(electrodes_df["channel_idx"].values)
si_recording.set_probe(probe=si_probe, in_place=True)

# Account for additional electrodes being removed
if unused_electrodes:
    chn_ids_to_remove = [
        f"{probe_info['port_id']}-{electrodes_df.channel_idx.iloc[elec]:03d}"
        for elec in unused_electrodes
    ]
else:
    chn_ids_to_remove = []

si_recording = si_recording.remove_channels(remove_channel_ids=chn_ids_to_remove)

# Run preprocessing and save results to output folder
si_preproc_func = getattr(si_preprocessing, params["SI_PREPROCESSING_METHOD"])
si_recording = si_preproc_func(si_recording)
si_recording.dump_to_pickle(file_path=recording_file, relative_to=output_dir)

In [33]:
# `si_recording` can be explored here
si_recording

In [50]:
# ----------------- SIClustering Make Function Copied Here ----------------- #

# Load recording object.
clustering_method, output_dir, params = (
    ephys.ClusteringTask * ephys.ClusteringParamSet & key
).fetch1("clustering_method", "clustering_output_dir", "params")
output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir)
sorter_name = clustering_method.replace(".", "_")
recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl"
si_recording: si.BaseRecording = si.load_extractor(
    recording_file, base_folder=output_dir
)

sorting_params = params["SI_SORTING_PARAMS"]
sorting_output_dir = output_dir / sorter_name / "spike_sorting"


# Run sorting
@memoized_result(
    uniqueness_dict=sorting_params,
    output_directory=sorting_output_dir,
)
def _run_sorter():
    # Sorting performed in a dedicated docker environment if the sorter is not built in the spikeinterface package.
    si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter(
        sorter_name=sorter_name,
        recording=si_recording,
        output_folder=sorting_output_dir,
        remove_existing_folder=True,
        verbose=True,
        docker_image=sorter_name not in si.sorters.installed_sorters(),
        **sorting_params,
    )

    # Save sorting object
    sorting_save_path = sorting_output_dir / "si_sorting.pkl"
    si_sorting.dump_to_pickle(sorting_save_path, relative_to=output_dir)


_run_sorter()

print(f"`si_sorting` object is saved in the `sorting_save_path` {sorting_save_path}")

[2024-07-24 23:33:18,246][INFO]: Existing results found, skip '_run_sorter'


`si_sorting` object is saved in the `sorting_save_path` /Users/milagros/Documents/data/organoids/outbox/O09-12_raw/202305181225_202305181226/O09/spykingcircus2_101/spykingcircus2/spike_sorting/si_sorting.pkl


In [51]:
# ----------------- PostProcessing First Part of the Make Function Copied Here ----------------- #

# Load recording & sorting object.
clustering_method, output_dir, params = (
    ephys.ClusteringTask * ephys.ClusteringParamSet & key
).fetch1("clustering_method", "clustering_output_dir", "params")
output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir)
sorter_name = clustering_method.replace(".", "_")

recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl"
sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl"

si_recording: si.BaseRecording = si.load_extractor(
    recording_file, base_folder=output_dir
)
si_sorting: si.sorters.BaseSorter = si.load_extractor(
    sorting_file, base_folder=output_dir
)

In [52]:
# `si_sorting` can be explore here
si_sorting

In [53]:
# ----------------- PostProcessing Second Part of the Make Function Copied Here ----------------- #

postprocessing_params = params["SI_POSTPROCESSING_PARAMS"]

job_kwargs = postprocessing_params.get(
    "job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"}
)

analyzer_output_dir = output_dir / sorter_name / "sorting_analyzer"


@memoized_result(
    uniqueness_dict=postprocessing_params,
    output_directory=analyzer_output_dir,
)
def _sorting_analyzer_compute():
    # Sorting Analyzer
    sorting_analyzer = si.create_sorting_analyzer(
        sorting=si_sorting,
        recording=si_recording,
        format="binary_folder",
        folder=analyzer_output_dir,
        sparse=True,
        overwrite=True,
        **job_kwargs,
    )

    # The order of extension computation is drawn from sorting_analyzer.get_computable_extensions()
    # each extension is parameterized by params specified in extensions_params dictionary (skip if not specified)
    extensions_params = postprocessing_params.get("extensions", {})
    extensions_to_compute = {
        ext_name: extensions_params[ext_name]
        for ext_name in sorting_analyzer.get_computable_extensions()
        if ext_name in extensions_params
    }

    sorting_analyzer.compute(extensions_to_compute, **job_kwargs)

    # Save to phy format
    if postprocessing_params.get("export_to_phy", False):
        si.exporters.export_to_phy(
            sorting_analyzer=sorting_analyzer,
            output_folder=analyzer_output_dir / "phy",
            use_relative_path=True,
            **job_kwargs,
        )
    # Generate spike interface report
    if postprocessing_params.get("export_report", True):
        si.exporters.export_report(
            sorting_analyzer=sorting_analyzer,
            output_folder=analyzer_output_dir / "spikeinterface_report",
            **job_kwargs,
        )


_sorting_analyzer_compute()

[2024-07-24 23:34:50,476][INFO]: No existing results found, calling '_sorting_analyzer_compute'


estimate_sparsity:   0%|          | 0/120 [00:00<?, ?it/s]

compute_waveforms:   0%|          | 0/120 [00:00<?, ?it/s]

Fitting PCA:   0%|          | 0/32 [00:00<?, ?it/s]

Projecting waveforms:   0%|          | 0/32 [00:00<?, ?it/s]



Compute : spike_amplitudes + spike_locations:   0%|          | 0/120 [00:00<?, ?it/s]

calculate_pc_metrics:   0%|          | 0/32 [00:00<?, ?it/s]

write_binary_recording:   0%|          | 0/120 [00:00<?, ?it/s]

extract PCs:   0%|          | 0/120 [00:00<?, ?it/s]

Run:
phy template-gui  /Users/milagros/Documents/data/organoids/outbox/O09-12_raw/202305181225_202305181226/O09/spykingcircus2_101/spykingcircus2/sorting_analyzer/phy/params.py


In [58]:
print(
    f"Now you can explore the SpikeInterface report here: {analyzer_output_dir / 'spikeinterface_report'}\n"
    f"And the results using Phy here: {analyzer_output_dir / 'phy'}"
)

Now you can explore the SpikeInterface report here: /Users/milagros/Documents/data/organoids/outbox/O09-12_raw/202305181225_202305181226/O09/spykingcircus2_101/spykingcircus2/sorting_analyzer/spikeinterface_report
And the results using Phy here: /Users/milagros/Documents/data/organoids/outbox/O09-12_raw/202305181225_202305181226/O09/spykingcircus2_101/spykingcircus2/sorting_analyzer/phy
