# Implementing a Deep Learning Spike Sorting Pipeline

The objective of this project is to implement a spike sorting project using deep learning techniques. 

We utilized the following core libraries:
- SpikeInterface and PyNWB are used for extracting recording and sorted data present in NWB files
- PyTorch is used for building Tensors, data loaders, and neural networks

The general overview of this approach is to create a labeled image dataset which will then be used to train a convolutional neural network (CNN) as a spike detector.

## 1. Reading an NWB file

There are readily available ground-truth datasets in NWB files which contain spikes that have been manually curated by experts. We are going to use the `sub-CSHL049_ses-c99d53e6-c317-4c53-99ba-070b26673ac4_behavior+ecephys+image.nwb` file which can be downloaded from the DANDI archive:
https://api.dandiarchive.org/api/assets/7e4fa468-349c-44a9-a482-26898682eed1/download/

### Import SpikeInterface modules

We followed the instructions for using `SpikeInterface` based on this tutorial:
https://github.com/SpikeInterface/spiketutorials/tree/master/Official_Tutorial_SI_0.96_Oct22 

Install the latest version of `SpikeInterface` from source as recommended in the **"From source"** section here: 
https://spikeinterface.readthedocs.io/en/latest/installation.html

In [None]:
import spikeinterface as si
import spikeinterface.extractors as se 
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
import spikeinterface.postprocessing as spost
import spikeinterface.qualitymetrics as sqm
import spikeinterface.comparison as sc
import spikeinterface.exporters as sexp
import spikeinterface.widgets as sw

In [None]:
print(f"SpikeInterface version: {si.__version__}")

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pathlib import Path

import warnings
warnings.simplefilter("ignore")

%matplotlib widget

### Reading recording and sorting

In [None]:
base_folder = Path(".")
nwb_file_path = "sub-CSHL049_ses-c99d53e6-c317-4c53-99ba-070b26673ac4_behavior+ecephys+image.nwb"

In [None]:
recording_nwb = se.read_nwb_recording(file_path=nwb_file_path, electrical_series_name='ElectricalSeriesAp')
recording_nwb

In [None]:
recording_nwb.annotate(is_filtered=False)

In [None]:
channel_ids = recording_nwb.get_channel_ids()
print(channel_ids)

In [None]:
channel_ids_slice = channel_ids[0:384]
print(channel_ids_slice)

In [None]:
recording_slice = recording_nwb.channel_slice(channel_ids=channel_ids_slice)
recording_slice

In [None]:
sorting_nwb = se.read_nwb_sorting(file_path=nwb_file_path, electrical_series_name='ElectricalSeriesAp')
sorting_nwb

## 2. Preprocessing

In [None]:
recording_f = spre.bandpass_filter(recording_slice, freq_min=300, freq_max=6000)
recording_f

In [None]:
recording_cmr = spre.common_reference(recording_f, reference='global', operator='median')
recording_cmr

In [None]:
fs = recording_cmr.get_sampling_frequency()

In [None]:
recording_sub = recording_cmr.frame_slice(start_frame=0*fs, end_frame=300*fs)
recording_sub

In [None]:
sorting_sub = sorting_nwb.frame_slice(start_frame=0*fs, end_frame=300*fs)
sorting_sub

## 3. Exploring the data

In [None]:
channel_ids = recording_sub.get_channel_ids()
print(channel_ids)

In [None]:
channel_locations = recording_slice.get_channel_locations()
print(channel_locations)

In [None]:
channel_summary = np.hstack((channel_ids.reshape(-1,1), channel_locations))
print(channel_summary)

In [None]:
column_names = ['channel_id', 'channel_loc_x', 'channel_loc_y']

channel_summary_table = pd.DataFrame(channel_summary, columns=column_names)
display(channel_summary_table)

In [None]:
channel_summary_table['channel_loc_x'].unique()

In [None]:
sw.plot_probe_map(recording_sub, with_channel_ids=True)

Since we are using an NWB file that contains both the raw recording and spike sorted data, we can extract information of the already sorted spikes.

We need these expert-sorted spikes in order to determine the best channels and frames for plotting our images and labelling them as spikes for training.

