# Phase 2: Implementing DeepCluster

---

## 1. Extract data from file

### Import required modules

In [None]:
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd

import os

import warnings
warnings.simplefilter("ignore")

%matplotlib widget

In [None]:
import spikeinterface.full as si
from spikeinterface.sortingcomponents.peak_detection import detect_peaks

print(f"SpikeInterface version: {si.__version__}")

In [None]:
import preprocessing
import process_peaks
import comparison
import util

### Read the NWB file

In [None]:
nwb_file = "data/sub-CSHL049_ses-c99d53e6-c317-4c53-99ba-070b26673ac4_behavior+ecephys+image.nwb"

recording_nwb = si.read_nwb(nwb_file, electrical_series_name='ElectricalSeriesAp')
recording_nwb

In [None]:
recording_nwb.annotate(is_filtered=False)

In [None]:
sorting_nwb = si.read_nwb_sorting(file_path=nwb_file, electrical_series_name='ElectricalSeriesAp')
sorting_nwb

### Preprocess the recording

In [None]:
recording_f = si.bandpass_filter(recording_nwb, freq_min=300, freq_max=6000)
recording_f

In [None]:
recording_cmr = si.common_reference(recording_f, reference='global', operator='median')
recording_cmr

In [None]:
recording_slice = preprocessing.channel_slice_electricalseriesap(recording_cmr)
recording_slice

In [None]:
extractors_folder = "extractors/001"

os.makedirs(extractors_folder, exist_ok=True)

In [None]:
preprocessed_folder = os.path.join(extractors_folder, "preprocessed")
job_kwargs = dict(n_jobs=10, chunk_duration="1s", progress_bar=True)

if os.path.exists(preprocessed_folder):    
    recording_preprocessed = si.load_extractor(preprocessed_folder)
else:
    recording_preprocessed = recording_slice.save(folder=preprocessed_folder, **job_kwargs)
    
recording_preprocessed

### Extract channels and spikes

In [None]:
channels = preprocessing.extract_channels(recording_preprocessed)
display(pd.DataFrame(channels))

In [None]:
waveform_folder = os.path.join(extractors_folder, "waveform")
job_kwargs = dict(n_jobs=10, chunk_duration="1s", progress_bar=True)

if os.path.exists(waveform_folder):
    waveform_nwb = si.load_waveforms(waveform_folder, with_recording=False)
else:
    waveform_nwb = si.extract_waveforms(
        recording_slice,
        sorting_nwb,
        waveform_folder,
        ms_before=1.5,
        ms_after=2.,
        max_spikes_per_unit=None,
        overwrite=True,
        **job_kwargs
    )
    
waveform_nwb

In [None]:
spikes = preprocessing.extract_spikes(sorting_nwb, waveform_nwb)
display(pd.DataFrame(spikes))

---

## 2. Create a dataset from matched peaks

### Extract peaks

In [None]:
peaks_folder = 'peaks/001'

os.makedirs(peaks_folder, exist_ok=True)

In [None]:
peaks_file = os.path.join(peaks_folder, "peaks.npy")
job_kwargs = dict(chunk_duration='1s', n_jobs=8, progress_bar=True)

if os.path.exists(peaks_file):
    peaks = np.load(peaks_file)
else:
    peaks = detect_peaks(
        recording_cmr,
        method='locally_exclusive',
        peak_sign='neg',
        detect_threshold=6,
        **job_kwargs
    )    
    np.save(peaks_file, peaks)
    
display(pd.DataFrame(peaks))

In [None]:
peaks_filtered = process_peaks.filter_peaks(recording_slice, peaks)
display(pd.DataFrame(peaks_filtered))

### Match peaks to spikes

In [None]:
peaks_matched_file = os.path.join(peaks_folder, "peaks_matched.npy")

if os.path.exists(peaks_matched_file):
    peaks_matched = np.load(peaks_matched_file)
else:
    peaks_matched = process_peaks.match_peaks(peaks_filtered, spikes, channels)
    np.save(peaks_matched_file, peaks_matched)
    
display(pd.DataFrame(peaks_matched))

### Create peaks dataset

A dataset of HDF5 files can be generated using the `generate_dataset.py` script. 

Each file belongs to an identified unit within the peaks that we have matched to that of the NWB file. Within each file are two datasets:
- A dataset of frame numbers for when each sample occurred
- A dataset of trace representations of each sample belonging to the unit

In [None]:
util = importlib.reload(util)

In [None]:
peak_units = peaks_matched['unit_index']

