# Advanced Spike Sorting and Clustering for Multi-Electrode Array (MEA) Data
This notebook demonstrates an advanced pipeline for spike sorting and clustering analysis of multi-electrode array (MEA) in vivo electrophysiology data. The analysis includes data handling, preprocessing, spike sorting, feature extraction, dimensionality reduction, clustering, and network-level analyses.

## Objectives
- **Spike Sorting**: Perform advanced spike sorting to classify spikes from densely packed electrodes.
- **Clustering and Visualization**: Apply dimensionality reduction and clustering algorithms to visualize and validate spike sorting results.
- **Spike Train and Connectivity Analysis**: Analyze spike train dynamics and compute network connectivity using advanced methods.


In [None]:
# Import necessary libraries
import neo  # For data handling
import spikeinterface as si  # Core module for SpikeInterface
import spikeinterface.extractors as se  # For data loading and extraction
import spikeinterface.preprocessing as sp  # For data preprocessing
import spikeinterface.sorters as ss  # For spike sorting algorithms
import spikeinterface.postprocessing as spost  # For postprocessing sorted data
import spikeinterface.qualitymetrics as sq  # For quality control metrics
import elephant.spike_train_correlation as escorr  # For spike train correlations
import elephant.connectivity as econn  # For network connectivity analysis
import pyspike as ps  # For synchrony and burst detection
import quantities as pq  # For unit handling
import numpy as np  # For numerical operations
import matplotlib.pyplot as plt  # For static visualization
import plotly.express as px  # For interactive visualization
from sklearn.decomposition import PCA  # For dimensionality reduction
from sklearn.manifold import TSNE  # For t-SNE visualization
from umap import UMAP  # For UMAP visualization
from sklearn.cluster import AffinityPropagation, DBSCAN, AgglomerativeClustering  # For clustering algorithms
import hdbscan  # For HDBSCAN clustering
from neo.io import NeuralynxIO, BlackrockIO  # Example IOs for Neo data loading


## 1. Data Handling
Load the MEA data using `Neo` and convert it to the `SpikeInterface` format. Various types of `Neo` IO classes can be used depending on the data format.

In [None]:
def load_mea_data(file_path, io_type='NeuralynxIO'):
    """
    Load MEA data using Neo and convert to SpikeInterface format.
    
    Args:
    - file_path (str): Path to the file containing raw data.
    - io_type (str): Type of Neo IO to use ('NeuralynxIO', 'BlackrockIO', etc.).
    
    Returns:
    - recording (si.BaseRecording): Loaded data in SpikeInterface's RecordingExtractor format.
    """
    try:
        io_types = {
            'NeuralynxIO': NeuralynxIO(dirname=file_path),
            'BlackrockIO': BlackrockIO(filename=file_path)
        }
        
        if io_type not in io_types:
            raise ValueError(f"Unsupported IO type: {io_type}")
        
        reader = io_types[io_type]
        block = reader.read_block()
        analog_signal = block.segments[0].analogsignals[0]
        recording = se.NeoRecordingExtractor(analog_signal)
        return recording
    except Exception as e:
        print(f"Error loading MEA data: {e}")
        raise

# Example usage
file_path = 'data/sample_mea_data'  # Replace with your data file path
recording = load_mea_data(file_path)

## 2. Preprocessing
Apply bandpass filtering, normalization, and optional noise reduction techniques like Common Average Referencing (CAR) or Independent Component Analysis (ICA).

In [None]:
def preprocess_data(recording, freq_min=300, freq_max=6000, noise_reduction='CAR'):
    """
    Apply bandpass filter, normalization, and optional noise reduction techniques.
    
    Args:
    - recording (si.BaseRecording): Loaded data in SpikeInterface's RecordingExtractor format.
    - freq_min (int): Minimum frequency for bandpass filter.
    - freq_max (int): Maximum frequency for bandpass filter.
    - noise_reduction (str): Noise reduction technique ('CAR', 'ICA', etc.).
    
    Returns:
    - recording_preprocessed (si.BaseRecording): Preprocessed data.
    """
    try:
        recording_filtered = sp.bandpass_filter(recording, freq_min=freq_min, freq_max=freq_max)
        
        if noise_reduction == 'CAR':
            recording_filtered = sp.common_reference(recording_filtered, reference='median')
        elif noise_reduction == 'ICA':
            recording_filtered = sp.ica(recording_filtered)
        
        recording_normalized = sp.zscore(recording_filtered)
        return recording_normalized
    except Exception as e:
        print(f"Error in preprocessing data: {e}")
        raise

