# Spike sorting in Pipeline
This notebook shows how the ``Waveformer`` and the ``WaveformSorter`` classes can be used in a Scikit-learn Pipeline to process a trace, extract waveforms, and label them. The notebook is structured as follows:

0. **Set up:** load dependencies and define functions.
1. **Load data:** load traces from a Utah array .ns6 file.
2. **Sort spikes:** use the provided classes in a Scikit-learn pipeline

## 0. Set up

In [1]:
import os
import sys
import brpylib

import pandas as pd
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.decomposition import PCA

from spikesorting import Waveformer, WaveformSorter

from IPython.display import display_html
def display_side_by_side(*args):
    html_str=''
    for df in args:
        html_str+=df.to_html()
    display_html(html_str.replace('table','table style="display:inline"'),raw=True)

## 1. Load data

In [2]:
# Write your .ns6 file name
FILE_NAME = r'i140703-001.ns6'
# Load recordings (sys.stdout is used to avoid nasty printing of brpylib)
sys.stdout = open(os.devnull, "w")
data = brpylib.NsxFile(os.path.join(os.getcwd(), FILE_NAME)).getdata(elec_ids='all', start_time_s=0, data_time_s=120)              
sys.stdout = sys.__stdout__
X = data['data'].squeeze()

## 2. Sort spikes

In [3]:
# Build pipeline
pipeline = Pipeline([
    ('waveformer', Waveformer(visualize=False)),
    ('pca', ColumnTransformer([('timestamp', 'passthrough', slice(0,1)),
                               ('pca', PCA(n_components=3), slice(1,None))])),
    ('sorter', WaveformSorter(visualize=False))
])

# Sort spikes for the first 6 channels
output = []
for ch in range(6):
    output.append(pipeline.fit(X[ch]).predict(X[ch]))

# Gather the timestamps and labels in a list of DataFrames
output_df = []
for ch in range(6):
    output_df.append(pd.DataFrame(output[ch], columns=['timestamps', 'labels']))
    
display_side_by_side(output_df[0].head(), output_df[1].head(), output_df[2].head(), 
                     output_df[3].head(), output_df[4].head(), output_df[5].head())

Unnamed: 0,timestamps,labels
0,0.022533,0.0
1,0.022533,0.0
2,0.033333,1.0
3,0.074,1.0
4,0.074,1.0

Unnamed: 0,timestamps,labels
0,0.0412,0.0
1,0.0568,0.0
2,0.065167,0.0
3,0.0669,0.0
4,0.092633,0.0

Unnamed: 0,timestamps,labels
0,0.0087,0.0
1,0.076,0.0
2,0.0966,1.0
3,0.146533,0.0
4,0.175267,0.0

Unnamed: 0,timestamps,labels
0,0.008333,1.0
1,0.088033,1.0
2,0.1709,1.0
3,0.252033,1.0
4,0.3022,1.0

Unnamed: 0,timestamps,labels
0,0.185,0.0
1,0.3349,0.0
2,0.5041,0.0
3,0.617433,0.0
4,0.660033,0.0

Unnamed: 0,timestamps,labels
0,0.099,0.0
1,0.104767,0.0
2,0.121767,0.0
3,0.1911,0.0
4,0.194733,1.0
