# Implementing a Deep Learning Spike Sorting Pipeline

The objective of this project is to implement a spike sorting project using deep learning techniques. 

We utilized the following core libraries:
- SpikeInterface is used for handling extracellular data present in NWB files 
- PyTorch is used for building Tensors, data loaders, and neural networks

The general overview of this approach is to create a labeled image dataset which will then be used to train different neural network architectures to detect and sort spikes.

---

## Phase 2: A Semi-Supervised Method

### 1. Extract the data in an NWB file

#### Import required modules

In [1]:
import matplotlib as mpl
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd

from pathlib import Path
import os
import pickle

import psutil
import multiprocessing as mp
from tqdm.notebook import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
from collections import Counter
from sklearn.model_selection import train_test_split

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchinfo import summary
import torchvision
from torchvision import utils

import warnings
warnings.simplefilter("ignore")

%matplotlib widget

In [2]:
import spikeinterface as si
import spikeinterface.extractors as se 
import spikeinterface.preprocessing as spre
import spikeinterface.widgets as sw
from spikeinterface.sortingcomponents.peak_detection import detect_peaks

In [3]:
print(f"SpikeInterface version: {si.__version__}")

SpikeInterface version: 0.97.1


In [4]:
import deepspikesort as dss

#### Read the NWB file

In [5]:
base_folder = Path(".")
nwb_file = "sub-CSHL049_ses-c99d53e6-c317-4c53-99ba-070b26673ac4_behavior+ecephys+image.nwb"

In [6]:
recording_nwb = se.read_nwb_recording(file_path=nwb_file, electrical_series_name='ElectricalSeriesAp')
recording_nwb

NwbRecordingExtractor: 768 channels - 1 segments - 30.0kHz - 4172.982s
  file_path: /pscratch/sd/v/vlavan/sub-CSHL049_ses-c99d53e6-c317-4c53-99ba-070b26673ac4_behavior+ecephys+image.nwb

In [7]:
recording_nwb.annotate(is_filtered=False)

In [8]:
sorting_nwb = se.read_nwb_sorting(file_path=nwb_file, electrical_series_name='ElectricalSeriesAp')
sorting_nwb

NwbSortingExtractor: 423 units - 1 segments - 30.0kHz
  file_path: /pscratch/sd/v/vlavan/sub-CSHL049_ses-c99d53e6-c317-4c53-99ba-070b26673ac4_behavior+ecephys+image.nwb

#### Preprocess the recording

In [9]:
recording_f = spre.bandpass_filter(recording_nwb, freq_min=300, freq_max=6000)
recording_f

BandpassFilterRecording: 768 channels - 1 segments - 30.0kHz - 4172.982s

In [10]:
recording_cmr = spre.common_reference(recording_f, reference='global', operator='median')
recording_cmr

CommonReferenceRecording: 768 channels - 1 segments - 30.0kHz - 4172.982s

#### Extract channels and spikes information

In [11]:
recording_slice = dss.channel_slice_electricalseriesap(recording_cmr)
recording_slice

ChannelSliceRecording: 384 channels - 1 segments - 30.0kHz - 4172.982s

In [12]:
channels_table = dss.extract_channels(recording_slice)
display(channels_table)

Unnamed: 0,channel_id,channel_loc_x,channel_loc_y
0,AP0,16.0,0.0
1,AP1,48.0,0.0
2,AP2,0.0,20.0
3,AP3,32.0,20.0
4,AP4,16.0,40.0
...,...,...,...
379,AP379,32.0,3780.0
380,AP380,16.0,3800.0
381,AP381,48.0,3800.0
382,AP382,0.0,3820.0


In [13]:
channels_table['channel_loc_x'].unique()

array([16., 48.,  0., 32.])

In [14]:
waveform_folder = 'waveform'

job_kwargs = dict(n_jobs=10, chunk_duration="1s", progress_bar=True)

In [15]:
if (base_folder / waveform_folder).is_dir():
    waveform_nwb = si.load_waveforms(base_folder / waveform_folder)
else:
    waveform_nwb = si.extract_waveforms(
        recording_slice,
        sorting_nwb,
        waveform_folder,
        ms_before=1.5,
        ms_after=2.,
        max_spikes_per_unit=None,
        overwrite=True,
        **job_kwargs
    )

