# Phase 2: A Semi-Supervised Method

---

## 1. Extract the data in an NWB file

### Import required modules

In [None]:
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd

from pathlib import Path
import os
import pickle

import psutil
from tqdm.notebook import tqdm

import torch
from collections import Counter
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary
import torch.optim as optim

from pycave.bayes import GaussianMixture

import warnings
warnings.simplefilter("ignore")

%matplotlib widget

In [None]:
import spikeinterface.full as si

In [None]:
print(f"SpikeInterface version: {si.__version__}")

In [None]:
import preprocessing
import process_peaks
import dataset
import model
import clustering
import training

### Read the NWB file

In [None]:
base_folder = Path(".")
nwb_file = "sub-CSHL049_ses-c99d53e6-c317-4c53-99ba-070b26673ac4_behavior+ecephys+image.nwb"

In [None]:
recording_nwb = si.read_nwb(nwb_file, electrical_series_name='ElectricalSeriesAp')
recording_nwb

In [None]:
recording_nwb.annotate(is_filtered=False)

In [None]:
sorting_nwb = si.read_nwb_sorting(file_path=nwb_file, electrical_series_name='ElectricalSeriesAp')
sorting_nwb

### Preprocess the recording

In [None]:
recording_f = si.bandpass_filter(recording_nwb, freq_min=300, freq_max=6000)
recording_f

In [None]:
recording_cmr = si.common_reference(recording_f, reference='global', operator='median')
recording_cmr

### Extract channels and spikes information

In [None]:
recording_slice = preprocessing.channel_slice_electricalseriesap(recording_cmr)
recording_slice

In [None]:
channels_table = preprocessing.extract_channels(recording_slice)
display(channels_table)

In [None]:
channels_table['channel_loc_x'].unique()

In [None]:
waveform_folder = 'waveform'

job_kwargs = dict(n_jobs=10, chunk_duration="1s", progress_bar=True)

In [None]:
if (base_folder / waveform_folder).is_dir():
    waveform_nwb = si.load_waveforms(base_folder / waveform_folder, with_recording=False)
else:
    waveform_nwb = si.extract_waveforms(
        recording_slice,
        sorting_nwb,
        waveform_folder,
        ms_before=1.5,
        ms_after=2.,
        max_spikes_per_unit=None,
        overwrite=True,
        **job_kwargs
    )

In [None]:
waveform_nwb

In [None]:
spikes_table = preprocessing.extract_spikes(sorting_nwb, waveform_nwb)
display(spikes_table)

---

## 2. Match peaks to spikes

In [None]:
peaks_folder = 'peaks'

job_kwargs = dict(chunk_duration='1s', n_jobs=8, progress_bar=True)

In [None]:
if (base_folder / peaks_folder).is_dir():
    peaks = process_peaks.load_peaks(base_folder / peaks_folder)
else:
    peaks = detect_peaks(recording_slice,
                         method='locally_exclusive',
                         peak_sign='neg',
                         detect_threshold=6,
                         **job_kwargs
                        )
    process_peaks.save_peaks(peaks, base_folder / peaks_folder)

In [None]:
peaks_table = process_peaks.extract_peaks(recording_slice, peaks)
display(peaks_table)

In [None]:
peaks_matched_table_file = os.path.join(peaks_folder, "peaks_matched_table.pkl")

if os.path.exists(peaks_matched_table_file):
    peaks_matched_table = pd.read_pickle(peaks_matched_table_file)
else:
    peaks_matched_table = process_peaks.match_peaks(peaks_table, spikes_table, channels_table)
    peaks_matched_table.to_pickle(peaks_matched_table_file)
    
display(peaks_matched_table)

In [None]:
peaks_spikes_table = process_peaks.get_peaks_spikes(peaks_matched_table)
peaks_noise_table = process_peaks.get_peaks_noise(peaks_matched_table)

display(peaks_spikes_table)
display(peaks_noise_table)

---

## 3. Create a dataset from matched peaks

In [None]:
peaks_dataset_folder = os.path.join(peaks_folder, 'peaks_dataset')

if not os.path.exists(peaks_dataset_folder):
    os.mkdir(peaks_dataset_folder)