print(f'Peak units: {len(np.unique(peak_units))}\n')
print(util.format_value_counts(peak_units))

The script needs to be run with 4 arguments:
- [1] The number associated with the recording to be used
- [2] The type of dataset to be generated - 0 for spikes from an NWB file, 1 for peaks from the peaks algorithm
- [3] The starting index for which unit to be processed
- [4] The ending index for which unit to be processed

Example: `!python generate_dataset.py 1 1 0 421`

This example command will generate a dataset of peaks from recording number 1 starting from unit 0 to 420.

---

## 3. Running DeepSpikeSort

The DeepSpikeSort algorithm can be run using the `run_dss.py` script.

DeepSpikeSort or DSS follows the DeepCluster method using the following steps:

1. Feature Extraction
- Initialize the CNN model with random weights for the first epoch
- Extract features before the final FC layer
- Preprocess features using PCA, whitening and l2-normalization

2. Clustering
- Fit a GMM with the preprocessed features 
- Predict cluster labels for the features

3. Cluster Comparison
- Calculate the ARI (Adjusted Rand Index) between epochs after the first epoch
- Set the ARI value as a metric for convergence

4. Representation Learning
- Create a dataset using the cluster labels for supervised learning
- Train the CNN model with labelled dataset

In [None]:
!nvidia-smi

The script needs to be run with 7 arguments:

- [1] The number associated with the recording to be used
- [2] The minimum number of samples per unit
- [3] The maximum number of samples per unit 
- [4] The number of units to be sorted
- [5] The number of classes to be predicted
- [6] The number of available GPUs for parallel data loading
- [7] The number of epochs for running DSS

Example: `!python main.py 1 3000 4000 5 5 1 200`

The example command will run DSS:
- for 200 epochs 
- on 5 units
- each with 3000-4000 samples
- predicting 5 clusters
- using 1 available GPU

The script will also save the DSS output and results to their respective folders:
- Output
    - Selected units
    - Preprocessed features
    - Cluster labels
    - Corresponding times
- Results
    - ARI progress plot
    - ARI progress log
    - SpikeInterface comparison results
    - Agreement matrix plot

## 4. Compare DeepSpikeSort output

### Create Sorting object from DSS output

In [None]:
recording_num = 1
output_folder = f'output/{recording_num:03}'

num_units = 5
dss_labels_file = os.path.join(output_folder, f'dss_labels_{num_units:03}.npy')
dss_times_file = os.path.join(output_folder, f'dss_times_{num_units:03}.npy')

dss_labels = np.load(dss_labels_file)
dss_times = np.load(dss_times_file)

In [None]:
print(f'Samples: {len(dss_labels)}\n')
print(util.format_value_counts(dss_labels))

In [None]:
# Create custom NumpySorting object from DeepSpikeSort output
sorting_dss = comparison.create_numpy_sorting(dss_times, dss_labels, 30000)
sorting_dss

### Create Sorting object from NWB file

In [None]:
# Create a boolean mask
selected_units_file = os.path.join(output_folder, f'selected_units_{num_units:03}.npy')
selected_units = np.load(selected_units_file)
mask_selected = np.isin(peaks_matched['unit_index'], [int(unit) for unit in selected_units])

# Filter the array
peaks_selected = peaks_matched[mask_selected]
display(pd.DataFrame(peaks_selected))

In [None]:
peak_times = peaks_selected['sample_index']
peak_units = peaks_selected['unit_index']

In [None]:
print(f'Samples: {len(peak_units)}\n')
print(util.format_value_counts(peak_units))

In [None]:
sorting_peaks = comparison.create_numpy_sorting(peak_times, peak_units, 30000)
sorting_peaks

### Compare Sorting objects

In [None]:
# Run the comparison
cmp_dss_peaks = si.compare_two_sorters(
    sorting1=sorting_dss,
    sorting2=sorting_peaks,
    sorting1_name='DeepSpikeSort',
    sorting2_name='Peaks',
    verbose=True
)

In [None]:
# In order to check which units were matched, the `get_matching` method can be used.
# If units are not matched they are listed as -1.
dss_to_peaks, _ = cmp_dss_peaks.get_matching()
display(dss_to_peaks)

In [None]:
# Some useful internal dataframes help to check the match and count
#  like **match_event_count** or **agreement_scores**
display(cmp_dss_peaks.match_event_count)
display(cmp_dss_peaks.agreement_scores)

In [None]:
# We can check the agreement matrix to inspect the matching.
si.plot_agreement_matrix(cmp_dss_peaks)