# Phase 2

---

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

import numpy as np
import h5py
import pandas as pd

import matplotlib.pyplot as plt

import warnings
warnings.simplefilter("ignore")

%matplotlib widget

In [3]:
import spikeinterface.full as si
from spikeinterface.sortingcomponents.peak_detection import detect_peaks

print(f"SpikeInterface version: {si.__version__}")

SpikeInterface version: 0.101.0


In [4]:
import comparison
import evaluation
import process_peaks

sys.path.append("..")
import plotting
import preprocessing
import util

## 1. Extract peaks from recording

In [1]:
import h5py
print(h5py.__version__)
print(h5py.version.hdf5_version)

3.10.0
1.12.2


In [5]:
data_folder = "../data/sub-CSHL049"

### Load recording from disk

In [6]:
preprocessed_folder = os.path.join(data_folder, "extractors/preprocessed")
recording_preprocessed = si.load_extractor(preprocessed_folder)
    
recording_preprocessed

### Retrieve channels and spikes

In [7]:
channels_file = os.path.join(data_folder, "channels.npy")
channels = np.load(channels_file)

display(pd.DataFrame(channels))

Unnamed: 0,channel_index,channel_location_x,channel_location_y
0,0,16.0,0.0
1,1,48.0,0.0
2,2,0.0,20.0
3,3,32.0,20.0
4,4,16.0,40.0
...,...,...,...
379,379,32.0,3780.0
380,380,16.0,3800.0
381,381,48.0,3800.0
382,382,0.0,3820.0


In [8]:
spikes_file = os.path.join(data_folder, "spikes/spikes.npy")
spikes = np.load(spikes_file)

display(pd.DataFrame(spikes))

Unnamed: 0,spike_index,sample_index,channel_index,channel_location_x,channel_location_y,unit_index
0,0,472,341,48.0,3400.0,271
1,1,511,361,48.0,3600.0,306
2,2,606,354,0.0,3540.0,297
3,3,680,361,48.0,3600.0,306
4,4,715,325,48.0,3240.0,235
...,...,...,...,...,...,...
4604408,4604408,125188816,21,48.0,200.0,26
4604409,4604409,125188838,155,32.0,1540.0,105
4604410,4604410,125188912,325,48.0,3240.0,237
4604411,4604411,125188967,326,0.0,3260.0,239


### Detect peaks

In [9]:
peaks_folder = '../data/sub-CSHL049/peaks'

os.makedirs(peaks_folder, exist_ok=True)

In [10]:
peaks_file = os.path.join(peaks_folder, "peaks.npy")

if os.path.exists(peaks_file):
    peaks_filtered = np.load(peaks_file)
else:
    job_kwargs = dict(chunk_duration='1s', n_jobs=10, progress_bar=True)
    
    peaks = detect_peaks(
        recording_preprocessed,
        method='locally_exclusive',
        peak_sign='neg',
        detect_threshold=6,
        radius_um = 100,
        **job_kwargs
    )    
    
    peaks_filtered = process_peaks.filter_peaks(recording_preprocessed, peaks, channels)
    
    np.save(peaks_file, peaks_filtered)
    
display(pd.DataFrame(peaks_filtered))

Unnamed: 0,peak_index,time,channel_index,channel_location_x,channel_location_y,amplitude
0,0,93,326,0,3260,-27
1,1,147,348,16,3480,-40
2,2,177,337,48,3360,-67
3,3,207,6,0,60,-54
4,4,269,330,0,3300,-34
...,...,...,...,...,...,...
3260855,3260855,125189311,222,0,2220,-36
3260856,3260856,125189392,273,48,2720,-24
3260857,3260857,125189402,89,48,880,-37
3260858,3260858,125189402,269,48,2680,-21


---

## 2. Create a dataset from peaks

### Match peaks to spikes

In [11]:
peaks_matched_file = os.path.join(peaks_folder, "peaks_matched.npy")

if os.path.exists(peaks_matched_file):
    peaks_matched = np.load(peaks_matched_file)
else:
    peaks_matched = process_peaks.match_peaks(peaks_filtered, spikes)
    np.save(peaks_matched_file, peaks_matched)
    
display(pd.DataFrame(peaks_matched))