Before we are able to retrieve information about these spikes, we need to create a `WaveformExtractor` object which has mechanisms provided by `SpikeInterface` for computing the spike locations as well as plotting them on the probe.

A `WaveformExtractor` object requires a paired `Recording` and `Sorting object` which we already have.

More information on waveform extractors can be found here:
https://spikeinterface.readthedocs.io/en/latest/modules_gallery/core/plot_4_waveform_extractor.html

In [None]:
waveform_folder = 'waveform'

job_kwargs = dict(n_jobs=10, chunk_duration="1s", progress_bar=True)

In [None]:
if (base_folder / waveform_folder).is_dir():
    waveform = si.load_waveforms(base_folder / waveform_folder)
else:
    waveform = si.extract_waveforms(
        recording_sub,
        sorting_sub,
        waveform_folder,
        ms_before=1.5,
        ms_after=2.,
        max_spikes_per_unit=None,
        overwrite=True,
        **job_kwargs
    )

In [None]:
waveform

We can retrieve the frames each spike occurred (since `SpikeInterface` uses frames instead of seconds) by using the `get_all_spike_trains()` function which returns a list containing two arrays including each spike's unit ID and frame.

Each individual spike frame is the rounded product of its corresponding spike time and the sampling frequency.

In [None]:
print(sorting_sub.get_all_spike_trains())

In [None]:
spikes_table_si = pd.DataFrame({'unit_id':sorting_sub.get_all_spike_trains()[0][1], 'spike_frame':sorting_sub.get_all_spike_trains()[0][0]})
spikes_table_si['unit_id'] = spikes_table_si['unit_id'].astype(int)

display(spikes_table_si)

In [None]:
print(si.get_template_extremum_channel(waveform, outputs="index"))


In [None]:
# Create a new column and map values from the dictionary based on matching keys
spikes_table_si['extremum_channel'] = spikes_table_si['unit_id'].map(si.get_template_extremum_channel(waveform, outputs="index"))
spikes_table_si['spike_number'] = range(len(spikes_table_si))

display(spikes_table_si)

In [None]:
display(channel_summary_table)

In [None]:
spike_frame = 471
trace_snippet = recording_sub.get_traces(start_frame=spike_frame-31, end_frame=spike_frame+33)

In [None]:
trace_snippet.shape

In [None]:
trace_reshaped = np.dstack((
    trace_snippet[:, ::2],
    trace_snippet[:, 1::2]
))

trace_reshaped.shape

## 4. Creating timeseries image dataset

In [None]:
import multiprocessing as mp
import os
import psutil
from tqdm import tqdm

In [None]:
timeseries_path = os.path.join(os.getcwd(), "timeseries")

if not os.path.exists(timeseries_path):
    os.mkdir(timeseries_path)

In [None]:
nthreads = psutil.cpu_count(logical=True)
ncores = psutil.cpu_count(logical=False)
nthreads_per_core = nthreads // ncores
nthreads_available = len(os.sched_getaffinity(0))
ncores_available = nthreads_available // nthreads_per_core

assert nthreads == os.cpu_count()
assert nthreads == mp.cpu_count()

print(f'{nthreads=}')
print(f'{ncores=}')
print(f'{nthreads_per_core=}')
print(f'{nthreads_available=}')
print(f'{ncores_available=}')

### Process spike images

#### Inspect abundant spike units

In [None]:
top_spike_units = spikes_table_si['unit_id'].value_counts().head(20)
print(top_spike_units)

In [None]:
top_spike_units_table = spikes_table_si[spikes_table_si['unit_id'].isin(top_spike_units.index)]
top_spike_units_table = top_spike_units_table.sort_values(by=['unit_id', 'spike_frame'], ascending=True)
display(top_spike_units_table)

In [None]:
top_spike_units = top_spike_units_table['unit_id'].unique()
print(top_spike_units)

In [None]:
top_channels = []
for i in top_spike_units:
    top_channels.append(top_spike_units_table[top_spike_units_table['unit_id']==i]['extremum_channel'].unique()[0])

In [None]:
# 0, 5, 7, 13, 33, 40, 52, 67, 148, 286

In [None]:
for unit, channel in zip(top_spike_units, top_channels):
    print(f'{unit:<10}{channel:<10}')

