# Online pipeline for Feldman lab

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from pprint import pprint

import spikeextractors as se
import spiketoolkit as st
import spikesorters as ss
import spikecomparison as sc
import spikewidgets as sw
import time
import nwbwidgets

%matplotlib notebook

In [None]:
# set parameters for parallelization
n_jobs = 8 # number of concurrent jobs
chunk_mb = 2000  # max amount of RAM in Mb

## 1) Load short AP and LF recordings

In [None]:
# base_path = Path("D:/Feldman")
# #base_path = Path("/Users/abuccino/Documents/Data/catalyst/brody/A256_bank1_2020_09_30_g0")
# #base_data_path = Path("D:/Neuropixels/Neuropixels/A256_bank1_2020_09_30/A256_bank1_2020_09_30_g0")
# base_data_path = Path("20210115_NPX_and_behavior/2021_01_15_E105/towersTask_g0")
# ap_bin_path = base_data_path / "towersTask_g0_imec0" / "towersTask_g0_t0.imec0.ap.bin"
# lf_bin_path = base_data_path / "towersTask_g0_imec0" / "towersTask_g0_t0.imec0.lf.bin"

In [None]:
base_path = Path("/Users/abuccino/Documents/Data/catalyst/feldman/")
session_name = "LR_210209_2_g1"
# session_name = "LR_210209_g1"
# session_name = "LR_210209_2_g0"
# session_name = "LR_210209_2_g1"
ap_bin_path = base_path / session_name / f"{session_name}_imec0" / f"{session_name}_t0.imec0.ap.bin"
lf_bin_path = base_path / session_name / f"{session_name}_imec0" / f"{session_name}_t0.imec0.lf.bin"
nidq_bin_path = base_path / session_name / f"{session_name}_g0_t0.imec0.nidq.bin"

In [None]:
recording_ap = se.SpikeGLXRecordingExtractor(ap_bin_path)

In [None]:
duration = recording_ap.get_num_frames() / recording_ap.get_sampling_frequency()
fs = recording_ap.get_sampling_frequency()
print(f"Duration: {np.round(duration, 1)} s")

In [None]:
# for testing, cut out 2 minutes
subrec = se.SubRecordingExtractor(recording_ap, start_frame=60*fs, end_frame=180*fs)
subrec.get_num_frames() / subrec.get_sampling_frequency()

# 2) Quick spike detection by channel

In [None]:
t_start = time.time()
sorting_ch = st.sortingcomponents.detect_spikes(recording=subrec,
                                                n_jobs=n_jobs, 
                                                chunk_mb=chunk_mb,
                                                verbose=True)
t_stop = time.time()
print(f"Elapsed time for detection: {t_stop - t_start}")

In [None]:
print(f"Detected spikes on {len(sorting_ch.get_unit_ids())} channels")

In [None]:
wr = sw.plot_rasters(sorting_ch)

### (optional) Remove channels below a certain firing rate

In [None]:
firing_rate_threshold = 0.1
num_frames = subrec.get_num_frames()

sorting_high_fr = st.curation.threshold_firing_rates(
        sorting_ch,
        duration_in_frames=num_frames,
        threshold=firing_rate_threshold, 
        threshold_sign='less'
    )

In [None]:
print(f"Detected spikes on {len(sorting_high_fr.get_unit_ids())} channels with fr > {firing_rate_threshold}")

# 3) Parse behavioral data from nidq file

In [None]:
# TODO

# 4) Save spike and behavior info to NWB

In [None]:
# The name of the NWBFile containing behavioral data
nwbfile_path = base_data_path / f"Feldman_prototype_{session_name}.nwb"

se.NwbSortingExtractor.write_recording(
    sorting=recording_ap,
    save_path=nwbfile_path,
    overwrite=False  # this appends the file. True would write a new file
)
se.NwbSortingExtractor.write_sorting(
    sorting=quick_sort,
    save_path=nwbfile_path,
    overwrite=False  # this appends the file. True would write a new file
)

# 5) View output vs. behavior in NWBWidgets 

In [None]:
io = NWBHDF5IO(nwbfile_path, mode='r')
nwb = io.read()

nwb2widget(nwb)