Unnamed: 0,peak_index,time,channel_index,channel_location_x,channel_location_y,amplitude,unit_index
0,0,93,326,0,3260,-27,-1
1,1,147,348,16,3480,-40,-1
2,2,177,337,48,3360,-67,-1
3,3,207,6,0,60,-54,-1
4,4,269,330,0,3300,-34,-1
...,...,...,...,...,...,...,...
3260855,3260855,125189311,222,0,2220,-36,-1
3260856,3260856,125189392,273,48,2720,-24,-1
3260857,3260857,125189402,89,48,880,-37,-1
3260858,3260858,125189402,269,48,2680,-21,-1


In [114]:
import importlib
import evaluation

evaluation = importlib.reload(evaluation)

In [42]:
output_folder = "output/sub-CSHL049"
session_date = '2024-08-22'
trial_id = 'RELABEL_000'

# Create a boolean mask
units_selected_file = os.path.join(output_folder, f'{session_date}/{trial_id}/{trial_id}_units_selected.npy')
# units_selected_file = os.path.join(output_folder, f'{session_date}/{trial_id}/{trial_id}_selected_units.npy')
units_selected = np.load(units_selected_file).astype(int)

mask_selected = np.isin(peaks_matched['unit_index'], units_selected)

# Filter the array
peaks_selected = peaks_matched[mask_selected]
display(pd.DataFrame(peaks_selected))

Unnamed: 0,peak_index,time,channel_index,channel_location_x,channel_location_y,amplitude,unit_index
0,602,47676,359,32,3580,-59,305
1,833,62254,335,32,3340,-55,262
2,840,62574,359,32,3580,-44,305
3,951,71727,336,16,3360,-47,262
4,1231,95624,335,32,3340,-62,262
...,...,...,...,...,...,...,...
12794,3260101,125164030,235,32,2340,-32,165
12795,3260267,125169794,42,0,420,-56,38
12796,3260347,125172733,232,16,2320,-28,165
12797,3260413,125174181,42,0,420,-45,38


In [68]:
pd.DataFrame(peaks_matched[peaks_matched['unit_index']==165][:6])

Unnamed: 0,peak_index,time,channel_index,channel_location_x,channel_location_y,amplitude,unit_index
0,3357,184082,225,48,2240,-46,165
1,9478,539109,228,16,2280,-27,165
2,19436,1048448,229,48,2280,-30,165
3,70857,3669831,229,48,2280,-35,165
4,87532,4478887,228,16,2280,-33,165
5,92370,4725699,230,0,2300,-34,165


In [47]:
list(units_selected)

[165, 262, 305, 219, 38]

In [115]:
interactive_plot = evaluation.interactive_unit_plot(recording_preprocessed, peaks_selected, units_selected, channels, columns='single', method='mask')
display(interactive_plot)

