# McsPyDataTools Tutorial for data analysis: Simple Tetrode Spike Sorting

This tutorial gives an introduction into data analysis with the McsPyDataTools toolbox using a simple algorithm for spike sorting of tetrode data as an example.

In this tutorial we show you, how you sort your MCS Data in the MCSPy package.
The estanblished spike sorting pipeline consists of the following steps:
1. spike detection and collection of spike cutouts
2. feature extraction
3. clustering
4. validation
MCS Software provides you with very sofiicated tools to perform spike detection and cutout collection. When converted to hdf5 using the MCS DataManager the results are stored in two streams, a SegmentStream for the spike cutouts and a TimeStampStream for the spike timestamps. This python package requires these two streams to extend the capabilities of MCS Software with offline spike sorting for tetrodes.
The Sort() interface from the sorting module is a callable python class. It performs PCA for feature extraction and clustering using a method defined by the user. A manual on how to use the interface will be introduced in the respective section of this tutorial.

First we are going to import all necessary modules.

In [None]:
# SHOW ALL OUTPUTS
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

import os
import numpy as np

# MCS PyData tools
import McsPy.McsData

# IMPORT FROM src
from McsPy.McsSpikeSorting.sorting import Sort

# VISUALIZATION TOOLS
import matplotlib.pyplot as plt
import matplotlib as mpl
%matplotlib inline

# SUPRESS WARNINGS
import warnings
warnings.filterwarnings('ignore')

# autoreload modules
%load_ext autoreload
%autoreload 2

And define some handy path variables to locate our test dataset:

In [None]:
path2data = '..\\data\\raw'
path2extdata = '..\\data\\external'
path2interdata = '..\\data\\interim'

## Load data
Next, we are going to use the MCSPy Datatools to load our test data file. The file is part of the test data files provided [here](https://www.multichannelsystems.com/software/multi-channel-datamanager). The toolbox is designed to perform spike **sorting on** MCS files that contain **Segment Stream**s of type **Cutout**, where we have precomputed the spikes using MCS software. Each segment stream entity corresponds to one channel. We need to define the four indices that make up our tetrode. Note! We expect each spike to be CUTOUT on ALL FOUR CHANNELS!

In [None]:
# PARAMETERS
filename = 'AnalogSegmentTimestamp.h5'
tetrode = [0, 1, 2, 3]
dt = 40e-6
fs = 1/dt  # sampling rate of the signal in Hz

file = McsPy.McsData.RawData(os.path.join(path2data, filename))
segment_stream = file.recordings[0].segment_streams[0]
timestamp_stream = file.recordings[0].timestamp_streams[0]

spikes = np.array([segment_stream.segment_entity[tetrode[0]].data,
                   segment_stream.segment_entity[tetrode[1]].data,
                   segment_stream.segment_entity[tetrode[2]].data,
                   segment_stream.segment_entity[tetrode[3]].data])
spikes = spikes.transpose((1,2,0))  # we need to conver the data into the form [#num samples cutout , #spikes , #num channels]

train = np.array([timestamp_stream.timestamp_entity[tetrode[0]].data,
                  timestamp_stream.timestamp_entity[tetrode[1]].data,
                  timestamp_stream.timestamp_entity[tetrode[2]].data,
                  timestamp_stream.timestamp_entity[tetrode[3]].data])

Let's have a look at the data.

In [None]:
spikes.shape

Plot first 25 spike waveforms

In [None]:
w_len = spikes.shape[0]
start = -1*w_len//2
tt = np.arange(start, start+w_len) * dt * 1000
ymax= np.max(spikes)
ymin= np.min(spikes)

_ = plt.figure(figsize=(11, 8))

for i in range(spikes.shape[2]):
    _ = plt.subplot(2,2,i+1)
    _ = plt.plot(tt,spikes[:,1:25,i],'k', linewidth=1, alpha=0.3)
    _ = plt.ylim((ymin, ymax))
    _ = plt.xlim((tt[0],tt[-1]))
    _ = plt.ylabel('ADC value')
    _ = plt.xlabel('time [$\mu$s]')
    _ = plt.title('channel {}'.format(i))

_ = plt.tight_layout()

Plot largest 25 spike waveforms

In [None]:
idx = np.argsort(np.min(np.min(spikes,axis=2),axis=0))

tt = np.arange(start, start+w_len) * dt * 1000
ymax= np.max(spikes)*1.05
ymin= np.min(spikes)*1.05

_ = plt.figure(figsize=(11, 8))
for i in range(spikes.shape[2]):
    _ = plt.subplot(2,2,i+1)
    _ = plt.plot(tt,spikes[:,idx[0:25],i],'k', linewidth=1, alpha=0.3)
    _ = plt.ylim((ymin, ymax))
    _ = plt.xlim((tt[0],tt[-1]))
    _ = plt.ylabel('ADC value')
    _ = plt.xlabel('time [$\mu$s]')
    _ = plt.title('channel {}'.format(i))

_ = plt.tight_layout()

## Sorting

For sorting, the McsPy package provides the callable interface class Sort().
Using
- `sorter=Sort()`             ... we initialize the interface
- `assignments = sorter()`    ... we fits the underlying model for clustering and assign our spikes to detected units
All model parameters have to be passed during initialization. For model parameters and their documentation please check the cluster module. DO NOT forget to validate your results. A preliminary set of tools for validation is available in the validation module.

In [None]:
sorter = Sort(method='modt', max_iter=30)
assignments = sorter(segment_stream, timestamp_stream)

Let's visualize the results. We are going to plot the first principle components of every channel against each other.

In [None]:
# PARAMETERS
num_clusters = np.unique(assignments).shape[0]
b = sorter.features

# SETUP COLOR MAP
norm = mpl.colors.Normalize(vmin=0, vmax=num_clusters)
cmap = plt.cm.jet
m = plt.cm.ScalarMappable(norm=norm, cmap=cmap)

# PLOT
_ = plt.figure(figsize=(20, 10))
_ = plt.suptitle('Scatter plots',fontsize=20)

idx = [0, 3, 6, 9]
p = 1
labels = ['Ch1','Ch2','Ch3','Ch4']
maxb = np.max(np.abs(b))*1.05/2
for i in np.arange(0,4):
    for j in np.arange(i+1,4):
        ax = plt.subplot(2,3,p, aspect='equal')
        _ = plt.scatter(b[:,idx[i]], 
                        b[:,idx[j]], 
                        c=assignments, 
                        alpha=0.3, 
                        s=4, 
                        cmap=cmap)
        _ = plt.xlabel(labels[i])
        _ = plt.ylabel(labels[j])
        _ = plt.xlim((-maxb,maxb))
        _ = plt.ylim((-maxb,maxb))
        _ = ax.set_xticks([])
        _ = ax.set_yticks([])
#         _ = ax.legend(np.sort(np.unique(assignments_first_modt)))
        p = p+1