# Phase 2: Implementing DeepCluster

---

## 1. Extract the data in an NWB file

### Import required modules

In [None]:
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd

from pathlib import Path
import os
import re
import pickle

import torch
import torch.nn as nn
from torchinfo import summary

from torch.utils.data import DataLoader
from collections import Counter

from pycave.bayes import GaussianMixture
import warnings
warnings.simplefilter("ignore")

%matplotlib widget

In [None]:
import spikeinterface.full as si
from spikeinterface.sortingcomponents.peak_detection import detect_peaks

In [None]:
print(f"SpikeInterface version: {si.__version__}")

In [None]:
import preprocessing
import process_peaks
import dataset
import model
import deepcluster

### Read the NWB file

In [None]:
base_folder = Path(".")
nwb_file = "sub-CSHL049_ses-c99d53e6-c317-4c53-99ba-070b26673ac4_behavior+ecephys+image.nwb"

In [None]:
recording_nwb = si.read_nwb(nwb_file, electrical_series_name='ElectricalSeriesAp')
recording_nwb

In [None]:
recording_nwb.annotate(is_filtered=False)

### Preprocess the recording

In [None]:
recording_f = si.bandpass_filter(recording_nwb, freq_min=300, freq_max=6000)
recording_f

In [None]:
recording_cmr = si.common_reference(recording_f, reference='global', operator='median')
recording_cmr

### Extract channels and spikes information

In [None]:
recording_slice = preprocessing.channel_slice_electricalseriesap(recording_cmr)
recording_slice

In [None]:
channels_table = preprocessing.extract_channels(recording_slice)
display(channels_table)

In [None]:
channels_table['channel_loc_x'].unique()

In [None]:
sorting_nwb = si.read_nwb_sorting(file_path=nwb_file, electrical_series_name='ElectricalSeriesAp')
sorting_nwb

In [None]:
waveform_folder = 'waveform'

job_kwargs = dict(n_jobs=10, chunk_duration="1s", progress_bar=True)

In [None]:
if (base_folder / waveform_folder).is_dir():
    waveform_nwb = si.load_waveforms(base_folder / waveform_folder, with_recording=False)
else:
    waveform_nwb = si.extract_waveforms(
        recording_slice,
        sorting_nwb,
        waveform_folder,
        ms_before=1.5,
        ms_after=2.,
        max_spikes_per_unit=None,
        overwrite=True,
        **job_kwargs
    )

In [None]:
waveform_nwb

In [None]:
spikes_table = preprocessing.extract_spikes(sorting_nwb, waveform_nwb)
display(spikes_table)

In [None]:
peaks_folder = 'peaks'

job_kwargs = dict(chunk_duration='1s', n_jobs=8, progress_bar=True)

In [None]:
if (base_folder / peaks_folder).is_dir():
    peaks = process_peaks.load_peaks(base_folder / peaks_folder)
else:
    peaks = detect_peaks(
        recording_slice,
        method='locally_exclusive',
        peak_sign='neg',
        detect_threshold=6,
        **job_kwargs
    )
    
    process_peaks.save_peaks(peaks, base_folder / peaks_folder)

In [None]:
peaks_table = process_peaks.extract_peaks(recording_slice, peaks)
display(peaks_table)

---

## 2. Match peaks to spikes

In [None]:
peaks_matched_table_file = os.path.join(peaks_folder, "peaks_matched_table.pkl")

if os.path.exists(peaks_matched_table_file):
    peaks_matched_table = pd.read_pickle(peaks_matched_table_file)
else:
    peaks_matched_table = process_peaks.match_peaks(peaks_table, spikes_table, channels_table)
    peaks_matched_table.to_pickle(peaks_matched_table_file)
    
display(peaks_matched_table)

## 3. Create a dataset from matched peaks

In [None]:
peaks_dataset_folder = os.path.join(peaks_folder, 'peaks_dataset')

if not os.path.exists(peaks_dataset_folder):
    os.mkdir(peaks_dataset_folder)

In [None]:
unit_counts = peaks_matched_table['unit_id'].value_counts()
filtered_units = unit_counts[(unit_counts >= 3000) & (unit_counts <= 4000)]
selected_units = filtered_units.sample(n=3, random_state=1).index.to_list()

print(selected_units)

In [None]:
peaks_dataset = dataset.UnsupervisedDataset(peaks_dataset_folder, selected_units, shuffle=True, seed=0)
print('Dataset:', len(peaks_dataset))

In [None]:
classes = peaks_dataset.get_labels()
print(Counter(classes))

In [None]:
peaks_dataloader = DataLoader(
        peaks_dataset,
        batch_size=64
)

---

## 4. Implementing DeepCluster

In [None]:
import importlib

preprocessing = importlib.reload(preprocessing)
process_peaks = importlib.reload(process_peaks)
dataset = importlib.reload(dataset)
model = importlib.reload(model)
deepcluster = importlib.reload(deepcluster)

In [None]:
!nvidia-smi

In [None]:
num_classes=3
deepcluster_model = model.DeepCluster(num_classes)
gmm = GaussianMixture(num_classes, covariance_type="full", init_strategy='kmeans', trainer_params=dict(gpus=[0]))

In [None]:
device=torch.device("cuda:0")
device_ids=[0]

In [None]:
deepcluster_kwargs = dict(
    device=torch.device("cuda:0"), 
    device_ids=[0], 
    loss_fn=nn.CrossEntropyLoss(), 
    optimizer=torch.optim.Adam(deepcluster_model.parameters(), lr=0.0001),
    sampling_frequency = 30000,
    epochs=200
)

In [None]:
cluster_labels = deepcluster.train_deepcluster(
    peaks_dataset, peaks_dataloader,
    deepcluster_model, 
    gmm, 
    num_classes,
    deepcluster_kwargs
)

## 5. Compare DeepSpikeSort output

In [None]:
peak_times = deepcluster.get_peak_times(peaks_dataset.image_paths)

In [None]:
# Create custom NumpySorting object
sorting_dss = deepcluster.create_numpy_sorting(peak_times, cluster_labels, 30000)
sorting_dss

In [None]:
sorting_selected = sorting_nwb.select_units(unit_ids=selected_units)
sorting_selected

In [None]:
# Run the comparison
cmp_nwb_dss = si.compare_two_sorters(
    sorting1=sorting_dss,
    sorting2=sorting_selected,
    sorting1_name='DeepSpikeSort',
    sorting2_name='NWB',
)

In [None]:
# We can check the agreement matrix to inspect the matching.
si.plot_agreement_matrix(cmp_nwb_dss)

In [None]:
# Some useful internal dataframes help to check the match and count
#  like **match_event_count** or **agreement_scores**
display(cmp_nwb_dss.match_event_count)
display(cmp_nwb_dss.agreement_scores)

In [None]:
# In order to check which units were matched, the `get_matching` method can be used.
# If units are not matched they are listed as -1.
dss_to_nwb, _ = cmp_nwb_dss.get_matching()
display(dss_to_nwb)