interactive(children=(Dropdown(description='Unit:', options=('165', '262', '305', '219', '38'), value='165'), …

### Create peaks dataset

The `generate_dataset.py` script is used here again to create a dataset of HDF5 files where each file belongs to an identified unit within the peaks that we have matched to that of the NWB file.

In [12]:
peak_units = peaks_matched['unit_index']

print(f'Peak units: {len(np.unique(peak_units))}\n')
print(util.format_value_counts(peak_units))

Peak units: 421

-01: 973241	042: 2398  	085: 1516  	129: 8540  	172: 4216  	215: 32292 	258: 8596  	301: 1688  	344: 11332 	388: 1370  
000: 19330 	043: 1245  	086: 5398  	130: 3216  	173: 15629 	216: 1750  	259: 1     	302: 14722 	345: 404   	389: 31    
001: 14270 	044: 10521 	087: 7897  	131: 1     	174: 2576  	217: 59    	260: 4648  	303: 529   	346: 4364  	390: 1184  
002: 6441  	045: 4018  	088: 12222 	132: 814   	175: 8663  	218: 27    	261: 449   	304: 519   	347: 1966  	391: 318   
003: 4416  	046: 1315  	089: 5607  	133: 467   	176: 381   	219: 2057  	262: 2701  	305: 2995  	348: 11475 	392: 1061  
004: 417   	047: 11677 	090: 23    	134: 4001  	177: 800   	220: 34    	263: 7236  	306: 4815  	349: 3925  	393: 2552  
005: 325   	048: 168   	091: 2059  	135: 12393 	178: 183   	221: 489   	264: 6660  	307: 159   	350: 148   	394: 800   
006: 3433  	049: 1165  	092: 10770 	136: 126   	179: 438   	222: 609   	265: 9508  	308: 916   	351: 5779  	395: 936   
007: 20    	050: 4743  

In [13]:
from collections import Counter

def count_labels_in_range(labels, lower_bound, upper_bound):
    # Count the frequency of each label
    label_counts = Counter(labels)
    
    # Count how many labels have a frequency within the specified range
    labels_in_range = sum(1 for count in label_counts.values() if lower_bound <= count <= upper_bound)
    
    return labels_in_range

# Example usage
lower_bound = 1000
upper_bound = 4010

result = count_labels_in_range(peak_units, lower_bound, upper_bound)

print(f"Number of labels appearing between {lower_bound} and {upper_bound} times: {result}")

Number of labels appearing between 1000 and 4010 times: 100


Here, we set the second argument to 1 for peaks from the peaks algorithm.

Example: `!python generate_dataset.py 1 1`

This example command will generate a dataset of peaks from recording number 1 starting from unit 0 to 420.

---

## 3. Run DeepSpikeSort

The DeepSpikeSort algorithm can be run using the `run_dss.py` script.

DeepSpikeSort or DSS follows the DeepCluster method using the following steps:

1. Feature Extraction
- Initialize the CNN model with random weights for the first epoch
- Extract features before the final FC layer
- Preprocess features using PCA, whitening and l2-normalization

2. Clustering
- Fit a GMM with the preprocessed features 
- Predict cluster labels for the features

3. Cluster Comparison
- Calculate the ARI (Adjusted Rand Index) between epochs after the first epoch
- Set the ARI value as a metric for convergence

4. Representation Learning
- Create a dataset using the cluster labels for supervised learning
- Train the CNN model with labelled dataset

The script needs to be run with 7 arguments:

- [1] The number associated with the recording to be used
- [2] The minimum number of samples per unit
- [3] The maximum number of samples per unit 
- [4] The number of units to be sorted
- [5] The number of classes to be predicted
- [6] The number of available GPUs for parallel data loading
- [7] The number of epochs for running DSS

Example: `!python run_dss.py 1 5 3000 4000 5 1 200`

The example command will run DSS:
- using recording 1
- on 5 units
- with 3000-4000 samples per unit
- predicting 5 clusters
- using 1 available GPU
- for 200 epochs 

The script will also save the DSS output and results to their respective folders:
- Output
    - Selected units
    - Preprocessed features
    - Cluster labels
    - Corresponding times
- Results
    - ARI progress plot
    - ARI progress log
    - SpikeInterface comparison results
    - Agreement matrix plot

## 4. Inspect clusters

In [None]:
output_folder = "output/sub-CSHL049"

In [None]:
import re
from typing import List, Tuple

def parse_log_file(file_path: str) -> List[Tuple[int, float, float]]:
    """
    Parse the log file and extract epoch, loss, and accuracy values.
    
    Args:
        file_path (str): Path to the log file.
    
    Returns:
        List[Tuple[int, float, float]]: List of tuples containing (epoch, loss, accuracy).
    """
    pattern = r'\[(\d+)\]\s+Loss:\s+([\d.]+)\s+Accuracy:\s+([\d.]+)'
    data = []
    
    with open(file_path, 'r') as file:
        for line in file:
            match = re.search(pattern, line)
            if match:
                epoch = int(match.group(1))
                loss = float(match.group(2))
                accuracy = float(match.group(3))
                data.append((epoch, loss, accuracy))
    
    return data

In [None]:
def create_plot(data: List[Tuple[int, float, float]]) -> None:
    """
    Create a plot showing loss and accuracy over epochs in separate subplots.
    
    Args:
        data (List[Tuple[int, float, float]]): List of tuples containing (epoch, loss, accuracy).
    """
    epochs, losses, accuracies = zip(*data)
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10), sharex=True)
    
    # Plot loss
    ax1.plot(epochs, losses, color='tab:red', label='Loss')
    ax1.set_title('Loss')
    ax1.grid(True)
    
    # Plot accuracy
    ax2.plot(epochs, accuracies, color='tab:blue', label='Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_title('Accuracy')
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()

In [None]:
session_date = '2024-08-23'
trial_id = 'DSS_002'

log_file_path = os.path.join(output_folder, f'{session_date}/{trial_id}/{trial_id}_performance_metrics.log')
data = parse_log_file(log_file_path)
create_plot(data)

---

In [None]:
session_date = '2024-08-23'
trial_id = 'DSS_000'

In [None]:
previous_labels = np.load(os.path.join(output_folder, f'{session_date}/{trial_id}/previous_labels.npy'))
labels_current = np.load(os.path.join(output_folder, f'{session_date}/{trial_id}/current_labels.npy'))

In [None]:
current_sorting = si.NpzSortingExtractor(os.path.join(output_folder, f'{session_date}/{trial_id}/current_sorting.npz'))
previous_sorting = si.NpzSortingExtractor(os.path.join(output_folder, f'{session_date}/{trial_id}/previous_sorting.npz'))

In [None]:
comparison_cluster = si.compare_two_sorters(
            sorting1=current_sorting,
            sorting2=previous_sorting,
            sorting1_name="Current",
            sorting2_name="Previous",
            delta_time=0,
            verbose=True
        )

si.plot_agreement_matrix(sorting_comparison=comparison_cluster)

In [None]:
def relabel_clusters(previous_sorting, current_sorting, cluster_labels):
    # Compare the two sortings
    comparison = si.compare_two_sorters(
        sorting1=current_sorting,
        sorting2=previous_sorting,
        sorting1_name="Current",
        sorting2_name="Previous",
        delta_time=0,
        verbose=True  
    )
    
    # Get the matching between current and previous labels
    matching = comparison.get_matching()[0]
    
    # Find missing labels from the previous set
    all_possible_labels = set(range(len(matching)))
    used_labels = set(label for label in matching if label != -1)
    missing_labels = list(all_possible_labels - used_labels)
    
    # Randomly shuffle the missing labels
    np.random.shuffle(missing_labels)
    
    # Replace -1 values with randomly selected missing labels
    for i in range(len(matching)):
        if matching[i] == -1:
            if missing_labels:
                matching[i] = missing_labels.pop(0)
    
    # Create the mapping
    label_map = {i: int(matching[i]) for i in range(len(matching))}
    
    # Apply the mapping to cluster_labels
    new_labels = np.array([label_map[label] for label in cluster_labels], dtype=int)
    
    return new_labels

In [None]:
def relabel_clusters_add(previous_sorting, current_sorting, cluster_labels):
    # Compare the two sortings
    comparison = si.compare_two_sorters(
        sorting1=current_sorting,
        sorting2=previous_sorting,
        sorting1_name="Current",
        sorting2_name="Previous",
        delta_time=0,
        verbose=False  # Set to True for debugging
    )

    # Get the matching between current and previous labels
    matching = comparison.get_matching()[0].astype(int)

    # Create a mapping dictionary, including -1 values for unmatched clusters
    labels_map = {int(current): int(previous) for current, previous in matching.items()}

    # Handle unmatched clusters
    current_labels = np.array(list(labels_map.keys()))
    previous_labels = np.array(list(labels_map.values()))

    # Find the maximum label used
    max_label = max(max(current_labels), max(previous_labels[previous_labels != -1]))

    # Create a new mapping
    new_map = {}
    next_new_label = max_label + 1

    for current, previous in labels_map.items():
        if previous == -1:
            new_map[current] = next_new_label
            next_new_label += 1
        else:
            new_map[current] = previous

    # Create a mapping array
    max_label = max(max(cluster_labels), max(new_map.values()))
    map_array = np.arange(max_label + 1)  # Default to identity mapping
    for current, previous in new_map.items():
        map_array[current] = previous

    # Apply the mapping to cluster_labels
    new_labels = map_array[cluster_labels]

    return new_labels

In [None]:
new_labels = relabel_clusters(previous_sorting, current_sorting, labels_current)
np.unique(new_labels)

In [None]:
# Compare the two sortings
comparison = si.compare_two_sorters(
    sorting1=current_sorting,
    sorting2=previous_sorting,
    sorting1_name="Current",
    sorting2_name="Previous",
    delta_time=0,
    verbose=True  
)
    
# Get the matching between current and previous labels
matching = comparison.get_matching()[0].astype(int)
matching

In [None]:
matching.dtype

---

In [None]:
trial_id = 'DSS_000'
session_date = '2024-08-23'

In [None]:
features = np.load(os.path.join(output_folder, f'{session_date}/{trial_id}/features_before_pca.npy'))
features.shape

In [None]:
def check_zero_features(features):
    # Sum along the first axis (across all samples)
    feature_sums = np.sum(features, axis=0)
    
    # Find indices where the sum is zero
    zero_features = np.where(feature_sums == 0)[0]
    
    return zero_features

# Check for zero features
zero_feature_indices = check_zero_features(features_normalized)

# Print results
if len(zero_feature_indices) > 0:
    print(f"Found {len(zero_feature_indices)} features with all zero values.")
    print("Indices of all-zero features:", zero_feature_indices)
else:
    print("No features with all zero values found.")

In [None]:
def check_nan_values(features):
    # Check for NaN values
    nan_mask = np.isnan(features)
    
    # Find features with any NaN values
    features_with_nan = np.any(nan_mask, axis=0)
    
    # Get indices of features with NaN values
    nan_feature_indices = np.where(features_with_nan)[0]
    
    # Count NaN values per feature
    nan_counts = np.sum(nan_mask, axis=0)
    
    return nan_feature_indices, nan_counts[nan_feature_indices]

# Check for NaN values
nan_indices, nan_counts = check_nan_values(features_normalized)

# Print results
if len(nan_indices) > 0:
    print(f"Found {len(nan_indices)} features with NaN values.")
    for idx, count in zip(nan_indices, nan_counts):
        print(f"Feature {idx}: {count} NaN values")
else:
    print("No NaN values found in the array.")

# Get total number of NaN values
total_nans = np.sum(nan_counts)
print(f"Total number of NaN values in the array: {total_nans}")

In [None]:
from sklearn.decomposition import PCA

In [None]:
# Perform PCA
pca = PCA()
pca.fit(features)

# Calculate cumulative explained variance ratio
cumulative_variance_ratio = np.cumsum(pca.explained_variance_ratio_)

# Plot the cumulative explained variance ratio
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(cumulative_variance_ratio) + 1), cumulative_variance_ratio, 'bo-')
plt.xlabel('Number of Components')
plt.ylabel('Cumulative Explained Variance Ratio')
plt.title('Explained Variance Ratio vs. Number of Components')
plt.grid(True)
plt.show()