In [None]:
selected_spike_units = [0, 5, 7, 13, 33, 40, 52, 67, 148, 286]

#### Define multiprocessing functions

In [None]:
# def process_numpy_images(frames_paths):
#     for frame, path in frames_paths:
#         # Get the trace
#         numpy_image = recording_sub.get_traces(start_frame=frame - 31, end_frame=frame + 33)
#         numpy_image_reshaped = np.dstack((
#             numpy_image[:, ::2],
#             numpy_image[:, 1::2]
#         ))

#         # Save the numpy array to disk
#         image_name = f"frame_{frame}"
#         np.save(os.path.join(path, image_name), numpy_image_reshaped)

In [None]:
# def process_batch(batch, path):
#     batch_frames_paths = [(frame, path) for frame in batch]
#     process_numpy_images(batch_frames_paths)

In [None]:
# # Define the number of processes
# num_processes = 128

# # Define the batch size
# batch_size = 100

# # Create a multiprocessing pool
# pool = mp.Pool(processes=num_processes)

#### Process numpy files

In [None]:
# # Iterate over the units
# for unit in top_spike_units:
#     unit_path = os.path.join(timeseries_path, f"unit_{unit}")
#     if not os.path.exists(unit_path):
#         os.mkdir(unit_path)

#     unit_table = top_spike_units_table[top_spike_units_table['unit_id'] == unit]

#     # Get the number of frames
#     num_frames = 1000

#     # Iterate over the frames in batches
#     for i in range(0, num_frames, batch_size):
#         batch_frames = unit_table['spike_frame'][i:i+batch_size]

#         # Apply multiprocessing to process the batch of frames
#         pool.apply_async(process_batch, args=(batch_frames, unit_path))

### Process noise images

#### Create list of noise frames

In [None]:
# spike_frames = spikes_table_si['spike_frame'].to_list()
# print(spike_frames)

In [None]:
# noise_frames = [noise_frame - 64 for noise_frame in spike_frames]
# print(noise_frames)

#### Display multiprocessing functions

In [None]:
# def process_frame(noise_frame, noise_path):
#     # Get the trace
#     numpy_image = recording_sub.get_traces(start_frame=noise_frame - 31, end_frame=noise_frame + 33)
#     numpy_image_reshaped = np.dstack((
#         numpy_image[:, ::2],
#         numpy_image[:, 1::2]
#     ))

#     # Save the numpy array to disk
#     image_name = f"frame_{noise_frame}"
#     np.save(os.path.join(noise_path, image_name), numpy_image_reshaped)

#### Process numpy files

In [None]:
# noise_path = os.path.join(timeseries_path, "noise")

# if not os.path.exists(noise_path):
#     os.mkdir(noise_path)

In [None]:
# # Set the number of frames
# num_frames = 10000

# # Iterate over the frames
# for i in range(0, num_frames):
#     noise_frame = noise_frames[i]
    
#     # Apply multiprocessing to process the frame
#     pool.apply_async(process_frame, args=(noise_frame, noise_path))

In [None]:
# # Close the multiprocessing pool
# pool.close()
# pool.join()

## 5. Creating CNN model

### Create tensor dataset

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

In [None]:
subfolders = ['unit_' + str(unit) for unit in selected_spike_units]
subfolders.append("noise")
subfolders

In [None]:
subfolders_dict = {name: index for index, name in enumerate(subfolders)}
subfolders_dict

In [None]:
class NumpyDataset(Dataset):
    def __init__(self, folder_path):
        self.folder_path = folder_path
        self.file_paths = self.get_file_paths()
        self.label_map = subfolders_dict
        
    def get_file_paths(self):
        file_paths = []
        # Iterate over the spike and noise folders
        for subfolder in subfolders:
            subfolder_path = os.path.join(self.folder_path, subfolder)
            # Get all the numpy files in the subfolder
            subfolder_files = [file for file in os.listdir(subfolder_path) if file.endswith('.npy')]
            # Add the full path of each file to the list
            file_paths.extend([os.path.join(subfolder_path, file) for file in subfolder_files])
        return file_paths

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        # Load the numpy file as a grayscale image tensor
        image = torch.from_numpy(np.load(self.file_paths[idx])).unsqueeze(0).float()
        # Extract the folder name from the file path
        folder_name = os.path.dirname(self.file_paths[idx])
        # Extract the label from the folder name
        label = folder_name.split(os.sep)[-1]  # Extract the last folder name
        # Assign the numerical label based on the label map
        label = self.label_map[label]
        return image, label

