In [None]:
import os

if os.path.basename(os.getcwd()) == "notebooks":
    os.chdir("..")

In [None]:
import datajoint as dj
from datetime import datetime
from pathlib import Path
import numpy as np
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)
import workflow
from matplotlib import pyplot as plt
import spikeinterface as si
from spikeinterface import widgets, exporters, postprocessing, qualitymetrics, sorters
import probeinterface as pi
from probeinterface import plotting
from workflow.pipeline import *
from workflow.utils.ingestion_utils import El2ROW
from workflow.utils.paths import (
    get_ephys_root_data_dir,
    get_raw_root_data_dir,
    get_processed_root_data_dir,
)
from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory

In [None]:
from dotenv import load_dotenv

load_dotenv()

***The examples in this notebook use a sample dataset to demonstrate how to explore results. Please replace these entries with your database entries to view and analyze your data.***

### Select from the following list of experiments

In [None]:
display(
    culture.Experiment()
    .proj("experiment_end_time", "drug_name", "drug_concentration", "experiment_plan")
    .fetch(format="frame")
    .reset_index()
)

### Create `spike_sorting` sessions

In [None]:
session_info = dict(
    organoid_id="O09",
    experiment_start_time="2023-05-18 12:25:00",
    insertion_number=0,
    start_time="2023-05-18 12:25:00",
    end_time="2023-05-18 12:30:00",
    session_type="spike_sorting",
)

session_probe_info = dict(
    organoid_id="O09",
    experiment_start_time="2023-05-18 12:25:00",
    insertion_number=0,
    start_time="2023-05-18 12:25:00",
    end_time="2023-05-18 12:30:00",
    probe="Q983",  # probe serial number
    port_id="A",  # Port ID ("A", "B", etc.)
    used_electrodes=[],  # empty if all electrodes were used
)

In [None]:
# Insert the session
SPIKE_SORTING_DURATION = 120  # minutes

start_time = datetime.strptime(session_info["start_time"], "%Y-%m-%d %H:%M:%S")
end_time = datetime.strptime(session_info["end_time"], "%Y-%m-%d %H:%M:%S")
duration = (end_time - start_time).total_seconds() / 60

assert (
    session_info["session_type"] == "spike_sorting"
    and duration <= SPIKE_SORTING_DURATION
), f"Session type must be 'spike_sorting' and duration must be less than {SPIKE_SORTING_DURATION} minutes"

ephys.EphysSession.insert1(session_info, ignore_extra_fields=True, skip_duplicates=True)

ephys.EphysSessionProbe.insert1(
    session_probe_info, ignore_extra_fields=True, skip_duplicates=True
)

del session_probe_info["used_electrodes"]
display(ephys.EphysSession & session_info)
display(ephys.EphysSessionProbe & session_probe_info)

### Load data

In [None]:
query = culture.Experiment().proj("drug_name") * ephys.EphysSession & {
    "session_type": "spike_sorting"
}
key = (query & session_info).fetch1()

title = "_".join(
    [
        key["organoid_id"],
        key["start_time"].strftime("%Y%m%d%H%M"),
        key["end_time"].strftime("%Y%m%d%H%M"),
        key["drug_name"].replace(" ", ""),
    ]
)

spike_sorting_path = get_processed_root_data_dir() / "spike_sorting" / title
spike_sorting_path.mkdir(exist_ok=True, parents=True)

files, file_times = (
    ephys.EphysRawFile
    & f"file_time BETWEEN '{key['start_time']}' AND '{key['end_time']}'"
).fetch("file_path", "file_time", order_by="file_time")

[print(file) for file in files]
print(f"\nNumber of files: {len(files)} ({key['drug_name']})")

In [None]:
# 1. Read the raw data as a recording object.
# 2. Concatenate the object for one session.

stream_name = "RHD2000 amplifier channel"

if (spike_sorting_path / "recording.pkl").exists():
    recording = si.load_extractor(spike_sorting_path / "recording.pkl")
else:
    recording = None
    t_start = file_times[0]
    for file in [find_full_path(get_ephys_root_data_dir(), f) for f in files]:
        print(f"Processing {file}.")
        if not recording:
            recording = si.extractors.read_intan(file, stream_name=stream_name)
        else:
            recording = si.concatenate_recordings(
                [recording, si.extractors.read_intan(file, stream_name=stream_name)]
            )

    recording.dump_to_pickle(
        file_path=spike_sorting_path / "recording.pkl"
    )  # lazy dumping (not actual traces, only the information on how to reconstruct the recording gets dumped)
    # recording.save(folder=spike_sorting_path)  # save on disk