In [None]:
selected_peaks = peaks_matched_table['unit_id'].value_counts()
selected_peaks = selected_peaks[(selected_peaks >= 1000) & (selected_peaks <= 3700)].index
selected_peaks = selected_peaks.to_list()

print(len(selected_peaks))
print(selected_peaks)

In [None]:
peaks_dataset = dataset.TensorDataset(peaks_dataset_folder, selected_peaks)

### Split for training and testing

In [None]:
train_size = int(0.7 * len(peaks_dataset))
test_size = len(peaks_dataset) - train_size
peaks_train_dataset, peaks_test_dataset = torch.utils.data.random_split(peaks_dataset, [train_size, test_size])

print("Training size:", len(peaks_train_dataset))
print("Testing size:",len(peaks_test_dataset))

In [None]:
peaks_train_dataloader = DataLoader(
    peaks_train_dataset,
    batch_size=8,
    num_workers=8
)

---

## 4. Obtaining cluster assignments

### Build the extractor architecture

In [None]:
class Extractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layer_1 = nn.Conv3d(1, 32, kernel_size=(9, 3, 2)) 
        self.conv_layer_2 = nn.Conv2d(32, 64, kernel_size=4) 
        
        self.conv_layer_2_drop = nn.Dropout2d()
        
        self.flatten = nn.Flatten()
        self.fully_connected_layer_1 = nn.Linear(35328, 500)
        
        # Initialize the weights as per the provided AlexNet
        self._initialize_weights()

    def forward(self, x):
        x = F.relu(F.max_pool2d(torch.squeeze(self.conv_layer_1(x), 4), 2))
        x = F.relu(F.max_pool2d(self.conv_layer_2_drop(self.conv_layer_2(x)), 2))
        
        x = self.flatten(x)
        x = F.relu(self.fully_connected_layer_1(x))
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)

In [None]:
# Create an instance of model
model = Extractor()
summary(model, input_size=(64, 1, 64, 192, 2))

### Extract features and folders

In [None]:
features_file = os.path.join(peaks_folder, "features_100.npy")
folders_file = os.path.join(peaks_folder, "folders_100.pkl")

if os.path.exists(features_file):
    features = np.load(features_file, allow_pickle=True)
    # Load folder list
    with open(folders_file, 'rb') as f:
        folders = pickle.load(f)
else:
    device_ids = [0, 1, 2, 3]
    device = torch.device("cuda:0")
    features, folders = clustering.extract_features(peaks_train_dataloader, model, device, device_ids)
    
    # Save features
    np.save(features_file, features)
    
    # Save folder list
    with open(folders_file, 'wb') as f:
        pickle.dump(folders, f)

In [None]:
features.shape

In [None]:
preprocessed_features = clustering.preprocess_features(features, n_components=100)

In [None]:
preprocessed_features.shape

In [None]:
print(len(Counter(folders)))

### Generate cluster assignments

In [None]:
gmm = GaussianMixture(100, covariance_type="full", init_strategy='kmeans', trainer_params=dict(gpus=[0]))
gmm.fit(preprocessed_features)

In [None]:
cluster_assignments = gmm.predict(preprocessed_features)

In [None]:
cluster_assignments

In [None]:
cluster_assignments = cluster_assignments.cpu().numpy()

## 5. Learning cluster representations

### Create reassigned dataset

In [None]:
folder_to_cluster_map = {}

for folder, cluster in zip(folders, cluster_assignments):
    folder_to_cluster_map[folder] = cluster

In [None]:
print(folder_to_cluster_map)

In [None]:
clustered_dataset = dataset.ClusteredDataset(
    peaks_dataset_folder,
    selected_peaks,
    folder_to_cluster_map
)

In [None]:
train_size = int(0.7 * len(peaks_dataset))
test_size = len(peaks_dataset) - train_size
clustered_train_dataset, clustered_test_dataset = torch.utils.data.random_split(clustered_dataset, [train_size, test_size])

print("Training size:", len(clustered_train_dataset))
print("Testing size:",len(clustered_test_dataset))

In [None]:
# Create DataLoader instances for train and test datasets
clustered_train_dataloader = DataLoader(
    clustered_train_dataset,
    batch_size=8,
    num_workers=8
)

