# SpikeInterface Processing Pipeline for OpenEphys Neuropixels 2 & raw Axona recordings
### Jake Swann, 2024

##### This is a notebook which takes a spreadsheet as input with information on NP2 OpenEphys recordings, and sorts all unsorted recordings in a loop. It will concatenate all recordings made for each animal on each unique day, and sort them all together, to be split apart afterwards
##### Each path in the spreadsheet should be to a folder containing all recordings in a given day with the file structure: `base_folder/rXXXX/YYYY-MM-DD/`. Trial names should match those in the spreadsheet.
##### Required spreadsheet columns are: `trial_name, path (animal and date parts only), probe_type ('NP2_openephys'), num_channels (384), include ('Y')`
##### The script loads them as a [SpikeInterface](https://github.com/SpikeInterface) object & attaches probe geometry, spike sorts using [Kilosort2 (Axona)/ Kilosort 4 (Neuropixels)](https://github.com/MouseLand/Kilosort), and allows curation of the output in the [phy](https://github.com/cortex-lab/phy/) template-gui
##### **N.B.** This requires a Python 3.8 environment with SpikeInterface v0.101+ installed
---

In [1]:
sheet_path = 'https://docs.google.com/spreadsheets/d/1cZxgOw7worcVZq8wIPslmU2jD__xm1MXnNgbs1-9ros/edit#gid=0'
path_to_data = '/home/isabella/Documents/isabella/jake/recording_data/'
sorting_suffix = 'sorting_ks4'
probe_to_sort = 'NP2_openephys'

################################################################################################################

import os
import numpy as np
import pandas as pd
import spikeinterface as si
from spelt.session_utils import gs_to_df
from spelt.np2_utils.np2_preprocessing import sort_np2
from spelt.axona_utils.axona_preprocessing import sort_axona
from spelt.sorting_utils.collect_sessions import collect_sessions

# Load & format Google sheet, collect trials and sessions
sheet = gs_to_df(sheet_path)
sheet['path'] = path_to_data + sheet['path']
sheet_inc = sheet[sheet['Include'] == 'Y']
sheet_inc = sheet_inc[sheet_inc['probe_type'] == probe_to_sort]
trial_list = sheet_inc['trial_name'].to_list()
session_list = np.unique([f"{i.split('_')[0]}_{i.split('_')[1]}" for i in trial_list]) # add area

# Collect recordings for concatenation and sorting
recording_list = collect_sessions(session_list, trial_list, sheet_inc, probe_to_sort)

# Concatenate over a single session and sort
for recording in recording_list:
	session = pd.DataFrame(recording)
	base_folder = session.iloc[0,2]
	probe_type = session.iloc[0,3]

	# Concatenate recordings
	recordings_concat = si.concatenate_recordings(session.iloc[:,0].to_list())
	print(f'Sorting {recordings_concat}')

	if probe_type == 'NP2_openephys':
		# Save concatenated recording to .dat
		if f'concat.dat' in os.listdir(base_folder):
			print(f'{base_folder}/concat.dat already exists, skipping concatenation')
		else:
			si.write_binary_recording(recordings_concat, f'{base_folder}/concat.dat')
			print(f'Concatenated recording saved to {base_folder}/concat.dat')
		# Sort concatenated recording
		sorting = sort_np2(recording = recordings_concat, 
				recording_name = session.iloc[0,1], 
				base_folder = session.iloc[0,2],
				sorting_suffix = sorting_suffix)
		
	elif probe_type == '5x12_buz':
		sorting = sort_axona(recording = recordings_concat, 
				recording_name = session.iloc[0,1], 
				base_folder = session.iloc[0,2],
				electrode_type = session.iloc[0,3],
				sorting_suffix = sorting_suffix)

	session.to_csv(f'{session.iloc[0,2]}/{session.iloc[0,1][:6]}_{sorting_suffix}/session.csv') #save session trial info to .csv

Loading /home/isabella/Documents/isabella/jake/recording_data/r1536/2024-04-29/240429_r1536_blackbox_1
18288308
Sorting ConcatenateSegmentRecording: 384 channels - 30.0kHz - 1 segments - 18,288,308 samples 
                             609.61s (10.16 minutes) - int16 dtype - 13.08 GiB
/home/isabella/Documents/isabella/jake/recording_data/r1536/2024-04-29/concat.dat already exists, skipping concatenation
Loading recording with SpikeInterface...
number of samples: 18288308
number of channels: 384
numbef of segments: 1
sampling rate: 30000.0
dtype: int16
Preprocessing filters computed in  2.07s; total  2.07s

computing drift
Re-computing universal templates from data.


100%|██████████| 305/305 [06:56<00:00,  1.37s/it]


drift computed in  420.76s; total  422.84s

Extracting spikes using templates
Re-computing universal templates from data.


100%|██████████| 305/305 [05:49<00:00,  1.15s/it]


448447 spikes extracted in  353.89s; total  776.72s

First clustering


100%|██████████| 24/24 [00:29<00:00,  1.23s/it]


478 clusters found, in  29.90s; total  806.62s

Extracting spikes using cluster waveforms


100%|██████████| 305/305 [01:31<00:00,  3.35it/s]


758070 spikes extracted in  91.32s; total  897.94s

Final clustering


100%|██████████| 24/24 [00:32<00:00,  1.34s/it]


436 clusters found, in  32.17s; total  930.12s

Merging clusters
422 units found, in  1.56s; total  931.68s

Saving to phy and computing refractory periods
166 units found with good refractory periods

Total runtime: 936.20s = 00:15:36 h:m:s
kilosort4 run time 936.62s
Recording sorted!
 KS4 found 422 units

Sorting saved to /home/isabella/Documents/isabella/jake/recording_data/r1536/2024-04-29/240429_sorting_ks2_custom/sort





#### Unused Code

In [None]:
import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.widgets as sw

recording_path = '/data/isabella/jake/recording_data/NP2 data/2024-03-15/test/2024-03-15_13-05-49'
sorting_path = '/data/isabella/jake/recording_data/NP2 data/2024-03-15/test/kilosort4'

recording = se.read_openephys(folder_path=recording_path, stream_id = '0')
sorting = se.read_phy(sorting_path, exclude_cluster_groups=['noise', 'mua'])


import spikeinterface.postprocessing as sp
sorting_analyzer = si.create_sorting_analyzer(sorting=sorting, recording=recording)
sorting_analyzer.compute('random_spikes')
sorting_analyzer.compute('waveforms')
sorting_analyzer.compute_one_extension('templates')
si.postprocessing.compute_template_metrics(sorting_analyzer)
unit_locations = sorting_analyzer.compute(input="unit_locations", method="monopolar_triangulation")

sw.plot_rasters(sorting, time_range=[0, 10])



Loading recording with SpikeInterface...
number of samples: 18613123
number of channels: 384
numbef of segments: 1
sampling rate: 30000.0
dtype: int16
Interpreting binary file as default dtype='int16'. If data was saved in a different format, specify `data_dtype`.
Using GPU for PyTorch computations. Specify `device` to change this.


TypeError: string indices must be integers