recording

In [None]:
# Useful APIs

# traces = recording.get_traces(return_scaled=True)  # return values in uV
# recording.get_times() # get timestamps
# recording.get_time_info()  # {'sampling_frequency': 20000.0, 't_start': None, 'time_vector': None}
# recording.neo_reader
# recording.has_time_vector()  # false
# recording.sampling_frequency

### Generate probe

In [None]:
# Get probe info
manufacturer = "neuronexus"
probe_info = (ephys.EphysSessionProbe & key).fetch1()
probe_type = ((probe.Probe * ephys.EphysSessionProbe()) & key).fetch1("probe_type")

electrode_query = probe.ElectrodeConfig.Electrode & (
    probe.ElectrodeConfig & {"probe_type": probe_type}
)
number_of_electrodes = len(electrode_query)

# Filter for used electrodes. If probe_info["used_electrodes"] is None, it means all electrodes were used.
probe_info["used_electrodes"] = probe_info["used_electrodes"] or list(
    range(number_of_electrodes)
)
unused_electrodes = [
    elec
    for elec in range(number_of_electrodes)
    if elec not in probe_info["used_electrodes"]
]
electrode_query &= f'electrode IN {tuple(probe_info["used_electrodes"])}'

channel_to_electrode_map = dict(zip(*electrode_query.fetch("channel", "electrode")))

channel_to_electrode_map = {
    f'{probe_info["port_id"]}-{int(channel):03d}': electrode
    for channel, electrode in channel_to_electrode_map.items()
}
print(channel_to_electrode_map)
lfp_indices = np.sort(np.array(electrode_query.fetch("channel"), dtype=int))

# # Useful APIs
# p.device_channel_indices
# p.contact_ids

In [None]:
# Create a custom plot using the probe information
fig, ax = plt.subplots(figsize=(7, 7))

# Create a session probe
linear_probe = pi.generate_linear_probe(
    num_elec=32, ypitch=100, contact_shape_params={"radius": 15}
)
linear_probe.set_device_channel_indices(El2ROW)

try:
    contact_colors = [
        "r" if e in probe_info["used_electrodes"] else "w"
        for e in range(number_of_electrodes)
    ]  # red for used channels
except TypeError:
    contact_colors = list("r" * number_of_electrodes)

# Plot the probe
pi.plotting.plot_probe(linear_probe, ax=ax, contacts_colors=contact_colors)
[spine.set_visible(False) for spine in ax.spines.values()]
ax.yaxis.set_ticks_position("none")  # Remove y-axis tick marks
ax.set_xticks([])
ax.set_xlabel("")
ax.set_ylabel("($\\mu m$)", fontsize=10)
ax.set_title(title + "\n" + probe_type)
contact_positions = linear_probe.contact_positions
device_channel_indices = [
    f"{probe_info['port_id']}-{ch:03}" for ch in linear_probe.device_channel_indices
]

for (x, y), txt in zip(contact_positions, device_channel_indices):
    ax.text(x + 100, y, txt, va="center", fontsize=8)

if not (spike_sorting_path / "probe.pdf").exists():
    fig.savefig(spike_sorting_path / "probe.pdf")

recording = recording.set_probe(linear_probe)
recording.get_probe().to_dataframe(complete=True).sort_values(
    by="contact_ids", key=lambda col: col.astype(int)
)

#### Preprocessing

In [None]:
# Remove unused electrodes
if unused_electrodes:
    recording = recording.remove_channels(
        remove_channel_ids=np.array([str(elec) for elec in unused_electrodes])
    )
print(recording)
print(recording.get_probe())

In [None]:
from spikeinterface import preprocessing

recording_f = si.preprocessing.bandpass_filter(
    recording=recording, freq_min=300, freq_max=6000
)
recording_cmr = si.preprocessing.common_reference(
    recording=recording_f, operator="median"
)

trace_raw = recording.get_traces(
    start_frame=100_000, end_frame=101_000, return_scaled=True
)
trace_preprocessed = recording_cmr.get_traces(
    start_frame=100_000, end_frame=101_000, return_scaled=True
)

plt.plot(trace_raw[:, 0], label="Raw")
plt.plot(trace_preprocessed[:, 0], label="Preprocessed")
plt.legend()

del trace_raw, trace_preprocessed

In [None]:
import seaborn as sns

data = recording_cmr.get_traces(start_frame=0, end_frame=20000, return_scaled=True)

fig, ax = plt.subplots(figsize=(20, 10))
ytick_loc = []
offset = 50

