# Phase 1: Proof of Principle

The main goal of this initial trial is to investigate if it would be possible to train a model to learn neuronal spiking activity. A large part of this process is to first unpack and understand the data we are working with in order to process it as inputs. We also implement different neural network architectures to test out their effectiveness.

---

## 1. Explore the data in an NWB file

There are readily available ground-truth datasets in NWB files which contain spikes that have been manually curated by experts. We are going to use the `sub-CSHL049_ses-c99d53e6-c317-4c53-99ba-070b26673ac4_behavior+ecephys+image.nwb` file which can be downloaded from the DANDI archive:
https://api.dandiarchive.org/api/assets/7e4fa468-349c-44a9-a482-26898682eed1/download/

### Import required modules

In [None]:
import os

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

import warnings
warnings.simplefilter("ignore")

%matplotlib widget

We followed the instructions for using `SpikeInterface` based on this tutorial:
https://github.com/SpikeInterface/spiketutorials/tree/master/Official_Tutorial_SI_0.96_Oct22 

Install the latest version of `SpikeInterface` from source as recommended in the **"From source"** section here: 
https://spikeinterface.readthedocs.io/en/latest/installation.html

In [None]:
import spikeinterface.full as si

print(f"SpikeInterface version: {si.__version__}")

In [None]:
import preprocessing
import plotting
import util

### Read the NWB file

In [None]:
nwb_file = "data/sub-CSHL049_ses-c99d53e6-c317-4c53-99ba-070b26673ac4_behavior+ecephys+image.nwb"

recording_nwb = si.read_nwb(nwb_file, electrical_series_name='ElectricalSeriesAp')
recording_nwb

In [None]:
recording_nwb.annotate(is_filtered=False)

In [None]:
sorting_nwb = si.read_nwb_sorting(file_path=nwb_file, electrical_series_name='ElectricalSeriesAp')
sorting_nwb

### Preprocess the recording

In [None]:
recording_f = si.bandpass_filter(recording_nwb, freq_min=300, freq_max=6000)
recording_f

In [None]:
recording_cmr = si.common_reference(recording_f, reference='global', operator='median')
recording_cmr

In [None]:
recording_slice = preprocessing.channel_slice_electricalseriesap(recording_cmr)
recording_slice

### Save extractors to disk

In [None]:
extractors_folder = "extractors/001"

os.makedirs(extractors_folder, exist_ok=True)

In [None]:
preprocessed_folder = os.path.join(extractors_folder, "preprocessed")
job_kwargs = dict(n_jobs=10, chunk_duration="1s", progress_bar=True)

if not os.path.exists(preprocessed_folder):    
    recording_slice.save(folder=preprocessed_folder, **job_kwargs)

In [None]:
sorting_folder = os.path.join(extractors_folder, "sorting")

if not os.path.exists(sorting_folder):    
    sorting_nwb.save(folder=sorting_folder)

### Inspect channels on probe

In [None]:
channels = preprocessing.extract_channels(recording_slice)
    
display(pd.DataFrame(channels))

In [None]:
si.plot_probe_map(recording_slice, with_channel_ids=True)

### Inspect spike events

Since we are using an NWB file that contains both the raw recording and spike sorted data, we can extract information of the already sorted spikes.

We need these expert-sorted spikes in order to determine the best channels and frames for plotting our images and labelling them as spikes for training.

Before we are able to retrieve information about these spikes, we need to create a `WaveformExtractor` object which has mechanisms provided by `SpikeInterface` for computing the spike locations as well as plotting them on the probe.

A `WaveformExtractor` object requires a paired `Recording` and `Sorting object` which we already have.

More information on waveform extractors can be found here:
https://spikeinterface.readthedocs.io/en/latest/modules_gallery/core/plot_4_waveform_extractor.html

In [None]:
waveform_folder = os.path.join(extractors_folder, "waveform")

if os.path.exists(waveform_folder):
    waveform_nwb = si.load_waveforms(waveform_folder, with_recording=False)
else:
    job_kwargs = dict(n_jobs=10, chunk_duration="1s", progress_bar=True)

    waveform_nwb = si.extract_waveforms(
        recording_slice,
        sorting_nwb,
        waveform_folder,
        ms_before=1.5,
        ms_after=2.,
        max_spikes_per_unit=None,
        overwrite=True,
        **job_kwargs
    )
    
waveform_nwb

We can retrieve the frames each spike occurred (since `SpikeInterface` uses frames instead of seconds) by using the `get_all_spike_trains()` function which returns a list containing two arrays including each spike's unit ID and frame.

Each individual spike frame is the rounded product of its corresponding spike time and the sampling frequency.

In [None]:
spikes_folder = 'spikes/001'

os.makedirs(spikes_folder, exist_ok=True)

In [None]:
spikes = preprocessing.extract_spikes(sorting_nwb, waveform_nwb) 

spikes_file = os.path.join(spikes_folder, "spikes_nwb.npy")

if not os.path.exists(spikes_file):
    np.save(spikes_file, spikes)
    
display(pd.DataFrame(spikes))

In [None]:
plotting.plot_unit_waveform(recording_slice, spikes, unit_id=5, num_waveforms=1)

Because of how the channels on a Neuropixels probe are arranged in a checkerboard pattern, we want to reshape our trace to better emulate that. This would mean separating the channels into two columns resulting in a 3-dimensional array.

In [None]:
plotting.plot_trace_image(recording_slice, 471)

---

## 2. Create a dataset from sorted spikes

A dataset of HDF5 files can be generated using the `generate_dataset.py` script. 

Each file belongs to an identified unit within the spikes from the NWB file. Within each file are two datasets:
- A dataset of frame numbers for when each sample occurred
- A dataset of trace representations of each sample belonging to the unit

The script can also generate an HDF5 file of noise samples (named 'unit_-01') based on a number you specify (up to 1 million). 

In [None]:
spike_units = spikes['unit_index']

print(f'Spike units: {len(np.unique(spike_units))}\n')
print(util.format_value_counts(spike_units))

The script needs to be run with 4 arguments:
- [1] The number associated with the recording to be used
- [2] The type of dataset to be generated - 0 for spikes from the NWB file, 2 for noise samples

Example: `!python generate_dataset.py 1 0 0 423`

This example command will generate a dataset of spikes from recording number 1 starting from unit 0 to 423.

## 3. Classify spikes and noise with a CNN 

The classifier model can be run using the run_classifier.py script.

In [None]:
!nvidia-smi

The script needs to be run with 8 arguments:

- [1] The number associated with the recording to be used
- [2] The minimum number of samples per unit
- [3] The maximum number of samples per unit 
- [4] The number of units to be classified
- [5] The number of noise samples to include in the training data
- [6] The number of available GPUs for parallel data loading
- [7] The number of epochs for running DSS
- [8] The number to set the session ID

Example: `!python run_classifier.py 1 1000 2000 10 1000 1 200 1`

The example command will run DSS:
- named session 001
- for 200 epochs 
- on 10 units
- each with 1000-2000 samples
- including 1000 noise samples
- using 1 available GPU

The script will also save the classification results to the results folder:
- Accuracy progress plot
- Accuracy and Loss progress log
- Confusion matrix plot