# Find the number of components that explain 95% of the variance
n_components_95 = np.argmax(cumulative_variance_ratio >= 0.95) + 1
print(f"Number of components explaining 95% of variance: {n_components_95}")

In [None]:
# Perform PCA with 50 components
pca = PCA(n_components=50)
X_pca = pca.fit_transform(features)

# Get the explained variance ratio
explained_variance_ratio = pca.explained_variance_ratio_
cumulative_variance_ratio = np.cumsum(explained_variance_ratio)

In [None]:
# Plot 1: Cumulative Explained Variance
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(range(1, 51), cumulative_variance_ratio, 'bo-')
plt.xlabel('Number of Components')
plt.ylabel('Cumulative Explained Variance Ratio')
plt.title('Cumulative Explained Variance vs. Number of Components')
plt.grid(True)

# Plot 2: Individual Explained Variance
plt.subplot(1, 2, 2)
plt.bar(range(1, 51), explained_variance_ratio)
plt.xlabel('Principal Component')
plt.ylabel('Explained Variance Ratio')
plt.title('Explained Variance Ratio per Principal Component')
plt.tight_layout()
plt.show()

# Print the total explained variance with 50 components
print(f"Total explained variance with 50 components: {cumulative_variance_ratio[-1]:.4f}")

# Find the number of components needed to explain 95% of the variance
n_components_95 = np.argmax(cumulative_variance_ratio >= 0.95) + 1
print(f"Number of components explaining 95% of variance: {n_components_95}")

