# Phase 1: Proof of Principle

The main goal of this initial trial is to investigate if it would be possible to train a model to learn neuronal spiking activity. A large part of this process is to first unpack and understand the data we are working with in order to process it as inputs. We also implement different neural network architectures to test out their effectiveness.

---

## 1. Explore the data in an NWB file

There are readily available ground-truth datasets in NWB files which contain spikes that have been manually curated by experts. We are going to use the `sub-CSHL049_ses-c99d53e6-c317-4c53-99ba-070b26673ac4_behavior+ecephys+image.nwb` file which can be downloaded from the DANDI archive:
https://api.dandiarchive.org/api/assets/7e4fa468-349c-44a9-a482-26898682eed1/download/

### Import required modules

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd

from pathlib import Path
import os
import pickle

import psutil
import multiprocessing as mp
from tqdm.notebook import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
from collections import Counter
from sklearn.model_selection import train_test_split

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchinfo import summary
import torchvision
from torchvision import utils

import warnings
warnings.simplefilter("ignore")

%matplotlib widget

We followed the instructions for using `SpikeInterface` based on this tutorial:
https://github.com/SpikeInterface/spiketutorials/tree/master/Official_Tutorial_SI_0.96_Oct22 

Install the latest version of `SpikeInterface` from source as recommended in the **"From source"** section here: 
https://spikeinterface.readthedocs.io/en/latest/installation.html

In [None]:
import spikeinterface as si
import spikeinterface.extractors as se 
import spikeinterface.preprocessing as spre
import spikeinterface.widgets as sw
from spikeinterface.sortingcomponents.peak_detection import detect_peaks

In [None]:
print(f"SpikeInterface version: {si.__version__}")

In [None]:
import importlib
import deepspikesort as dss

### Read the NWB file

In [None]:
base_folder = Path(".")
nwb_file = "sub-CSHL049_ses-c99d53e6-c317-4c53-99ba-070b26673ac4_behavior+ecephys+image.nwb"

In [None]:
recording_nwb = se.read_nwb_recording(file_path=nwb_file, electrical_series_name='ElectricalSeriesAp')
recording_nwb

In [None]:
recording_nwb.annotate(is_filtered=False)

In [None]:
sorting_nwb = se.read_nwb_sorting(file_path=nwb_file, electrical_series_name='ElectricalSeriesAp')
sorting_nwb

### Preprocess the recording

In [None]:
recording_f = spre.bandpass_filter(recording_nwb, freq_min=300, freq_max=6000)
recording_f

In [None]:
recording_cmr = spre.common_reference(recording_f, reference='global', operator='median')
recording_cmr

### Inspect channels on probe

In [None]:
recording_slice = dss.channel_slice_electricalseriesap(recording_cmr)
recording_slice

In [None]:
channels_table = dss.extract_channels(recording_slice)
display(channels_table)

In [None]:
channels_table['channel_loc_x'].unique()

In [None]:
sw.plot_probe_map(recording_slice, with_channel_ids=True)

### Inspect spike events

Since we are using an NWB file that contains both the raw recording and spike sorted data, we can extract information of the already sorted spikes.

We need these expert-sorted spikes in order to determine the best channels and frames for plotting our images and labelling them as spikes for training.

Before we are able to retrieve information about these spikes, we need to create a `WaveformExtractor` object which has mechanisms provided by `SpikeInterface` for computing the spike locations as well as plotting them on the probe.

A `WaveformExtractor` object requires a paired `Recording` and `Sorting object` which we already have.

More information on waveform extractors can be found here:
https://spikeinterface.readthedocs.io/en/latest/modules_gallery/core/plot_4_waveform_extractor.html

In [None]:
waveform_folder = 'waveform'

job_kwargs = dict(n_jobs=10, chunk_duration="1s", progress_bar=True)

In [None]:
if (base_folder / waveform_folder).is_dir():
    waveform_nwb = si.load_waveforms(base_folder / waveform_folder)
else:
    waveform_nwb = si.extract_waveforms(
        recording_slice,
        sorting_nwb,
        waveform_folder,
        ms_before=1.5,
        ms_after=2.,
        max_spikes_per_unit=None,
        overwrite=True,
        **job_kwargs
    )

In [None]:
waveform_nwb

We can retrieve the frames each spike occurred (since `SpikeInterface` uses frames instead of seconds) by using the `get_all_spike_trains()` function which returns a list containing two arrays including each spike's unit ID and frame.

Each individual spike frame is the rounded product of its corresponding spike time and the sampling frequency.

In [None]:
spikes_table = dss.extract_spikes(sorting_nwb, waveform_nwb)
display(spikes_table)

In [None]:
dss.plot_unit_waveform(recording_slice, spikes_table, unit_id=7, num_waveforms=50)

Because of how the channels on a Neuropixels probe are arranged in a checkerboard pattern, we want to reshape our trace to better emulate that. This would mean separating the channels into two columns resulting in a 3-dimensional array.

In [None]:
dss.plot_trace_image(recording_slice, 1153)

---

## 2. Create a dataset from sorted spikes

Since we are running this project on NERSC, we are able to utilize multiprocessing as well as batches in order to speed up the process of generating our dataset.

In [None]:
nthreads = psutil.cpu_count(logical=True)
ncores = psutil.cpu_count(logical=False)
nthreads_per_core = nthreads // ncores
nthreads_available = len(os.sched_getaffinity(0))
ncores_available = nthreads_available // nthreads_per_core

assert nthreads == os.cpu_count()
assert nthreads == mp.cpu_count()

print(f'{nthreads=}')
print(f'{ncores=}')
print(f'{nthreads_per_core=}')
print(f'{nthreads_available=}')
print(f'{ncores_available=}')