In [None]:
timeseries_dataset = NumpyDataset(timeseries_path)

In [None]:
from collections import Counter

# labels in training set
train_classes = [label for _, label in timeseries_dataset]
Counter(train_classes)

### Split for training and testing

In [None]:
#split data to test and train
#use 80% to train
train_size = int(0.8 * len(timeseries_dataset))
test_size = len(timeseries_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(timeseries_dataset, [train_size, test_size])

print("Training size:", len(train_dataset))
print("Testing size:",len(test_dataset))

In [None]:
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=32,
    num_workers=2,
    shuffle=True
)

test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=32,
    num_workers=2,
    shuffle=True
)

### Create the model

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchinfo import summary

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

In [None]:
class CNNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layer_1 = nn.Conv3d(1, 32, kernel_size=2) 
        self.conv_layer_2 = nn.Conv2d(32, 64, kernel_size=2)
        self.conv_layer_3 = nn.Conv2d(64, 128, kernel_size=2)
        self.conv_layer_2_drop = nn.Dropout2d()
        self.conv_layer_3_drop = nn.Dropout2d()
        self.flatten = nn.Flatten()
        self.fully_connected_layer_1 = nn.Linear(20608, 50)
        self.fully_connected_layer_2 = nn.Linear(50, 11)


    def forward(self, x):
        x = F.relu(F.max_pool2d(torch.squeeze(self.conv_layer_1(x), 4), 2))
        #Before this step, input is (batch size, 1, 64, 192, 2)
        # After this step, output is (batch size, 1, 32, 96)
        
        x = F.relu(F.max_pool2d(self.conv_layer_2_drop(self.conv_layer_2(x)), 2))
        # After this step, output is (batch size, 1, 16, 48)
        
        x = F.relu(F.max_pool2d(self.conv_layer_3_drop(self.conv_layer_3(x)), 2))
        # After this step, output is (batch size, 1, 8, 24)
        
        #x = x.view(x.size(0), -1)
        x = self.flatten(x)
        x = F.relu(self.fully_connected_layer_1(x))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.fully_connected_layer_2(x))
        return x

model = CNNet().to(device)

In [None]:
summary(model, input_size=(64, 1, 64, 192, 2))

In [None]:
class SpikeDeeptector(nn.Module):
    def __init__(self):
        super().__init__()
        # Convolutional layers
        self.conv_layer_1 = nn.Conv3d(1, 32, kernel_size=2) 
        self.conv_layer_2 = nn.Conv2d(32, 64, kernel_size=2)
        self.conv_layer_3 = nn.Conv2d(64, 128, kernel_size=2)
        self.conv_layer_4 = nn.Conv2d(128, 256, kernel_size=2)
        # Dropout layers
        self.conv_layer_2_drop = nn.Dropout2d()
        self.conv_layer_3_drop = nn.Dropout2d()
        self.conv_layer_4_drop = nn.Dropout2d()
        # Reshape
        self.flatten = nn.Flatten()
        # Fully connected layers
        self.fully_connected_layer_1 = nn.Linear(41216, 500)
        self.fully_connected_layer_2 = nn.Linear(500, 250)
        self.fully_connected_layer_3 = nn.Linear(250, 125)
        self.fully_connected_layer_4 = nn.Linear(125, 11)


    def forward(self, x):
        # Before this step, input is (batch size, 1, 64, 192, 2)
        x = F.relu(torch.squeeze(self.conv_layer_1(x), 4))
        # After this step, output is (batch size, 1, 32, 96)
        
        x = F.relu(F.max_pool2d(self.conv_layer_2_drop(self.conv_layer_2(x)), 2))
        # After this step, output is (batch size, 1, 16, 48)
        
        x = F.relu(F.max_pool2d(self.conv_layer_3_drop(self.conv_layer_3(x)), 2))
        # After this step, output is (batch size, 1, 8, 24)
        
        x = F.relu(F.max_pool2d(self.conv_layer_4_drop(self.conv_layer_4(x)), 2))
        # After this step, output is (batch size, 1, 8, 24)
        
        x = self.flatten(x)
        
        x = F.relu(self.fully_connected_layer_1(x))
        x = F.dropout(x, training=self.training)
        
        x = F.relu(self.fully_connected_layer_2(x))
        x = F.dropout(x, training=self.training)
        
        x = F.relu(self.fully_connected_layer_3(x))
        x = F.dropout(x, training=self.training)
        
        x = F.relu(self.fully_connected_layer_4(x))
        
        return x

