# Phase 2: A Semi-Supervised Method

---

## 1. Extract the data in an NWB file

### Import required modules

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd

from pathlib import Path
import os
import pickle

import psutil
import multiprocessing as mp
from tqdm.notebook import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
from collections import Counter
from sklearn.model_selection import train_test_split

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchinfo import summary
import torchvision
from torchvision import utils

import warnings
warnings.simplefilter("ignore")

%matplotlib widget

In [None]:
import spikeinterface as si
import spikeinterface.extractors as se 
import spikeinterface.preprocessing as spre
import spikeinterface.widgets as sw
from spikeinterface.sortingcomponents.peak_detection import detect_peaks

In [None]:
print(f"SpikeInterface version: {si.__version__}")

In [None]:
import deepspikesort as dss

### Read the NWB file

In [None]:
base_folder = Path(".")
nwb_file = "sub-CSHL049_ses-c99d53e6-c317-4c53-99ba-070b26673ac4_behavior+ecephys+image.nwb"

In [None]:
recording_nwb = se.read_nwb_recording(file_path=nwb_file, electrical_series_name='ElectricalSeriesAp')
recording_nwb

In [None]:
recording_nwb.annotate(is_filtered=False)

In [None]:
sorting_nwb = se.read_nwb_sorting(file_path=nwb_file, electrical_series_name='ElectricalSeriesAp')
sorting_nwb

### Preprocess the recording

In [None]:
recording_f = spre.bandpass_filter(recording_nwb, freq_min=300, freq_max=6000)
recording_f

In [None]:
recording_cmr = spre.common_reference(recording_f, reference='global', operator='median')
recording_cmr

### Extract channels and spikes information

In [None]:
recording_slice = dss.channel_slice_electricalseriesap(recording_cmr)
recording_slice

In [None]:
channels_table = dss.extract_channels(recording_slice)
display(channels_table)

In [None]:
channels_table['channel_loc_x'].unique()

In [None]:
waveform_folder = 'waveform'

job_kwargs = dict(n_jobs=10, chunk_duration="1s", progress_bar=True)

In [None]:
if (base_folder / waveform_folder).is_dir():
    waveform_nwb = si.load_waveforms(base_folder / waveform_folder)
else:
    waveform_nwb = si.extract_waveforms(
        recording_slice,
        sorting_nwb,
        waveform_folder,
        ms_before=1.5,
        ms_after=2.,
        max_spikes_per_unit=None,
        overwrite=True,
        **job_kwargs
    )

In [None]:
waveform_nwb

In [None]:
spikes_table = dss.extract_spikes(sorting_nwb, waveform_nwb)
display(spikes_table)

---

## 2. Match peaks to spikes

In [None]:
peaks_folder = 'peaks'

job_kwargs = dict(chunk_duration='1s', n_jobs=8, progress_bar=True)

In [None]:
if (base_folder / peaks_folder).is_dir():
    peaks = dss.load_peaks(base_folder / peaks_folder)
else:
    peaks = detect_peaks(recording_slice,
                         method='locally_exclusive',
                         peak_sign='neg',
                         detect_threshold=6,
                         **job_kwargs
                        )
    dss.save_peaks(peaks, base_folder / peaks_folder)

In [None]:
peaks_table = dss.extract_peaks(recording_slice, peaks)
display(peaks_table)

In [None]:
peaks_matched_table = dss.match_peaks(peaks_table, spikes_table, channels_table)
display(peaks_matched_table)

In [None]:
peaks_spikes_table = dss.get_peaks_spikes(peaks_matched_table)
peaks_noise_table = dss.get_peaks_noise(peaks_matched_table)

display(peaks_spikes_table)
display(peaks_noise_table)

---

## 3. Create a dataset from matched peaks

In [None]:
nthreads = psutil.cpu_count(logical=True)
ncores = psutil.cpu_count(logical=False)
nthreads_per_core = nthreads // ncores
nthreads_available = len(os.sched_getaffinity(0))
ncores_available = nthreads_available // nthreads_per_core

assert nthreads == os.cpu_count()
assert nthreads == mp.cpu_count()

print(f'{nthreads=}')
print(f'{ncores=}')
print(f'{nthreads_per_core=}')
print(f'{nthreads_available=}')
print(f'{ncores_available=}')

In [None]:
peaks_dataset_folder = os.path.join(peaks_folder, 'peaks_dataset')

if not os.path.exists(peaks_dataset_folder):
    os.mkdir(peaks_dataset_folder)

In [None]:
selected_peaks = peaks_matched_table['unit_id'].value_counts()
selected_peaks = selected_peaks[selected_peaks >= 500].index
selected_peaks = selected_peaks.to_list()
selected_peaks.sort()

In [None]:
peak_labels = ['unit_' + str(unit) for unit in selected_peaks]
peak_labels_dict = {name: index for index, name in enumerate(peak_labels)}
peaks_dataset = dss.TensorDataset(peaks_dataset_folder, peak_labels, peak_labels_dict)

### Split for training and testing

In [None]:
train_size = int(0.7 * len(peaks_dataset))
test_size = len(peaks_dataset) - train_size
peaks_train_dataset, peaks_test_dataset = torch.utils.data.random_split(peaks_dataset, [train_size, test_size])

print("Training size:", len(peaks_train_dataset))
print("Testing size:",len(peaks_test_dataset))

In [None]:
# Create DataLoader instances for train and test datasets
peaks_train_dataloader = DataLoader(
    peaks_train_dataset,
    batch_size=64,
    num_workers=8,
    shuffle=True  # Shuffle the train dataset during training
)

peaks_test_dataloader = DataLoader(
    peaks_test_dataset,
    batch_size=64,
    num_workers=8
)

---

## 4. Classify peaks with a CNN 

### Build the CNN architecture

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

In [None]:
class CNNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layer_1 = nn.Conv3d(1, 32, kernel_size=(9, 3, 2)) 
        self.conv_layer_2 = nn.Conv2d(32, 64, kernel_size=4) 
        
        self.conv_layer_2_drop = nn.Dropout2d()
        
        self.flatten = nn.Flatten()
        self.fully_connected_layer_1 = nn.Linear(35328, 500)
        self.fully_connected_layer_2 = nn.Linear(500, 316)

    def forward(self, x):
        x = F.relu(F.max_pool2d(torch.squeeze(self.conv_layer_1(x), 4), 2))
        
        x = F.relu(F.max_pool2d(self.conv_layer_2_drop(self.conv_layer_2(x)), 2))
        
        x = self.flatten(x)
        x = F.relu(self.fully_connected_layer_1(x))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.fully_connected_layer_2(x))
        return F.log_softmax(x,dim=1)

model = CNNet().to(device)

# Choose optimal parameters
loss_fn = nn.CrossEntropyLoss()
learning_rate = 0.0001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
summary(model, input_size=(64, 1, 64, 192, 2))

### Train and test the model

In [None]:
models_folder = os.path.join(os.getcwd(), "models")

if not os.path.exists(models_folder):
    os.mkdir(models_folder)

In [None]:
train_peaks = dss.TrainModel(peaks_train_dataloader,
                             peaks_test_dataloader,
                             device,
                             loss_fn,
                             optimizer)

In [None]:
model_name = "peak_2"

train_peaks.train_test_model(model, model_name, models_folder, epochs=15, classes=selected_peaks)