In [16]:
waveform_nwb

WaveformExtractor: 384 channels - 423 units - 1 segments
  before:45 after:60 n_per_units:None

In [17]:
spikes_table = dss.extract_spikes(sorting_nwb, waveform_nwb)
display(spikes_table)

Unnamed: 0,unit_id,peak_frame,peak_channel
0,0,471,341
1,1,511,361
2,2,606,354
3,1,680,361
4,3,715,325
...,...,...,...
4604408,372,125188815,21
4604409,41,125188837,155
4604410,102,125188911,325
4604411,316,125188967,326


---

### 2. Match peaks to spikes

In [18]:
peaks_folder = 'peaks'

job_kwargs = dict(chunk_duration='1s', n_jobs=8, progress_bar=True)

In [19]:
if (base_folder / peaks_folder).is_dir():
    peaks = dss.load_peaks(base_folder / peaks_folder)
else:
    peaks = detect_peaks(recording_slice,
                         method='locally_exclusive',
                         peak_sign='neg',
                         detect_threshold=6,
                         **job_kwargs
                        )
    dss.save_peaks(peaks, base_folder / peaks_folder)

In [20]:
peaks_table = dss.extract_peaks(recording_slice, peaks)
display(peaks_table)

Unnamed: 0,peak_frame,peak_channel
0,92,326
1,134,186
2,147,348
3,177,337
4,269,330
...,...,...
4385111,125189400,155
4385112,125189402,89
4385113,125189402,269
4385114,125189408,287


In [21]:
peaks_matched_table = dss.match_peaks(peaks_table, spikes_table, channels_table)
display(peaks_matched_table)

matching peaks:   0%|          | 0/4385116 [00:00<?, ?it/s]

Unnamed: 0,unit_id,peak_frame,peak_channel
0,-1,92,326
1,-1,134,186
2,-1,147,348
3,-1,177,337
4,-1,269,330
...,...,...,...
4385111,-1,125189400,155
4385112,-1,125189402,89
4385113,-1,125189402,269
4385114,-1,125189408,287


In [22]:
peaks_spikes_table = dss.get_peaks_spikes(peaks_matched_table)
peaks_noise_table = dss.get_peaks_noise(peaks_matched_table)

display(peaks_spikes_table)
display(peaks_noise_table)

Unnamed: 0,unit_id,peak_frame,peak_channel
0,0,470,337
1,1,510,361
2,2,605,354
3,1,679,361
4,3,713,321
...,...,...,...
2592140,12,125188702,334
2592141,41,125188837,155
2592142,102,125188912,325
2592143,316,125188965,326


Unnamed: 0,peak_frame,peak_channel
0,92,326
1,134,186
2,147,348
3,177,337
4,269,330
...,...,...
1792966,125189400,155
1792967,125189402,89
1792968,125189402,269
1792969,125189408,287


---

### 3. Create a dataset from matched peaks

In [23]:
nthreads = psutil.cpu_count(logical=True)
ncores = psutil.cpu_count(logical=False)
nthreads_per_core = nthreads // ncores
nthreads_available = len(os.sched_getaffinity(0))
ncores_available = nthreads_available // nthreads_per_core

assert nthreads == os.cpu_count()
assert nthreads == mp.cpu_count()

print(f'{nthreads=}')
print(f'{ncores=}')
print(f'{nthreads_per_core=}')
print(f'{nthreads_available=}')
print(f'{ncores_available=}')

nthreads=256
ncores=128
nthreads_per_core=2
nthreads_available=256
ncores_available=128


In [24]:
peaks_dataset_folder = os.path.join(peaks_folder, 'peaks_dataset')

if not os.path.exists(peaks_dataset_folder):
    os.mkdir(peaks_dataset_folder)

In [25]:
selected_peaks = peaks_matched_table['unit_id'].value_counts()
selected_peaks = selected_peaks[selected_peaks >= 500].index
selected_peaks = selected_peaks.to_list()
selected_peaks.sort()

print(selected_peaks)

