# Phase 1

---

In [None]:
import os
import sys

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

import warnings
warnings.simplefilter("ignore")

%matplotlib widget

Install the latest version of SpikeInterface as recommended in the **From source** section [here](https://spikeinterface.readthedocs.io/en/latest/get_started/installation.html).

In [None]:
import spikeinterface.full as si
from one.api import ONE

print(f"SpikeInterface version: {si.__version__}")

In [None]:
sys.path.append("..")
import preprocessing
import util

## 1. Read recording session

For this project, we will be using session [sub-CSHL049](https://dandiarchive.org/dandiset/000409/draft/files?location=sub-CSHL049&page=1) of the [IBL Brain Wide Map Dataset](https://dandiarchive.org/dandiset/000409/draft). 

In [None]:
data_folder = "../data/sub-CSHL049"

os.makedirs(data_folder, exist_ok=True)

Using SpikeInterface, we can read and save the data to disk. 

In [None]:
extractors_folder = os.path.join(data_folder, "extractors")

os.makedirs(extractors_folder, exist_ok=True)

In order to obtain this data, we will stream with ONE API using its identifier which is listed in the [metadata](https://api.dandiarchive.org/api/dandisets/000409/versions/draft/assets/7e4fa468-349c-44a9-a482-26898682eed1/).

In [None]:
one = ONE(base_url="https://openalyx.internationalbrainlab.org", password="international", silent=True)

eid = "c99d53e6-c317-4c53-99ba-070b26673ac4"
pids, _ = one.eid2pid(eid)
pid = pids[0]

In [None]:
one_folder = os.path.join(data_folder, "one")

os.makedirs(one_folder, exist_ok=True)

### Recording

In [None]:
preprocessed_folder = os.path.join(extractors_folder, "preprocessed")

if not os.path.exists(preprocessed_folder): 
    recording = si.read_ibl_recording(eid, pid, 'probe00.ap', cache_folder=one_folder)
    
    # Preprocess the recording
    recording_f = si.bandpass_filter(recording, freq_min=300, freq_max=6000)
    recording_cmr = si.common_reference(recording_f, reference='global', operator='median')
    
    # Save the preprocessed recording to disk
    job_kwargs = dict(n_jobs=10, chunk_duration="1s", progress_bar=True)
    recording_cmr.save(folder=preprocessed_folder, **job_kwargs)
else:
    recording_cmr = si.load_extractor(preprocessed_folder)
    
recording_cmr

### Sorting

In [None]:
sorting_folder = os.path.join(extractors_folder, "sorting")

if not os.path.exists(sorting_folder):  
    sorting = si.read_ibl_sorting(pid)        
    sorting.save(folder=sorting_folder)
else:
    sorting = si.load_extractor(sorting_folder)
    
sorting

### Sorting Analyzer

In [None]:
analyzer_folder = os.path.join(extractors_folder, "analyzer")

if not os.path.exists(analyzer_folder):
    analyzer = si.create_sorting_analyzer(
        sorting=sorting,
        recording=recording_cmr,
        format="memory"
    )
    
    # Compute extensions
    job_kwargs = dict(n_jobs=10, chunk_duration="1s", progress_bar=True)
    compute_dict = {
        'random_spikes': {'method': 'uniform'},
        'waveforms': {'ms_before': 1.0, 'ms_after': 2.0},
        'templates': {'operators': ["average", "median", "std"]}
    }
    analyzer.compute(compute_dict, **job_kwargs)
    
    # Save the sorting analyzer to disk
    analyzer.save_as(folder=analyzer_folder, format="binary_folder")
else:
    analyzer = si.load_sorting_analyzer(analyzer_folder)
    
analyzer

---

## 2. Extract the spikes

In [None]:
channels_file = os.path.join(extractors_folder, "channels.npy")

if not os.path.exists(spikes_file):
    channels = preprocessing.extract_channels(recording_cmr)
    np.save(channels_file, channels)
else:
    channels = np.load(channels_file)

display(pd.DataFrame(channels))

In [None]:
spikes_folder = os.path.join(data_folder, "spikes")

os.makedirs(spikes_folder, exist_ok=True)

In [28]:
spikes_file = os.path.join(spikes_folder, "spikes.npy")

if not os.path.exists(spikes_file):
    spikes = preprocessing.extract_spikes(sorting, analyzer, channels)
    np.save(spikes_file, spikes)
else:
    spikes = np.load(spikes_file)
    
display(pd.DataFrame(spikes))

Unnamed: 0,spike_index,sample_index,channel_index,channel_location_x,channel_location_y,unit_index
0,0,472,341,48.0,3400.0,271
1,1,511,361,48.0,3600.0,306
2,2,606,354,0.0,3540.0,297
3,3,680,361,48.0,3600.0,306
4,4,715,325,48.0,3240.0,235
...,...,...,...,...,...,...
4604408,4604408,125188816,21,48.0,200.0,26
4604409,4604409,125188838,155,32.0,1540.0,105
4604410,4604410,125188912,325,48.0,3240.0,237
4604411,4604411,125188967,326,0.0,3260.0,239


In [None]:
noise_file = os.path.join(spikes_folder, "noise.npy")

if not os.path.exists(noise_file):
    noise = preprocessing.create_noise(recording_cmr, spikes, num_samples=100000)
    np.save(noise_file, noise)
else:
    noise = np.load(noise_file)
    
display(pd.DataFrame(noise))

---

## 3. Create trace dataset

In [None]:
spike_units = spikes['unit_index']

print(util.format_value_counts(spike_units))

Create a dataset of HDF5 files using `submit_dataset.sh`. 

You will need to specify 2 arguments:
- [1] The ID of the recording
- [2] The type of dataset: 'spikes', 'peaks', or 'noise'


Each file belongs to an identified unit within the spikes from the NWB file. Within each file are two datasets:
- A dataset of frame numbers for when each sample occurred
- A dataset of traces for each sample belonging to the unit

---

## 4. Classify with CNN 

Run classification on spikes using `run_classify.py` from the main folder (or `submit_run_classify.sh` with SLURM).

You will need to specify 8 arguments:
- [1] The ID of the recording
- [2] The minimum number of samples per unit
- [3] The maximum number of samples per unit ('max' for the max number of samples per unit)
- [4] The number of units to be classified
- [5] The number of samples to be used per unit ('all' to use all samples)
- [6] The number of noise samples to include (0 for none, 'all' for all)
- [7] The name to set the session ID
- [8] The number to set the session ID

Example: `python -m phase_1.run_classify sub-CSHL049 1000 5000 3 all 0 sup 0`

The example command will run:
- using recording sub-CSHL049
- on 3 units
- with 1000-5000 samples per unit
- using all samples per unit
- including 0 noise samples
- named session SUP_000

The script will also save the classification results to the results folder:
- Accuracy and Loss plot
- Accuracy and Loss log
- Confusion matrix 