# DataJoint Workflow for Neuropixels Analysis with Kilosort

+ This notebook will describe the steps for interacting with the data processed with the workflow *outside* of the Docker container.

+ This workflow is assembled from 4 DataJoint elements:
     + [element-lab](https://github.com/datajoint/element-lab)
     + [element-animal](https://github.com/datajoint/element-animal)
     + [element-session](https://github.com/datajoint/element-session)
     + [element-array-ephys](https://github.com/datajoint/element-array-ephys)

+ DataJoint provides abundant functions to query and fetch data.  For a detailed tutorials, visit our [general tutorial site](https://playground.datajoint.io/)

## Requirements: 
Before getting started, you will need local copies of the following repositories:
+ [ecephys](https://github.com/jenniferColonell/ecephys_spike_sorting.git)
+ [kilosort](https://github.com/MouseLand/Kilosort/releases/tag/v2.5)
+ [npy-matlab](https://github.com/kwikteam/npy-matlab.git)
+ [CatGT](https://billkarsh.github.io/SpikeGLX/#catgt)
+ [T-prime](https://billkarsh.github.io/SpikeGLX/#tprime)
+ [C-waves](https://billkarsh.github.io/SpikeGLX/#post-processing-tools)

+ You will also need to properly configure Kilosort and the [MATLAB-engine for python](https://www.mathworks.com/help/matlab/matlab_external/install-the-matlab-engine-for-python.html)

#### Login to the database and load your modules:

In [None]:
import os
if os.path.basename(os.getcwd()) == "notebooks": os.chdir("..")
import datajoint as dj
dj.config.load('dj_local_conf.json')
dj.conn()

from workflow.pipeline import lab, subject, session, probe, ephys
from pathlib import Path
import datetime
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

In [None]:
#set env vars
%env ecephys_directory= C:\Users\janet\Documents\GitHub\ecephys_spike_sorting\ecephys_spike_sorting
%env kilosort_repository= C:\Users\janet\Documents\GitHub\Kilosort-2.5
%env npy_matlab_repository= C:\Users\janet\Documents\GitHub\npy-matlab
%env catGTPath= C:\Users\janet\Documents\GitHub\CatGT
%env tPrime_path= C:\Users\janet\Documents\GitHub\TPrime
%env cWaves_path= C:\Users\janet\Documents\GitHub\C_Waves
%env kilosort_output_tmp= C:\Users\janet\Documents\GitHub\kilosort_output_tmp


## 1. Insert Subject, Session, and Session Directory

In [None]:
subject_name = "EF314"
session_num = 20231009
data_dir = r"EF314\20231009"

In [None]:
subject.Subject.insert1(dict(subject= subject_name, 
                             sex='M', 
                             subject_birth_date='2022-12-27', 
                             subject_description='practice NP'))

In [None]:
session_key = dict(subject=subject_name, session_id=session_num,
                   session_datetime='2023-10-09 12:00:00')

session.Session.insert1(session_key)

In [None]:
session.SessionDirectory.insert1(dict(subject=session_key['subject'], session_id=session_key['session_id'],
                                      session_dir= data_dir ))

In [None]:
sd_key = dict(session_key, session_dir = data_dir)

## 2. Register Ephys Recording and Probe

In [None]:
probe.ProbeType()

#if new ProbeType, insert into lookup table:

#probe.ProbeType.insert1(dict(probe_type = ''))

In [None]:
#make sure to retrieve probe number from imec.meta file. Use imDatPrb_sn
probe.Probe.insert1(
    dict(probe="19454421672", probe_type="neuropixels 1.0 - 3B")
)

In [None]:
probe.Probe()

In [None]:
ephys.ProbeInsertion.insert1(
    dict(
        subject=session_key['subject'],
        session_id=session_key['session_id'],
        insertion_number=2,
        probe="19454421672",
    )
) 

# probe, subject, session_datetime needs to follow the restrictions of foreign keys.
ephys.ProbeInsertion()

In [None]:
insertion_key = dict(session_key, insertion_number =2)
insertion_key

In [None]:
ephys.EphysRecording.populate(insertion_key, display_progress=True)
ephys.EphysRecording()

## 3. Run the clustering task

In [None]:
params = {
    "fs": 10000,
    "fshigh": 150,
    "minfr_goodchannels": 0.1,
    "Th": [10, 4],
    "lam": 10,
    "AUCsplit": 0.9,
    "minFR": 0.02,
    "momentum": [20, 400],
    "sigmaMask": 30,
    "ThPr": 8,
    "spkTh": -6,
    "reorder": 1,
    "nskip": 25,
    "GPU": 1,
    "nfilt_factor": 4,
    "ntbuff": 64,
    "whiteningRange": 32,
    "nSkipCov": 25,
    "scaleproc": 200,
    "nPCs": 3,
    "useRAM": 0,
    "run_CatGT": False
}

ephys.ClusteringParamSet.insert_new_params(
    clustering_method="kilosort2.5", #this comes from a lookup table 
    paramset_idx=3,
    params=params,
    paramset_desc="Spike sorting using Kilosort2.5",
)

ephys.ClusteringParamSet()

In [None]:
ephys.ClusteringTask.insert1(
    dict(
        subject=session_key['subject'],
        session_id=session_key['session_id'],
        insertion_number=2,
        paramset_idx=3,
        clustering_output_dir=r'EF314\20231009\Ephys\kilosort2-5',
        task_mode="trigger",  # load or trigger; trigger will trigger the sorting, clustering load will load existing output files
    )
)

In [None]:
#query for clustering key
clustering_key=(ephys.ClusteringTask() & "paramset_idx = '3'").fetch1("KEY")
clustering_key

In [None]:
ephys.Clustering.populate(clustering_key, display_progress=True)

## 4. Curate the clustering Results

In [None]:
curation_key = (ephys.Clustering & "paramset_idx = '3'").fetch1("KEY")
curation_key

In [None]:
#ephys.Clustering().create1_from_clustering_task(curation_key)

In [None]:
ephys.CuratedClustering.populate(curation_key)
ephys.WaveformSet.populate(curation_key, display_progress=True)

In [None]:
ephys.CuratedClustering.Unit()

## 5. Insert LFP recordings

In [None]:
LFP_key = (ephys.EphysRecording() & "insertion_number = '2'").fetch1("KEY")

In [None]:
ephys.LFP.populate(LFP_key, display_progress=True)

## 6. Visualize your results

In [None]:
lfp_average = (ephys.LFP & "insertion_number = '2'").fetch1("lfp_mean")

In [None]:
plt.plot(lfp_average)
plt.title("Average LFP Waveform for Insertion 1")
plt.xlabel("Samples")
plt.ylabel("microvolts (uV)");

In [None]:
units, unit_spiketimes = (
    ephys.CuratedClustering.Unit
    & insertion_key
    & 'unit IN ("6","7","9","14","15","17","19")'
).fetch("unit", "spike_times")

In [None]:
x = np.hstack(unit_spiketimes)
y = np.hstack([np.full_like(s, u) for u, s in zip(units, unit_spiketimes)])
plt.plot(x, y, "|")
plt.xlabel("Time (s)")
plt.ylabel("Unit");

In [None]:
unit_key = (ephys.CuratedClustering.Unit & insertion_key & "unit = '15'").fetch1("KEY")
unit_data = (
    ephys.CuratedClustering.Unit * ephys.WaveformSet.PeakWaveform & unit_key
).fetch1()

In [None]:
sampling_rate = (ephys.EphysRecording & insertion_key).fetch1(
    "sampling_rate"
) / 1000  # in kHz
plt.plot(
    np.r_[: unit_data["peak_electrode_waveform"].size] * 1 / sampling_rate,
    unit_data["peak_electrode_waveform"],
)
plt.xlabel("Time (ms)")
plt.ylabel(r"Voltage ($\mu$V)")