spike_deeptector = SpikeDeeptector().to(device)

In [None]:
summary(spike_deeptector, input_size=(64, 1, 64, 192, 2))

### Train the model

In [None]:
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
import seaborn as sns

In [None]:
# cost function used to determine best parameters
cost = nn.CrossEntropyLoss()

# Create training function
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    size = len(dataloader.dataset)
    for batch, (X, Y) in enumerate(dataloader):
        X = X.to(device)
        Y = Y.to(device)
        optimizer.zero_grad()
        pred = model(X)
        loss = loss_fn(pred, Y)
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss_value = loss.item()
            current = batch * len(X)
            print(f'loss: {loss_value:>7f}  [{current:>5d}/{size:>5d}]')

# Create testing/validation function
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    model.eval()
    test_loss, correct = 0, 0

    true_labels = []
    predicted_labels = []

    with torch.no_grad():
        for batch, (X, Y) in enumerate(dataloader):
            X, Y = X.to(device), Y.to(device)
            pred = model(X)

            test_loss += loss_fn(pred, Y).item()
            correct += (pred.argmax(1) == Y).type(torch.float).sum().item()

            true_labels.extend(Y.tolist())
            predicted_labels.extend(pred.argmax(1).tolist())

    test_loss /= size
    correct /= size

    print(f'\nTest Error:\nacc: {(100 * correct):>0.1f}%, avg loss: {test_loss:>8f}\n')

    return test_loss, true_labels, predicted_labels

In [None]:
def train_test_model(model):
    losses = []
    accuracies = []
    epochs = 50

    for epoch in range(epochs):
        print(f'Epoch {epoch+1}\n-------------------------------')

        # Train the model and get the losses for this epoch
        train_losses = train(train_dataloader, model, cost, optimizer)

        # Test the model and get the accuracy for this epoch
        test_loss, true_labels, predicted_labels = test(test_dataloader, model, cost)
        losses.append(test_loss)
        accuracy = accuracy_score(true_labels, predicted_labels)
        accuracies.append(accuracy)  # Append the accuracy within the loop

        if epoch + 1 == epochs:
            # Plot confusion matrix
            classes = list(range(0, 11))
            cm = confusion_matrix(true_labels, predicted_labels)
            plt.figure(figsize=(8, 6))
            sns.heatmap(cm, annot=True, fmt="d", cmap='Blues', xticklabels=classes, yticklabels=classes)
            plt.title('Confusion Matrix')
            plt.xlabel('Predicted')
            plt.ylabel('True')
            plt.show()
            break

    # Plotting the training loss and accuracy over time
    plt.figure()
    plt.plot(losses, label='Training Loss')
    plt.plot(accuracies, label='Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.title('Training Loss and Accuracy over Time')
    plt.legend
    plt.show()

In [None]:
# used to create optimal parameters
learning_rate = 0.0001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

train_test_model(model)

In [None]:
# used to create optimal parameters
learning_rate = 0.0001
optimizer = torch.optim.Adam(spike_deeptector.parameters(), lr=learning_rate)

train_test_model(spike_deeptector)

In [None]:
models_path = os.path.join(os.getcwd(), "models")

if not os.path.exists(models_path):
    os.mkdir(models_path)

In [None]:
# Save the entire model
torch.save(model, os.path.join(models_path, 'multi_unit_model.pth'))

In [None]:
# Save the entire model
torch.save(spike_deeptector, os.path.join(models_path, 'spike_deeptector.pth'))