for i in range(data.shape[1]):
    ax.plot(
        np.r_[: data.shape[0]] / recording_cmr.sampling_frequency,
        data[:, i] + i * offset,
        linewidth=0.2,
    )
    ytick_loc.append(i * offset)

ax.set_yticks(ytick_loc)
ax.set_yticklabels([device_channel_indices[i] for i in probe_info["used_electrodes"]])
ax.set_title(title)
ax.tick_params(length=0)
ax.set(xlabel="Time (s)")
sns.despine(right=True, left=True)

if not (spike_sorting_path / "raw_trace.png").exists():
    fig.savefig(spike_sorting_path / "raw_trace.png")

### Run sorter

#### spiking circus 2

In [None]:
# Install the following for running spiking circus
# !pip install hdbscan
# !pip install numba

In [None]:
# Run sorter. Load the sorting data if it already exists
sorter_name = "spykingcircus2"
sorting_folder = spike_sorting_path / sorter_name

if (sorting_folder / "sorting.pkl").exists():
    sorting = si.load_extractor(sorting_folder / "sorting.pkl")
else:
    sorting = si.sorters.run_sorter(
        recording=recording_cmr,
        output_folder=sorting_folder,
        sorter_name=sorter_name,
        remove_existing_folder=True,
        verbose=True,
    )

    sorting.dump_to_pickle(file_path=sorting_folder / "sorting.pkl")
    # sorting.save(folder=sorting_folder)

In [None]:
# Waveform extraction
# Load if the waveform folder already exists. Otherwise, extract waveforms from the recording.
if (sorting_folder / "waveform").exists():
    we = si.load_waveforms(sorting_folder / "waveform", with_recording=True)

else:
    we = si.extract_waveforms(
        recording_cmr,
        sorting,
        folder=sorting_folder / "waveform",
        ms_before=1.5,
        ms_after=2.0,
        max_spikes_per_unit=500,
        # overwrite=True,
    )
    print(we)

In [None]:
plt.plot(we.get_template(0))

In [None]:
si.widgets.plot_unit_templates(we, unit_ids=sorting.unit_ids[:5], ncols=5)
# si.widgets.plot_unit_templates(we, ncols=5)

In [None]:
unit_id = 34
si.widgets.plot_unit_waveforms(we, unit_ids=[sorting.unit_ids[unit_id]])

fig, ax = plt.subplots()
template = we.get_template(unit_id=sorting.unit_ids[unit_id], mode="median")
ax.plot(template[:, 0])
plt.show()

In [None]:
# Plot rasters
fig, ax = plt.subplots(1, 1, figsize=(20, 15))
si.widgets.plot_rasters(sorting, time_range=[0, 5], ax=ax)
ax.set_ylabel("Unit ID")
ax.set_title(title)
# si.widgets.plot_rasters(sorting, time_range=[0, 5], unit_ids=[unit_id], ax=ax)
sns.despine()

if not (spike_sorting_path / "raster.png").exists():
    fig.savefig(spike_sorting_path / "raster.png")

#### QC metrics

In [None]:
# Save quality metrics
metrics = si.qualitymetrics.compute_quality_metrics(
    we,
    metric_names=[
        "firing_rate",
        "snr",
        "presence_ratio",
        "isi_violation",
        "num_spikes",
        "amplitude_cutoff",
        "amplitude_median",
        "sliding_rp_violation",
        "rp_violation",
        "drift",
    ],
)

metrics.to_csv(sorting_folder / "metrics.csv")
metrics

In [None]:
# Export to report
# from spikeinterface import exporters
_ = si.postprocessing.compute_spike_amplitudes(waveform_extractor=we)
# _ = si.postprocessing.compute_correlograms(waveform_extractor=we)
_ = si.qualitymetrics.compute_quality_metrics(
    waveform_extractor=we, metric_names=["snr", "isi_violation", "presence_ratio"]
)

si.exporters.export_report(we, output_folder=sorting_folder / "report")

#### Kilosort2_5

In [None]:
# !pip install docker
# !pip install cuda-python

In [None]:
sorter_name = "kilosort2"
sorting_folder = (
    get_processed_root_data_dir()
    / (key["organoid_id"] + "-" + str(key["start_time"].time()).replace(":", "-"))
    / sorter_name
)

sorting_kilosort = si.sorters.run_sorter(
    recording=recording_cmr,
    sorter_name=sorter_name,
    output_folder=sorting_folder,
    remove_existing_folder=True,
    verbose=True,
    docker_image=True,
)