clustered_test_dataloader = DataLoader(
    clustered_test_dataset,
    batch_size=8,
    num_workers=8
)

### Build the classifier architecture

In [None]:
class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layer_1 = nn.Conv3d(1, 32, kernel_size=(9, 3, 2)) 
        self.conv_layer_2 = nn.Conv2d(32, 64, kernel_size=4) 
        
        self.conv_layer_2_drop = nn.Dropout2d()
        
        self.flatten = nn.Flatten()
        self.fully_connected_layer_1 = nn.Linear(35328, 500)
        self.fully_connected_layer_2 = nn.Linear(500, 100)

    def forward(self, x):
        x = F.relu(F.max_pool2d(torch.squeeze(self.conv_layer_1(x), 4), 2))
        x = F.relu(F.max_pool2d(self.conv_layer_2_drop(self.conv_layer_2(x)), 2))
        
        x = self.flatten(x)
        x = F.relu(self.fully_connected_layer_1(x))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.fully_connected_layer_2(x))
        return x

In [None]:
# Create an instance of model
classifier = Classifier()
summary(classifier, input_size=(64, 1, 64, 192, 2))

### Train the classifier

In [None]:
# Choose optimal parameters
loss_fn = nn.CrossEntropyLoss()
learning_rate = 0.0001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
models_folder = os.path.join(os.getcwd(), "models")

if not os.path.exists(models_folder):
    os.mkdir(models_folder)

In [None]:
train_clusters = training.TrainModel(clustered_train_dataloader,
                                     clustered_test_dataloader,
                                     device,
                                     loss_fn,
                                     optimizer)

In [None]:
device = torch.device("cuda:3")
classifier = classifier.to(device)
model_name = "classifier"

train_clusters.train_test_model(classifier, model_name, models_folder, epochs=1, classes=selected_peaks)

## 6. Implementing DeepCluster

In [None]:
import importlib

dataset = importlib.reload(dataset)
model = importlib.reload(model)
clustering = importlib.reload(clustering)
training = importlib.reload(training)

In [None]:
!nvidia-smi

In [None]:
device = torch.device("cuda:0")
device_ids = [0, 1, 2, 3]

In [None]:
deep_cluster_model = model.DeepCluster(100)
gmm = GaussianMixture(100, covariance_type="full", init_strategy='kmeans', trainer_params=dict(gpus=[2]))

In [None]:
n_iterations = 3
loss_fn = nn.CrossEntropyLoss()
learning_rate = 0.0001
optimizer = torch.optim.Adam(deep_cluster_model.parameters(), lr=learning_rate)

In [None]:
models_folder = os.path.join(os.getcwd(), "models")

if not os.path.exists(models_folder):
    os.mkdir(models_folder)

In [None]:
for iteration in range(n_iterations):
    print(f"Iteration {iteration + 1}")
  
    # Step 1: Feature Extraction
    features, folders = clustering.extract_features(peaks_train_dataloader, 500, deep_cluster_model, device, device_ids) 
    preprocessed_features = clustering.preprocess_features(features, n_components=100) 
  
    # Step 2: Clustering
    gmm.fit(preprocessed_features)
    cluster_assignments = gmm.predict(preprocessed_features)
    cluster_assignments = cluster_assignments.cpu().numpy()
  
    # Step 3: Cluster Reassignment
    folder_to_cluster_map = {folder: cluster for folder, cluster in zip(folders, cluster_assignments)}
    clustered_dataset = dataset.ClusteredDataset(peaks_dataset_folder, selected_peaks, folder_to_cluster_map)  
    train_size = int(0.7 * len(clustered_dataset))
    test_size = len(clustered_dataset) - train_size
    clustered_train_dataset, clustered_test_dataset = torch.utils.data.random_split(clustered_dataset, [train_size, test_size])
  
    clustered_train_dataloader = DataLoader(clustered_train_dataset, batch_size=8)
    clustered_test_dataloader = DataLoader(clustered_test_dataset, batch_size=8)
  
    # Step 4: Classification
    train_clusters = training.TrainModel(clustered_train_dataloader, clustered_test_dataloader, device, device_ids, loss_fn, optimizer)
    train_clusters.train_test_model(deep_cluster_model, "sup_dss_1", models_folder, epochs=1, classes=selected_peaks)