# Example usage
recording_preprocessed = preprocess_data(recording)

## 2. Preprocessing
Apply bandpass filtering, normalization, and optional noise reduction techniques like Common Average Referencing (CAR) or Independent Component Analysis (ICA).

In [None]:
def preprocess_data(recording, freq_min=300, freq_max=6000, noise_reduction='CAR'):
    """
    Apply bandpass filter, normalization, and optional noise reduction techniques.
    
    Args:
    - recording (si.BaseRecording): Loaded data in SpikeInterface's RecordingExtractor format.
    - freq_min (int): Minimum frequency for bandpass filter.
    - freq_max (int): Maximum frequency for bandpass filter.
    - noise_reduction (str): Noise reduction technique ('CAR', 'ICA', etc.).
    
    Returns:
    - recording_preprocessed (si.BaseRecording): Preprocessed data.
    """
    try:
        recording_filtered = sp.bandpass_filter(recording, freq_min=freq_min, freq_max=freq_max)
        
        if noise_reduction == 'CAR':
            recording_filtered = sp.common_reference(recording_filtered, reference='median')
        elif noise_reduction == 'ICA':
            recording_filtered = sp.ica(recording_filtered)
        
        recording_normalized = sp.zscore(recording_filtered)
        return recording_normalized
    except Exception as e:
        print(f"Error in preprocessing data: {e}")
        raise

# Example usage
recording_preprocessed = preprocess_data(recording)

## 3. Spike Sorting
Perform spike sorting using advanced sorters like Kilosort, IronClust, and MountainSort. Custom parameters can be provided for fine-tuning.

In [None]:
def perform_spike_sorting(recording, sorter_name='kilosort2', custom_params=None):
    """
    Perform spike sorting on MEA data using advanced sorters.
    
    Args:
    - recording (si.BaseRecording): Preprocessed recording data.
    - sorter_name (str): Name of the spike sorting algorithm (e.g., 'kilosort2').
    - custom_params (dict): Optional custom parameters for the sorting algorithm.
    
    Returns:
    - sorting (si.BaseSorting): Sorted spike data.
    """
    try:
        sorter_params = custom_params if custom_params else ss.get_default_params(sorter_name)
        sorting = ss.run_sorter(sorter_name, recording, output_folder='sorting_output', **sorter_params)
        return sorting
    except Exception as e:
        print(f"Error in spike sorting: {e}")
        raise

# Example usage
sorting = perform_spike_sorting(recording_preprocessed)

## 4. Postprocessing and Quality Metrics
Extract waveforms and compute quality metrics like Signal-to-Noise Ratio (SNR), Inter-Spike Interval (ISI) violations, and firing rates for validation.

In [None]:
def postprocess_sorting(sorting, recording):
    """
    Extract waveforms and compute quality metrics for sorted units.
    
    Args:
    - sorting (si.BaseSorting): Sorted spike data.
    - recording (si.BaseRecording): Preprocessed recording data.
    
    Returns:
    - waveform_extractor (si.WaveformExtractor): Extracted waveforms.
    - quality_metrics (dict): Quality metrics for each sorted unit.
    """
    try:
        waveform_extractor = spost.WaveformExtractor.create(recording, sorting, folder='waveforms', remove_existing_folder=True)
        waveform_extractor.set_params(ms_before=1.5, ms_after=2.5)
        waveform_extractor.run()
        quality_metrics = sq.compute_quality_metrics(waveform_extractor, metric_names=['snr', 'isi_violation', 'firing_rate'])
        return waveform_extractor, quality_metrics
    except Exception as e:
        print(f"Error in postprocessing sorting: {e}")
        raise

# Example usage
waveform_extractor, quality_metrics = postprocess_sorting(sorting, recording_preprocessed)
print("Quality Metrics:", quality_metrics)

## 5. Feature Extraction and Dimensionality Reduction
Apply feature extraction methods (e.g., PCA, t-SNE, UMAP) to reduce the dimensionality of the high-dimensional spike features for clustering.

