# Phase 1

---

In [1]:
import os
import sys

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

import warnings
warnings.simplefilter("ignore")

%matplotlib widget

Install the latest version of SpikeInterface as recommended in the **From source** section [here](https://spikeinterface.readthedocs.io/en/latest/get_started/installation.html).

In [2]:
import spikeinterface.full as si

print(f"SpikeInterface version: {si.__version__}")

SpikeInterface version: 0.100.0.dev0


In [3]:
sys.path.append("..")
import preprocessing
import util

## 1. Read the NWB file

Download the NWB file from session **sub-CSHL049** of the IBL Brain Wide Map Dataset [here](https://dandiarchive.org/dandiset/000409/draft/files?location=sub-CSHL049&page=1).

In [4]:
data_folder = "../data/sub-CSHL049"

os.makedirs(data_folder, exist_ok=True)

In [5]:
nwb_file = os.path.join(data_folder,"sub-CSHL049_ses-c99d53e6-c317-4c53-99ba-070b26673ac4_behavior+ecephys+image.nwb")

In [6]:
extractors_folder = os.path.join(data_folder, "extractors")

os.makedirs(extractors_folder, exist_ok=True)

### Recording

In [7]:
preprocessed_folder = os.path.join(extractors_folder, "preprocessed")

if not os.path.exists(preprocessed_folder): 
    recording_nwb = si.read_nwb(nwb_file, electrical_series_path='acquisition/ElectricalSeriesAp')
    
    # Preprocess the recording
    recording_f = si.bandpass_filter(recording_nwb, freq_min=300, freq_max=6000)
    recording_cmr = si.common_reference(recording_f, reference='global', operator='median')
    
    # Save the preprocessed recording to disk
    job_kwargs = dict(n_jobs=10, chunk_duration="1s", progress_bar=True)
    recording_cmr.save(folder=preprocessed_folder, **job_kwargs)
else:
    recording_cmr = si.load_extractor(preprocessed_folder)
    
recording_cmr

### Sorting

In [8]:
sorting_folder = os.path.join(extractors_folder, "sorting")

if not os.path.exists(sorting_folder):    
    sorting_nwb = si.read_nwb_sorting(file_path=nwb_file, electrical_series_path='acquisition/ElectricalSeriesAp')
    sorting_nwb.save(folder=sorting_folder)
else:
    sorting_nwb = si.load_extractor(sorting_folder)
    
sorting_nwb

### Sorting Analyzer

In [9]:
analyzer_folder = os.path.join(extractors_folder, "analyzer")

if not os.path.exists(analyzer_folder):
    analyzer_nwb = si.create_sorting_analyzer(
        sorting=sorting_nwb,
        recording=recording_cmr,
        format="memory"
    )
    
    # Compute extensions
    job_kwargs = dict(n_jobs=8, chunk_duration="1s", progress_bar=True)
    compute_dict = {
        'random_spikes': {'method': 'uniform'},
        'waveforms': {'ms_before': 1.0, 'ms_after': 2.0},
        'templates': {'operators': ["average", "median", "std"]}
    }
    analyzer_nwb.compute(compute_dict, **job_kwargs)
    
    # Save the sorting analyzer to disk
    analyzer_nwb.save_as(folder=analyzer_folder, format="binary_folder")
else:
    analyzer_nwb = si.load_sorting_analyzer(analyzer_folder)
    
analyzer_nwb

SortingAnalyzer: 384 channels - 423 units - 1 segments - binary_folder - sparse - has recording
Loaded 3 extensions: random_spikes, templates, waveforms

---

## 2. Extract the spikes

In [10]:
spikes_folder = os.path.join(data_folder, "spikes")

os.makedirs(spikes_folder, exist_ok=True)

In [11]:
spikes_file = os.path.join(spikes_folder, "spikes.npy")

if not os.path.exists(spikes_file):
    spikes = preprocessing.extract_spikes(sorting_nwb, analyzer_nwb)
    np.save(spikes_file, spikes)
else:
    spikes = np.load(spikes_file)
    
display(pd.DataFrame(spikes))

Unnamed: 0,unit_index,sample_index,channel_index
0,0,471,341
1,1,511,361
2,2,606,354
3,1,680,361
4,3,715,325
...,...,...,...
4604408,372,125188815,21
4604409,41,125188837,155
4604410,102,125188911,325
4604411,316,125188967,326


In [12]:
noise_file = os.path.join(spikes_folder, "noise.npy")

if not os.path.exists(noise_file):
    noise = preprocessing.create_noise(recording_cmr, spikes, num_samples=100000)
    np.save(noise_file, noise)
else:
    noise = np.load(noise_file)
    
display(pd.DataFrame(noise))

Unnamed: 0,unit_index,sample_index
0,-1,77449522
1,-1,43739775
2,-1,16400284
3,-1,11532947
4,-1,39382641
...,...,...
99995,-1,72007380
99996,-1,110344177
99997,-1,82620528
99998,-1,116572938


---

## 3. Create trace dataset

In [14]:
spike_units = spikes['unit_index']

print(util.format_value_counts(spike_units))

000: 47148 	043: 9572  	086: 143422	129: 13499 	172: 20407 	215: 760   	258: 251   	301: 1156  	344: 580   	387: 2235  
001: 4717  	044: 11355 	087: 7581  	130: 26458 	173: 6512  	216: 18455 	259: 10942 	302: 5345  	345: 143   	388: 368   
002: 6304  	045: 27691 	088: 13047 	131: 41464 	174: 2267  	217: 12831 	260: 1385  	303: 16512 	346: 1729  	389: 438   
003: 46380 	046: 1101  	089: 8045  	132: 7994  	175: 3451  	218: 5014  	261: 7806  	304: 49411 	347: 520   	390: 814   
004: 8604  	047: 23870 	090: 6889  	133: 6178  	176: 3872  	219: 5644  	262: 40850 	305: 3071  	348: 458   	391: 101   
005: 20985 	048: 28271 	091: 4942  	134: 6150  	177: 15838 	220: 9779  	263: 1581  	306: 2224  	349: 202   	392: 66    
006: 1081  	049: 15974 	092: 13349 	135: 13253 	178: 10521 	221: 5250  	264: 2127  	307: 3096  	350: 5020  	393: 36    
007: 30788 	050: 38575 	093: 14695 	136: 4929  	179: 6271  	222: 398   	265: 23    	308: 5547  	351: 210   	394: 166   
008: 51920 	051: 18147 	094: 200   	137:

Create a dataset of HDF5 files using `submit_dataset.sh`. 

You will need to specify 2 arguments:
- [1] The ID of the recording
- [2] The type of dataset: 'spikes', 'peaks', or 'noise'


Each file belongs to an identified unit within the spikes from the NWB file. Within each file are two datasets:
- A dataset of frame numbers for when each sample occurred
- A dataset of traces for each sample belonging to the unit

---

## 4. Classify with CNN 

Run classification on spikes using `run_classify.py` from the main folder (or `submit_run_classify.sh` with SLURM).

You will need to specify 8 arguments:
- [1] The ID of the recording
- [2] The minimum number of samples per unit
- [3] The maximum number of samples per unit ('max' for the max number of samples per unit)
- [4] The number of units to be classified
- [5] The number of samples to be used per unit ('all' to use all samples)
- [6] The number of noise samples to include (0 for none, 'all' for all)
- [7] The name to set the session ID
- [8] The number to set the session ID

Example: `python -m phase_1.run_classify sub-CSHL049 1000 5000 3 all 0 sup 0`

The example command will run:
- using recording sub-CSHL049
- on 3 units
- with 1000-5000 samples per unit
- using all samples per unit
- including 0 noise samples
- named session SUP_000

The script will also save the classification results to the results folder:
- Accuracy and Loss plot
- Accuracy and Loss log
- Confusion matrix 