[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 21, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 95, 96, 97, 98, 99, 100, 101, 102, 103, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 146, 147, 148, 149, 150, 151, 153, 154, 156, 157, 159, 160, 161, 162, 163, 165, 166, 167, 169, 171, 172, 173, 174, 175, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 194, 196, 197, 198, 200, 201, 202, 203, 206, 207, 208, 209, 210, 212, 213, 214, 215, 216, 217, 219, 221, 225, 226, 227, 228, 229, 230, 232, 233, 235, 236, 237, 238, 239, 240, 241, 243, 244, 245, 246, 247, 

In [26]:
peak_labels = ['unit_' + str(unit) for unit in selected_peaks]
peak_labels_dict = {name: index for index, name in enumerate(peak_labels)}
peaks_dataset = dss.TensorDataset(peaks_dataset_folder, peak_labels, peak_labels_dict)

#### Split for training and testing

In [27]:
train_size = int(0.7 * len(peaks_dataset))
test_size = len(peaks_dataset) - train_size
peaks_train_dataset, peaks_test_dataset = torch.utils.data.random_split(peaks_dataset, [train_size, test_size])

print("Training size:", len(peaks_train_dataset))
print("Testing size:",len(peaks_test_dataset))

Training size: 3006964
Testing size: 1288699


In [28]:
# Create DataLoader instances for train and test datasets
peaks_train_dataloader = DataLoader(
    peaks_train_dataset,
    batch_size=64,
    num_workers=8,
    shuffle=True  # Shuffle the train dataset during training
)

peaks_test_dataloader = DataLoader(
    peaks_test_dataset,
    batch_size=64,
    num_workers=8
)

---

### 4. Classify peaks with a CNN 

#### Build the CNN architecture

In [29]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

Using cuda device


In [30]:
class CNNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layer_1 = nn.Conv3d(1, 32, kernel_size=(9, 3, 2)) 
        self.conv_layer_2 = nn.Conv2d(32, 64, kernel_size=4) 
        
        self.conv_layer_2_drop = nn.Dropout2d()
        
        self.flatten = nn.Flatten()
        self.fully_connected_layer_1 = nn.Linear(35328, 500)
        self.fully_connected_layer_2 = nn.Linear(500, 316)

    def forward(self, x):
        x = F.relu(F.max_pool2d(torch.squeeze(self.conv_layer_1(x), 4), 2))
        
        x = F.relu(F.max_pool2d(self.conv_layer_2_drop(self.conv_layer_2(x)), 2))
        
        x = self.flatten(x)
        x = F.relu(self.fully_connected_layer_1(x))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.fully_connected_layer_2(x))
        return F.log_softmax(x,dim=1)

model = CNNet().to(device)

# Choose optimal parameters
loss_fn = nn.CrossEntropyLoss()
learning_rate = 0.0001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [31]:
summary(model, input_size=(64, 1, 64, 192, 2))

Layer (type:depth-idx)                   Output Shape              Param #
CNNet                                    [64, 316]                 --
├─Conv3d: 1-1                            [64, 32, 56, 190, 1]      1,760
├─Conv2d: 1-2                            [64, 64, 25, 92]          32,832
├─Dropout2d: 1-3                         [64, 64, 25, 92]          --
├─Flatten: 1-4                           [64, 35328]               --
├─Linear: 1-5                            [64, 500]                 17,664,500
├─Linear: 1-6                            [64, 316]                 158,316
Total params: 17,857,408
Trainable params: 17,857,408
Non-trainable params: 0
Total mult-adds (G): 7.17
Input size (MB): 6.29
Forward/backward pass size (MB): 250.11
Params size (MB): 71.43
Estimated Total Size (MB): 327.83

#### Train and test the model

In [32]:
models_folder = os.path.join(os.getcwd(), "models")

if not os.path.exists(models_folder):
    os.mkdir(models_folder)

In [33]:
train_peaks = dss.TrainModel(peaks_train_dataloader,
                             peaks_test_dataloader,
                             device,
                             loss_fn,
                             optimizer)

In [34]:
train_peaks.train_test_model(model, 'peak_2', models_folder, epochs=1, classes=selected_peaks)

No checkpoint found.
Start training at epoch 1

Epoch 1
-------------------------------
loss: 6.066647  [    0/3006964]
loss: 3.756690  [ 6400/3006964]
loss: 3.489623  [12800/3006964]
Training paused.
Checkpoint saved: /pscratch/sd/v/vlavan/models/peak_2.pt

Training completed.
