In [None]:
import os

if os.path.basename(os.getcwd()) == "notebooks":
    os.chdir("..")

In [None]:
import datajoint as dj
from datetime import datetime
from pathlib import Path
from workflow.pipeline import *
import matplotlib.pyplot as plt
import numpy as np

### Select the clustering task

In [None]:
display(ephys.Clustering())

key = ephys.Clustering.fetch("KEY")[0]
key

### Basic information on spike-sorted units can be found in `ephys.CuratedClustering.Unit` table

In [None]:
ephys.CuratedClustering.Unit & key

### Examine unit waveforms

In [None]:
display(ephys.WaveformSet())
display(ephys.WaveformSet.PeakWaveform())

In [None]:
# Get waveform
unit = 1

(ephys.WaveformSet.PeakWaveform & key & f"unit={unit}").fetch1("peak_electrode_waveform")

In [None]:
# Plot the waveform and auto-correlogram
from element_array_ephys.plotting.unit_level import (
    plot_auto_correlogram,
    plot_waveform,
)
sampling_rate = (ephys.EphysSessionInfo & key).fetch1("session_info")["sample_rate"] / 1e3  # in kHz

peak_electrode_waveform, spike_times, cluster_quality_label = (
    (ephys.CuratedClustering.Unit & key & f"unit={unit}") * ephys.WaveformSet.PeakWaveform
).fetch1("peak_electrode_waveform", "spike_times", "cluster_quality_label")

# Get the figure
waveform_fig = plot_waveform(
    waveform=peak_electrode_waveform, sampling_rate=sampling_rate
)

correlogram_fig = plot_auto_correlogram(
    spike_times=spike_times, bin_size=0.001, window_size=1
)

display(waveform_fig)
display(correlogram_fig)

### Examine Quality metrics

In [None]:
display(ephys.QualityMetrics())
display(ephys.QualityMetrics.Cluster())

### Plot histogram of quality metrics 

In [None]:
def plot_metric(ax, data, bins, x_axis_label=None, title=None, color='k', smoothing=True, density=False):
    """A function modified from https://allensdk.readthedocs.io/en/latest/_static/examples/nb/ecephys_quality_metrics.html
    """
    from scipy.ndimage import gaussian_filter1d
    if any(data) and np.nansum(data):
        h, b = np.histogram(data, bins=bins, density=density)
        x = b[:-1]

        y = gaussian_filter1d(h, 1) if smoothing else h
        ax.plot(x, y, color=color)
        ax.set_xlabel(x_axis_label)
        ax.set_ylim([0, None])
    ax.set_title(title)
    ax.spines[['right', 'top']].set_visible(False)
    
query = ephys.QualityMetrics.Cluster & key

fig, axes = plt.subplots(4, 4, figsize=(12, 9))
axes = axes.flatten()
plt.suptitle(f"Cluster Quality Metrics for {key}", y=.99, fontsize=12)

# Firing Rates
data = np.log10(query.fetch("firing_rate"))
bins = np.linspace(-3,2,100)
plot_metric(axes[0], data, bins, title="Firing Rate (Hz) (log$_{10}$)")
axes[0].set_ylabel("Count")

# Signal-to-Noise Ratio
data = query.fetch("snr")
bins = np.linspace(0, 10, 100)
plot_metric(axes[1], data, bins, title="Signal-to-Noise Ratio")

# Presence Ratio
data = query.fetch("presence_ratio")
bins = np.linspace(0, 1, 100)
plot_metric(axes[2], data, bins, title="Presence Ratio")

# ISI Violation
data = query.fetch("isi_violation")
bins = np.linspace(0, 1, 100)
plot_metric(axes[3], data, bins, title="ISI Violation")

# Number Violation
data = query.fetch("number_violation")
bins = np.linspace(0, 1000, 100)
plot_metric(axes[4], data, bins, title="Number Violation")
axes[4].set_ylabel("Count")

# Amplitude Cutoff
data = query.fetch("amplitude_cutoff")
bins = np.linspace(0, 0.5, 100)
plot_metric(axes[5], data, bins, title="Amplitude Cutoff")

# Isolation Distance
data = query.fetch("isolation_distance")
bins = np.linspace(0, 170, 50)
plot_metric(axes[6], data, bins, title="Isolation Distance")

# L-Ratio
data = query.fetch("l_ratio")
bins = np.linspace(0, 1, 100)
plot_metric(axes[7], data, bins, title="L-Ratio")

# d-Prime
data = query.fetch("d_prime")
bins = np.linspace(0, 15, 50)
plot_metric(axes[8], data, bins, title="d-Prime")
axes[8].set_ylabel("Count")

# Nearest-Neighbors Hit Rate
data = query.fetch("nn_hit_rate")
bins = np.linspace(0, 1, 100)
plot_metric(axes[9], data, bins, title="Nearest-Neighbors Hit Rate")

# Nearest-Neighbors Miss Rate
data = query.fetch("nn_miss_rate")
bins = np.linspace(0, 1, 100)
plot_metric(axes[10], data, bins, title="Nearest-Neighbors Miss Rate")

# Silhouette Score
data = query.fetch("silhouette_score")
bins = np.linspace(0, 1, 100)
plot_metric(axes[11], data, bins, title="Silhouette Score")

# Max Drift
data = query.fetch("max_drift")
bins = np.linspace(0, 100, 100)
plot_metric(axes[12], data, bins, title="Max Drift")
axes[12].set_ylabel("Count")

# Cumulative Drift
data = query.fetch("cumulative_drift")
bins = np.linspace(0, 100, 100)
plot_metric(axes[13], data, bins, title="Cumulative Drift")

[ax.remove() for ax in axes[14:]]
plt.tight_layout()

## Download the spike-sorted data (SpikeInterface)

In [ ]:
from workflow.utils.initiate_session import download_spike_sorted_results

download_spike_sorted_results(key)