In [None]:
from sklearn.metrics import silhouette_score
from sklearn.mixture import GaussianMixture

def calculate_cluster_separation(X_pca, min_components, max_components):
    silhouette_scores = []
    all_cluster_labels = []
    n_components_range = range(min_components, max_components)
    for n_components in n_components_range:
        gmm = GaussianMixture(n_components=n_components, random_state=42)
        
        # Fit the GMM and predict cluster labels
        cluster_labels = gmm.fit_predict(X_pca[:, :4])
        all_cluster_labels.append(cluster_labels)
        
        # Calculate silhouette score
        silhouette_avg = silhouette_score(X_pca[:, :4], cluster_labels)
        silhouette_scores.append(silhouette_avg)
        
        print(f"For n_components = {n_components}, the average silhouette score is : {silhouette_avg}")
    
    best_n_components = n_components_range[np.argmax(silhouette_scores)]
    print(f"\nThe best number of components appears to be {best_n_components}")
    
    return silhouette_scores, best_n_components, all_cluster_labels

# silhouette_scores, best_n_clusters, cluster_labels = calculate_cluster_separation(X_pca, 9, 11)

In [None]:
from mpl_toolkits.mplot3d import Axes3D
    
# Create a colormap with 5 distinct colors
colors = ['red', 'yellow', 'green', 'blue', 'purple']
n_bins = 5  # number of distinct colors
cmap = plt.cm.colors.ListedColormap(colors)
    