In [None]:
def cluster_spikes(features, method='HDBSCAN', **kwargs):
    """
    Cluster spikes using specified clustering algorithm.
    
    Args:
    - features (np.ndarray): Feature matrix for clustering.
    - method (str): Clustering method ('AffinityPropagation', 'HDBSCAN', 'DBSCAN', 'AgglomerativeClustering').
    
    Returns:
    - labels (np.ndarray): Cluster labels for each spike.
    """
    try:
        if method == 'AffinityPropagation':
            clustering = AffinityPropagation().fit(features)
        elif method == 'HDBSCAN':
            clustering = hdbscan.HDBSCAN().fit(features)
        elif method == 'DBSCAN':
            clustering = DBSCAN(eps=0.5, min_samples=5).fit(features)
        elif method == 'AgglomerativeClustering':
            clustering = AgglomerativeClustering(n_clusters=3).fit(features)
        else:
            raise ValueError(f"Unsupported clustering method: {method}")
        return clustering.labels_
    except Exception as e:
        print(f"Error in clustering spikes: {e}")
        raise

# Example usage
labels = cluster_spikes(reduced_features, method='HDBSCAN')
print("Cluster Labels:", labels)

## 6. Clustering and Validation
Cluster the reduced features using advanced clustering algorithms like Affinity Propagation, HDBSCAN, DBSCAN, or Agglomerative Clustering to validate sorted units.

In [None]:
def cluster_spikes(features, method='HDBSCAN', **kwargs):
    """
    Cluster spikes using specified clustering algorithm.
    
    Args:
    - features (np.ndarray): Feature matrix for clustering.
    - method (str): Clustering method ('AffinityPropagation', 'HDBSCAN', 'DBSCAN', 'AgglomerativeClustering').
    
    Returns:
    - labels (np.ndarray): Cluster labels for each spike.
    """
    try:
        if method == 'AffinityPropagation':
            clustering = AffinityPropagation().fit(features)
        elif method == 'HDBSCAN':
            clustering = hdbscan.HDBSCAN().fit(features)
        elif method == 'DBSCAN':
            clustering = DBSCAN(eps=0.5, min_samples=5).fit(features)
        elif method == 'AgglomerativeClustering':
            clustering = AgglomerativeClustering(n_clusters=3).fit(features)
        else:
            raise ValueError(f"Unsupported clustering method: {method}")
        return clustering.labels_
    except Exception as e:
        print(f"Error in clustering spikes: {e}")
        raise

# Example usage
labels = cluster_spikes(reduced_features, method='HDBSCAN')
print("Cluster Labels:", labels)

## 7. Spike Train and Connectivity Analysis
Analyze spike train correlations and compute network connectivity using methods like Granger causality and Directed Transfer Function (DTF).

In [None]:
def compute_spike_train_correlations(sorting):
    """
    Compute cross-correlograms for spike trains using Elephant.
    
    Args:
    - sorting (si.BaseSorting): Sorted spike data.
    
    Returns:
    - correlations (np.ndarray): Spike train correlation matrix.
    """
    try:
        spike_trains = [sorting.get_unit_spike_train(unit_id) for unit_id in sorting.unit_ids]
        correlations = escorr.corrcoef(spike_trains)
        return correlations
    except Exception as e:
        print(f"Error in computing spike train correlations: {e}")
        raise

# Example usage
correlations = compute_spike_train_correlations(sorting)
print("Spike Train Correlation Matrix:", correlations)

## 8. Visualization of Clustering Results
Visualize the clustering results using interactive 3D plots.

In [None]:
def plot_clusters(reduced_features, labels):
    """
    Visualize clustering results in 3D.
    
    Args:
    - reduced_features (np.ndarray): Reduced feature matrix.
    - labels (np.ndarray): Cluster labels.
    """
    try:
        fig = px.scatter_3d(x=reduced_features[:, 0], y=reduced_features[:, 1], z=reduced_features[:, 2], color=labels)
        fig.update_layout(title='3D Clustering Visualization', xaxis_title='Component 1', yaxis_title='Component 2', zaxis_title='Component 3')
        fig.show()
    except Exception as e:
        print(f"Error in plotting clusters: {e}")
        raise

# Example usage
plot_clusters(reduced_features, labels)

## Conclusion
This notebook provides an advanced pipeline for spike sorting and clustering analysis of MEA data. The analysis covers data handling, preprocessing, spike sorting, feature extraction, dimensionality reduction, clustering, and network-level analyses.