In [None]:
supervised_dataset_folder = os.path.join(os.getcwd(), "supervised_dataset")

if not os.path.exists(supervised_dataset_folder):
    os.mkdir(supervised_dataset_folder)

### Process traces as numpy files

In order to obtain a fairly large dataset of spikes for our model, we want to select units which have at least 1000 spikes within them.

As for our noise class, we can select frames that exist in gaps between each of our spikes.

In [None]:
top_spike_units = spikes_table['unit_id'].value_counts()
top_spike_units = top_spike_units[top_spike_units >= 1000]
top_spike_units = top_spike_units.index.tolist()
top_spike_units.sort()

print('Number of units:', len(top_spike_units))
print(top_spike_units)

In [None]:
noise_folder = os.path.join(supervised_dataset_folder, "noise")

if not os.path.exists(noise_folder):
    os.mkdir(noise_folder)
    
    spike_frames = spikes_table['spike_frame'].to_list()
    noise_frames = [noise_frame - 64 for noise_frame in spike_frames]
    print(noise_frames)

### Convert to tensor dataset

In [None]:
dataset_folders = ['unit_' + str(unit) for unit in top_spike_units[:100]]
dataset_folders.append("noise")
print(dataset_folders)

In [None]:
dataset_folders_dict = {name: index for index, name in enumerate(dataset_folders)}
print(dataset_folders_dict)

In [None]:
spikes_dataset = dss.TensorDataset(supervised_dataset_folder, dataset_folders, dataset_folders_dict)

### Split for training and testing

In [None]:
train_dataset_path = os.path.join(supervised_dataset_folder, "train_dataset_v2.pkl")
test_dataset_path = os.path.join(supervised_dataset_folder, "test_dataset_v2.pkl")

if not (os.path.exists(train_dataset_path) and os.path.exists(test_dataset_path)):
    # Get labels for the entire dataset
    labels = []
    for _, label in tqdm(spikes_dataset, desc="Getting Labels"):
        labels.append(label)

    # Split the dataset into train and test sets while maintaining class distribution
    train_indices, test_indices = train_test_split(
        range(len(spikes_dataset)),
        test_size=0.3,
        stratify=labels
    )

    # Create train and test datasets using the indices
    train_dataset = [(spikes_dataset[i][0], spikes_dataset[i][1]) for i in tqdm(train_indices, desc="Creating Train Dataset")]
    test_dataset = [(spikes_dataset[i][0], spikes_dataset[i][1]) for i in tqdm(test_indices, desc="Creating Test Dataset")]
    
    with open(train_dataset_path, 'wb') as f:
        pickle.dump(train_dataset, f)

    with open(test_dataset_path, 'wb') as f:
        pickle.dump(test_dataset, f)

else:
    with open(train_dataset_path, 'rb') as f:
        train_dataset = pickle.load(f)
    
    with open(test_dataset_path, 'rb') as f:
        test_dataset = pickle.load(f)

In [None]:
# Get classes and number of items in train and test datasets
train_classes = [label for _, label in train_dataset]
test_classes = [label for _, label in test_dataset]

print('Training dataset:\n', Counter(train_classes), '\n')
print('Testing dataset:\n', Counter(test_classes))

In [None]:
# Create DataLoader instances for train and test datasets
train_dataloader = DataLoader(
    train_dataset,
    batch_size=64,
    num_workers=2,
    shuffle=True  # Shuffle the train dataset during training
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=64,
    num_workers=2
)

---

## 3. Classify spikes and noise with a CNN 

### Build the CNN architecture

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

In [None]:
class CNNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layer_1 = nn.Conv3d(1, 32, kernel_size=(9, 3, 2)) 
        self.conv_layer_2 = nn.Conv2d(32, 64, kernel_size=4) 
        
        self.conv_layer_2_drop = nn.Dropout2d()
        
        self.flatten = nn.Flatten()
        self.fully_connected_layer_1 = nn.Linear(35328, 500)
        self.fully_connected_layer_2 = nn.Linear(500, 316)

    def forward(self, x):
        x = F.relu(F.max_pool2d(torch.squeeze(self.conv_layer_1(x), 4), 2))
        
        x = F.relu(F.max_pool2d(self.conv_layer_2_drop(self.conv_layer_2(x)), 2))
        
        x = self.flatten(x)
        x = F.relu(self.fully_connected_layer_1(x))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.fully_connected_layer_2(x))
        return F.log_softmax(x,dim=1)

model = CNNet().to(device)

# Choose optimal parameters
loss_fn = nn.CrossEntropyLoss()
learning_rate = 0.0001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
summary(model, input_size=(64, 1, 64, 192, 2))

### Train and test the model

In [None]:
models_folder = os.path.join(os.getcwd(), "models")

if not os.path.exists(models_folder):
    os.mkdir(models_folder)

In [None]:
dss = importlib.reload(dss)

In [None]:
train_model = dss.TrainModel(train_dataloader,
                             test_dataloader,
                             device,
                             loss_fn,
                             optimizer)

In [None]:
model_name = "sup_1"

train_model.train_test_model(model, model_name, models_folder, epochs=7, classes=top_spike_units)

### Visualize convolutional layer filters

In [None]:
vis_model = dss.VisualizeModel(model)
vis_model.display_layers_weights()

In [None]:
# Visualize filters for first layer
vis_model.visualize_layer_filters(0, '3D')

In [None]:
# Visualize filters for second layer
vis_model.visualize_layer_filters(1, '2D')

### Test the model's confidence

In [None]:
test_model = dss.TestModel(test_dataset, device, model)

In [None]:
test_model.get_confidence_and_probabilities(dataset_folders, 7)