def plot_3d_pca(X_pca, labels=None):
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    scatter = ax.scatter(X_pca[:, 0], X_pca[:, 1], X_pca[:, 2], c=labels, cmap=cmap)
    
    ax.set_xlabel('PC1')
    ax.set_ylabel('PC2')
    ax.set_zlabel('PC3')
    ax.set_title('First Three Principal Components')
    
    if labels is not None:
        plt.colorbar(scatter, label='Cluster')
    
    plt.show()

plot_3d_pca(X_pca[:, :3], labels=cluster_labels)  # cluster_labels from KMeans

In [None]:
# Extract the first 4 principal components
first_four_pcs = X_pca[:, :4]

# Create a figure with subplots
fig, axs = plt.subplots(2, 3, figsize=(20, 12))
fig.subplots_adjust(right=0.85)  # Make room for the colorbar
axs = axs.ravel()  # Flatten the 2D array of axes for easier indexing

# Plot each pair of principal components
pairs = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
for idx, (i, j) in enumerate(pairs):
    scatter = axs[idx].scatter(first_four_pcs[:, i], first_four_pcs[:, j], 
                               c=cluster_labels, cmap=cmap, alpha=0.6)
    axs[idx].set_xlabel(f'PC{i+1}')
    axs[idx].set_ylabel(f'PC{j+1}')
    axs[idx].set_title(f'PC{i+1} vs PC{j+1}')

# Add a colorbar to the right of the subplots
cbar_ax = fig.add_axes([0.88, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
cbar = fig.colorbar(scatter, cax=cbar_ax)
cbar.set_label('Cluster')

plt.suptitle('Pairwise PCA Component Plots with Cluster Coloring', fontsize=16)
plt.tight_layout(rect=[0, 0.03, 0.85, 0.95])  # Adjust layout to accommodate suptitle
plt.show()

---

In [None]:
session_date = '2024-08-23'
trial_id = 'DSS_001'

In [None]:
hdf5_file = os.path.join(output_folder, f'{session_date}/{trial_id}/{trial_id}_output_data.h5')

with h5py.File(hdf5_file, 'r') as handle:
    features = handle['features'][:]
    labels = handle['labels'][:]
    properties = handle['properties'][:]
    metrics = handle['metrics'][:]
    
epoch = 39

In [None]:
from mpl_toolkits.mplot3d import Axes3D

def plot_3d_pca(X_pca, labels=None):
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    scatter = ax.scatter(X_pca[:, 0], X_pca[:, 1], X_pca[:, 2], c=labels, cmap='rainbow')
    
    ax.set_xlabel('PC1')
    ax.set_ylabel('PC2')
    ax.set_zlabel('PC3')
    ax.set_title('First Three Principal Components')
    
    if labels is not None:
        plt.colorbar(scatter, label='Cluster')
    
    plt.show()

plot_3d_pca(features[epoch, :, :3], labels=labels[epoch, :]) 

In [None]:
# Extract the first 4 principal components
first_four_pcs = features[epoch, :, :4]

# Create a figure with subplots
fig, axs = plt.subplots(2, 3, figsize=(20, 12))
fig.subplots_adjust(right=0.85)  # Make room for the colorbar
axs = axs.ravel()  # Flatten the 2D array of axes for easier indexing

# Plot each pair of principal components
pairs = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
for idx, (i, j) in enumerate(pairs):
    scatter = axs[idx].scatter(first_four_pcs[:, i], first_four_pcs[:, j], 
                               c=labels[epoch, :], cmap='rainbow', alpha=0.6)
    axs[idx].set_xlabel(f'PC{i+1}')
    axs[idx].set_ylabel(f'PC{j+1}')
    axs[idx].set_title(f'PC{i+1} vs PC{j+1}')

# Add a colorbar to the right of the subplots
cbar_ax = fig.add_axes([0.88, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
cbar = fig.colorbar(scatter, cax=cbar_ax)
cbar.set_label('Cluster')

plt.suptitle('Pairwise PCA Component Plots with Cluster Coloring', fontsize=16)
plt.tight_layout(rect=[0, 0.03, 0.85, 0.95])  # Adjust layout to accommodate suptitle
plt.show()

In [None]:
evaluation.animate_plot(evaluation.plot_cluster_distribution, plot_args=(labels,), epoch_start=76, epoch_end=80)

In [None]:
evaluation.animate_plot(evaluation.plot_variance, plot_args=(features,20), epoch_start=90, epoch_end=99)

In [None]:
evaluation.plot_epoch_clusters(features, labels, 50, num_components=2)

## 5. Compare DeepSpikeSort output

### Create Sorting object from DSS output

In [None]:
dss_output = np.zeros(len(labels[epoch]), dtype=[('peak_index', int), ('sample_index', int), ('channel_index', int), ('amplitude', int), ('unit_index', int)])

dss_output['peak_index'] = properties['peak_index']
dss_output['sample_index'] = properties['sample_index']
dss_output['channel_index'] = properties['channel_index']
dss_output['amplitude'] = properties['amplitude']
dss_output['unit_index'] = labels[epoch]

display(pd.DataFrame(dss_output))

In [None]:
dss_filtered = comparison.filter_samples_duplicate(dss_output)
dss_times = dss_output['sample_index']
dss_labels = dss_output['unit_index']

print(f'Samples: {len(dss_labels)}\n')
print(util.format_value_counts(dss_labels))

In [None]:
# Create custom NumpySorting object from DeepSpikeSort output
sorting_dss = comparison.create_numpy_sorting(dss_times, dss_labels, 30000)
sorting_dss

### Create Sorting object from NWB file

In [None]:
# Create a boolean mask
units_selected_file = os.path.join(output_folder, f'{session_date}/{trial_id}/{trial_id}_units_selected.npy')
# units_selected_file = os.path.join(output_folder, f'{session_date}/{trial_id}/{trial_id}_selected_units.npy')
units_selected = np.load(units_selected_file).astype(int)

mask_selected = np.isin(peaks_matched['unit_index'], units_selected)

# Filter the array
peaks_selected = peaks_matched[mask_selected]
display(pd.DataFrame(peaks_selected))

In [None]:
# peaks_filtered = comparison.filter_samples_duplicate(peaks_selected)
peak_times = peaks_selected['time']
peak_units = peaks_selected['unit_index']

print(f'Samples: {len(peak_units)}\n')
print(util.format_value_counts(peak_units))

In [None]:
sorting_peaks = comparison.create_numpy_sorting(peak_times, peak_units, 30000)
sorting_peaks

### Compare Sorting objects

In [None]:
# Run the comparison
cmp_dss_peaks = si.compare_two_sorters(
    sorting1=sorting_peaks,
    sorting2=sorting_dss,
    sorting1_name='DeepSpikeSort',
    sorting2_name='Kilosort',
    delta_time=0,
    verbose=True
)

In [None]:
# In order to check which units were matched, the `get_matching` method can be used.
# If units are not matched they are listed as -1.
dss_to_peaks = cmp_dss_peaks.get_matching()[1]
display(dss_to_peaks)

In [None]:
# Some useful internal dataframes help to check the match and count
#  like **match_event_count** or **agreement_scores**
display(cmp_dss_peaks.match_event_count)
display(cmp_dss_peaks.agreement_scores)

In [None]:
# We can check the agreement matrix to inspect the matching.
si.plot_agreement_matrix(cmp_dss_peaks)

In [None]:
np.sum(np.sum(cmp_dss_peaks.match_event_count))

## 6. Inspect mismatched peaks

In [None]:
cluster_label = 0
dss_output_cluster = np.sort(preprocessing.get_unit(dss_output, cluster_label), order='peak_index')

pd.DataFrame(dss_output_cluster)

In [None]:
duplicate_peaks = evaluation.get_duplicate_peaks(dss_output_cluster)

pd.DataFrame(duplicate_peaks[:10])

In [None]:
channel_ind = 50
neighbor_channels = np.sort(np.append(preprocessing.get_channel_neighbors(channels, channel_ind, 80)['channel_index'], channel_ind))

neighbor_channels

In [None]:
plotting.plot_trace_waveform(recording_preprocessed, dss_output_cluster['sample_index'][0], neighbor_channels)

In [None]:
plotting.plot_unit_waveform(recording_preprocessed, dss_output_cluster, cluster_label, 361, False, 25)

In [None]:
unit_map = {idx: val for idx, val in enumerate(dss_to_peaks)}

labels_st1, labels_st2 = si.do_score_labels(sorting_dss, sorting_peaks, 0, unit_map)
print(labels_st1)
print(labels_st2)

In [None]:
dss_scores = comparison.get_scores(labels_st1, [0,1,2])
pd.DataFrame(dss_scores)

In [None]:
dss_matched, dss_mismatched = comparison.filter_samples_on_match(labels_st1, [0,1,2], dss_filtered)

### Matched samples

In [None]:
cluster_label = 0
dss_matched_cluster = np.sort(preprocessing.get_unit(dss_matched, cluster_label), order='unit_index')
pd.DataFrame(dss_matched_cluster)

In [None]:
row_index = 0
trace_index = dss_matched_cluster['index'][row_index]
trace_index

In [None]:
channel_ind = peaks_matched['channel_index'][peaks_matched['sample_index'] == times[trace_index]] 
channel_ind

In [None]:
plotting.plot_unit_waveform(recording_preprocessed, dss_matched_cluster, cluster_label, channel_ind[0], False, 100)

In [None]:
def plot_trace_image(trace_reshaped):
    """
    Plots a 3D image of waveforms at the specified time frame and all channels.
 
    Args:
        recording (obj): A RecordingExtractor object created from an NWB file using SpikeInterface.
        sample_frame (int): A frame number when a sample occurred.
 
    Yields:
        obj: A 3D image of waveforms.
    """
    trace_transposed = np.transpose(trace_reshaped, (1, 0, 2))

    vmin = trace_transposed.min()
    vmax = trace_transposed.max()

    plt.figure(figsize=(8, 10))
    for i in range(trace_reshaped.shape[2]):
        plt.subplot(1, 2, i + 1)
        plt.imshow(trace_transposed[:, :, i], cmap='viridis', vmin=vmin, vmax=vmax)
    # Set x and y labels for the plot
    plt.text(0.5, 0.05, 'time (frames)', ha='center', va='center', transform=plt.gcf().transFigure)
    plt.text(0.01, 0.5, 'channel', ha='center', va='center', rotation='vertical', transform=plt.gcf().transFigure)
    # Add colorbar for the plot
    cax = plt.axes([0.15, 0.95, 0.7, 0.03])  # [left, bottom, width, height]
    cb = plt.colorbar(cax=cax, orientation='horizontal')
    
    plt.show()

In [None]:
plotting.plot_trace_waveform(recording_preprocessed, dss_matched_cluster['sample_index'][:9], channel_ind[0])
plot_trace_image(peaks_dataset[trace_index][0])

### Mismatched samples

In [None]:
cluster_label = 1
dss_mismatched_cluster = preprocessing.get_unit(dss_mismatched, cluster_label)
pd.DataFrame(dss_mismatched_cluster)

In [None]:
row_index = 0
trace_index = dss_mismatched_cluster['index'][row_index]
trace_index

In [None]:
channel_ind = peaks_matched['channel_index'][peaks_matched['sample_index'] == times[trace_index]]
channel_ind

In [None]:
channel_neighbors = preprocessing.get_channel_neighbors(channels, channel_ind[1], 40)['channel_index']
channel_neighbors = np.sort(np.append(channel_neighbors, channel_ind[1]))
channel_neighbors

In [None]:
plotting.plot_trace_waveform(recording_preprocessed, dss_mismatched_cluster['sample_index'][row_index], channel_neighbors)
plot_trace_image(peaks_dataset[trace_index][0])

In [None]:
features[0].shape

In [None]:
from isosplit6 import isosplit6

labels = isosplit6(features[0])

In [None]:
set(labels)