# Spike Sorting and Firing Rate Analysis

## Introduction

This notebook provides a comprehensive analysis of spike sorting and firing rate analysis for in vivo electrophysiology data. It uses state-of-the-art libraries such as Neo, SpikeInterface, and Elephant to handle data, perform spike sorting, analyze spike train dynamics, and visualize results.

### Objectives
- Detect and sort spikes from raw electrophysiological data.
- Compute firing rates and analyze spike train dynamics.
- Visualize spike train data with raster plots, ISI histograms, and more.

### Methods
- **Spike Detection**: Threshold-based and template matching approaches.
- **Feature Extraction**: Extract key features (e.g., spike width, amplitude).
- **Clustering Algorithms**: K-means, Gaussian Mixture Models (GMM), DBSCAN.
- **Spike Train Analysis**: Interspike intervals (ISI), peri-stimulus time histograms (PSTH), burst detection.

### Tools
- **Python Libraries**: Neo, SpikeInterface, Elephant, SciPy, PyWavelets, PySpike.
- **Visualization**: Matplotlib, Plotly.

### Outcome
- Sorted spikes, firing rate histograms, raster plots, and autocorrelograms.

## Setup and Installation

Before starting the analysis, we need to install the required libraries if they are not already installed.

In [None]:
# Install necessary libraries (use %pip for Jupyter compatibility)
%pip install neo spikeinterface elephant quantities matplotlib plotly scikit-learn scipy pywavelets

In [None]:
# Import libraries
import neo  # For data handling
import spikeinterface as si  # Core module for SpikeInterface
import spikeinterface.extractors as se  # For data loading and extraction
import spikeinterface.preprocessing as sp  # For data preprocessing
import spikeinterface.sorters as ss  # For spike sorting algorithms
import spikeinterface.postprocessing as spost  # For postprocessing sorted data
import spikeinterface.qualitymetrics as sq  # For quality control metrics
import elephant  # For advanced analysis on spike trains
import elephant.statistics as es  # For statistical measures like firing rates
import elephant.spike_train_correlation as escorr  # For correlation analysis
import elephant.spike_train_generation as estg  # For spike train generation and burst detection
import quantities as pq  # For unit handling
import matplotlib.pyplot as plt  # For static visualization
import plotly.express as px  # For interactive visualization with Plotly
from neo.io import NeuralynxIO  # Example IO for Neo data loading
import numpy as np  # For numerical operations
from scipy.cluster.vq import kmeans  # K-means clustering
from sklearn.mixture import GaussianMixture  # GMM clustering
from sklearn.cluster import DBSCAN  # DBSCAN clustering

## Data Loading and Preprocessing

This section covers the loading of electrophysiological data using Neo and SpikeInterface, followed by preprocessing steps like bandpass filtering and common reference.

In [None]:
import neo
from neo.io import NeuralynxIO
import spikeinterface.extractors as se

# Load electrophysiological data using Neo and convert to SpikeInterface format
def load_data(file_path):
    reader = NeuralynxIO(dirname=file_path)
    block = reader.read_block()
    segment = block.segments[0]
    analog_signal = segment.analogsignals[0]
    recording = se.NeoRecordingExtractor(analog_signal)
    return recording

# Example usage
file_path = 'data/sample_data'  # Update with actual path
recording = load_data(file_path)
print("Data Loaded Successfully")

## Preprocessing
Preprocessing the data is essential for reducing noise and enhancing signal quality. Apply bandpass filtering and common referencing to the loaded data for preprocessing.

In [None]:
# Load data using Neo
def load_data(file_path):
    """
    Load electrophysiological data using Neo and convert to SpikeInterface format.
    
    Args:
    - file_path (str): Path to the file containing raw data.
    
    Returns:
    - recording (si.BaseRecording): Loaded data in SpikeInterface's RecordingExtractor format.
    """
    reader = NeuralynxIO(dirname=file_path)
    block = reader.read_block()
    segment = block.segments[0]
    analog_signal = segment.analogsignals[0]
    recording = se.NeoRecordingExtractor(analog_signal)
    return recording

# Preprocess data
def preprocess_data(recording, freq_min=300, freq_max=3000, common_ref_type='median'):
    """
    Preprocess the loaded data by applying bandpass filtering and common reference.
    
    Args:
    - recording (si.BaseRecording): Loaded data in SpikeInterface's RecordingExtractor format.
    - freq_min (int): Minimum frequency for bandpass filter.
    - freq_max (int): Maximum frequency for bandpass filter.
    - common_ref_type (str): Type of common reference ('median', 'average', etc.).
    
    Returns:
    - recording_preprocessed (si.BaseRecording): Preprocessed data.
    """
    recording_bp = sp.bandpass_filter(recording, freq_min=freq_min, freq_max=freq_max)
    recording_cmr = sp.common_reference(recording_bp, reference=common_ref_type)
    return recording_cmr

# Load and preprocess example data
file_path = 'data/sample_data'  # Adjust this to your dataset path
recording = load_data(file_path)
recording_preprocessed = preprocess_data(recording)

## Spike Sorting
Perform spike sorting on the preprocessed data using various spike sorting algorithms (e.g., Kilosort2, SpykingCircus).

In [None]:
def sort_spikes(recording, sorter_name='kilosort2'):
    """
    Perform spike sorting on the preprocessed data.
    
    Args:
    - recording (si.BaseRecording): Preprocessed recording data.
    - sorter_name (str): Name of the sorting algorithm to use (e.g., 'kilosort2').
    
    Returns:
    - sorting (si.BaseSorting): Sorted spike data.
    """
    sorter_params = ss.get_default_params(sorter_name)
    sorting = ss.run_sorter(sorter_name, recording, output_folder='sorting_output', **sorter_params)
    return sorting

# Perform spike sorting
sorting = sort_spikes(recording_preprocessed)


## Postprocessing and Quality Metrics

Postprocess sorted spikes to extract waveforms and compute quality metrics such as SNR, ISI violations, and firing rates.


In [None]:
def postprocess_sorting(sorting, recording):
    """
    Postprocess the sorted spikes to extract features and waveforms.
    
    Args:
    - sorting (si.BaseSorting): Sorted spike data.
    - recording (si.BaseRecording): Preprocessed recording data.
    
    Returns:
    - waveform_extractor (si.WaveformExtractor): Extracted waveforms.
    """
    waveform_extractor = spost.WaveformExtractor.create(recording, sorting, folder='waveforms', remove_existing_folder=True)
    waveform_extractor.set_params(ms_before=1.5, ms_after=2.5)
    waveform_extractor.run()
    return waveform_extractor

def compute_quality_metrics(waveform_extractor):
    """
    Compute quality metrics for sorted units.
    
    Args:
    - waveform_extractor (si.WaveformExtractor): Extracted waveforms.
    
    Returns:
    - metrics (dict): Quality metrics for each sorted unit.
    """
    metrics = sq.compute_quality_metrics(waveform_extractor, metric_names=['snr', 'isi_violation', 'firing_rate'])
    return metrics

# Postprocess sorting and compute quality metrics
waveform_extractor = postprocess_sorting(sorting, recording_preprocessed)
quality_metrics = compute_quality_metrics(waveform_extractor)
print("Quality Metrics:", quality_metrics)

## Feature Extraction and Clustering

Extract spike waveform features and perform clustering using K-means, GMM, or DBSCAN to group similar spike events into units.

In [None]:
def extract_features(waveform_extractor):
    """
    Extract features from the sorted spike waveforms for clustering.
    
    Args:
    - waveform_extractor (si.WaveformExtractor): Extracted waveforms.
    
    Returns:
    - features (np.ndarray): Feature matrix (e.g., spike width, amplitude, etc.).
    """
    waveforms = waveform_extractor.get_waveforms()
    # Example feature extraction: mean and std of spike waveforms
    spike_width = np.mean(np.abs(waveforms), axis=(1, 2))
    spike_amplitude = np.std(waveforms, axis=(1, 2))
    features = np.column_stack((spike_width, spike_amplitude))
    return features

def cluster_spikes(features, method='kmeans'):
    """
    Cluster spikes using specified clustering algorithm.
    
    Args:
    - features (np.ndarray): Feature matrix for clustering.
    - method (str): Clustering method ('kmeans', 'gmm', 'dbscan').
    
    Returns:
    - labels (np.ndarray): Cluster labels for each spike.
    """
    if method == 'kmeans':
        centroids, labels = kmeans(features, 3)
    elif method == 'gmm':
        gmm = GaussianMixture(n_components=3).fit(features)
        labels = gmm.predict(features)
    elif method == 'dbscan':
        db = DBSCAN(eps=0.5, min_samples=5).fit(features)
        labels = db.labels_
    else:
        raise ValueError("Unsupported clustering method.")
    return labels

# Extract features and perform clustering
features = extract_features(waveform_extractor)
labels = cluster_spikes(features, method='kmeans')
print("Cluster Labels:", labels)


## Spike Train Analysis
Spike train analysis is essential for understanding how neurons encode and process information in vivo. By examining the temporal patterns of spikes recorded from single or limited-channel electrodes, we can gain insights into various aspects of neuronal activity such as firing rates, synchrony, and burst dynamics. This is particularly important in studying sensory processing, motor control, and cognitive functions in behaving animals. Here, we perform spike train analysis to compute firing rates, spike train correlations, interspike intervals (ISI), and burst detection for a comprehensive understanding of neuronal coding and network dynamics.

### Firing Rate Analysis
Firing rate analysis involves calculating the number of spikes per unit of time for each neuron (unit). This metric provides a fundamental measure of neuronal excitability and is often used to quantify how neurons respond to sensory stimuli, motor tasks, or cognitive processes. For in vivo single/limited channel electrophysiology, firing rates can reveal how individual neurons or populations of neurons are modulated by behavior, environmental changes, or experimental conditions. We compute the mean firing rate over time and can also visualize firing rate changes during specific experimental epochs.

### Spike Train Correlations
Spike train correlation analysis helps in understanding the temporal coordination and synchrony between different neurons. In the context of in vivo recordings, analyzing the correlations between spike trains of multiple units can provide evidence for shared synaptic inputs, network connectivity, or coordinated activity patterns that are associated with specific behavioral states. By computing pairwise correlations between spike trains, we can infer functional connectivity and identify neurons that may work together to encode sensory inputs or motor outputs.

### Interspike Interval (ISI) Analysis
Interspike interval (ISI) analysis involves calculating the time intervals between consecutive spikes for each neuron. This analysis provides insights into the regularity and variability of neuronal firing, which can indicate the presence of different firing patterns such as tonic firing, burst firing, or irregular spiking. In the context of in vivo electrophysiology, ISI analysis can reveal the underlying mechanisms of neural coding, such as adaptation, refractory periods, or network oscillations, which are crucial for understanding how the brain processes information in real-time.

### Burst Detection
Burst detection is the process of identifying clusters of spikes that occur within a short time interval. Bursting activity is often seen in certain types of neurons, and it can play a critical role in synaptic plasticity, signal amplification, and rhythmic brain activities. In in vivo electrophysiology studies, burst detection can provide insights into neuronal communication and synchronization, as well as help identify potential pathological conditions such as epilepsy. Detecting bursts and analyzing their properties, such as duration, frequency, and inter-burst intervals, can help elucidate the temporal organization of neural circuits and their functional states.

In [None]:
def calculate_firing_rate(sorting, bin_size=100 * pq.ms):
    """
    Calculate the mean firing rate from the sorted spike data.
    
    Args:
    - sorting (si.BaseSorting): Sorted spike data.
    - bin_size (Quantity): Time bin size for firing rate calculation.
    
    Returns:
    - firing_rates (dict): Dictionary of firing rates for each unit.
    """
    firing_rates = {}
    for unit_id in sorting.unit_ids:
        spike_train = sorting.get_unit_spike_train(unit_id) * pq.s
        rate = es.mean_firing_rate(spike_train, t_start=0 * pq.s, t_stop=max(spike_train), bin_size=bin_size)
        firing_rates[unit_id] = rate
    return firing_rates

# Calculate firing rates
firing_rates = calculate_firing_rate(sorting)
print("Firing Rates (Hz):", firing_rates)

In [None]:
def analyze_spike_train_correlation(sorting, method='pearson'):
    """
    Analyze spike train correlations between units using binned spike trains.
    
    Args:
    - sorting (si.BaseSorting): Sorted spike data.
    - method (str): Correlation method ('pearson', 'spearman', etc.).
    
    Returns:
    - correlation_matrix (np.ndarray): Correlation matrix of spike trains.
    """
    # Convert sorted spikes to binned spike trains for correlation analysis
    binned_spiketrains = [es.BinnedSpikeTrain(sorting.get_unit_spike_train(unit_id) * pq.s, binsize=5 * pq.ms) for unit_id in sorting.unit_ids]
    correlation_matrix = escorr.corrcoef(binned_spiketrains, method=method)
    return correlation_matrix

# Perform spike train correlation analysis
correlation_matrix = analyze_spike_train_correlation(sorting)
print("Correlation Matrix:", correlation_matrix)


In [None]:
def compute_isi_histogram(sorting, unit_id):
    """
    Compute the Interspike Interval (ISI) histogram for a given unit.
    
    Args:
    - sorting (si.BaseSorting): Sorted spike data.
    - unit_id (int): ID of the unit for which to compute ISI.
    
    Returns:
    - isi_hist (np.ndarray): ISI histogram.
    - bins (np.ndarray): Bin edges for ISI histogram.
    """
    spike_train = sorting.get_unit_spike_train(unit_id) * pq.s
    isi = np.diff(spike_train)
    isi_hist, bins = np.histogram(isi, bins=np.arange(0, 1, 0.01))  # ISI in seconds
    return isi_hist, bins

# Compute and plot ISI histogram for a specific unit
isi_hist, bins = compute_isi_histogram(sorting, sorting.unit_ids[0])
print("ISI Histogram:", isi_hist)

def plot_isi_histogram(isi_hist, bins):
    """
    Plot ISI histogram using Matplotlib.
    
    Args:
    - isi_hist (np.ndarray): ISI histogram.
    - bins (np.ndarray): Bin edges for ISI histogram.
    """
    plt.bar(bins[:-1], isi_hist, width=np.diff(bins))
    plt.xlabel('Interspike Interval (s)')
    plt.ylabel('Count')
    plt.title('ISI Histogram')
    plt.show()

# Plot ISI histogram
plot_isi_histogram(isi_hist, bins)


In [None]:
def detect_bursts(spike_train, burst_threshold=3):
    """
    Detect bursts in spike trains based on a threshold ISI value.
    
    Args:
    - spike_train (elephant.SpikeTrain): Spike train data.
    - burst_threshold (float): Threshold ISI (in ms) below which spikes are considered to be in a burst.
    
    Returns:
    - bursts (list of lists): Each list contains spike times that form a burst.
    """
    isi = np.diff(spike_train.magnitude)
    bursts = []
    current_burst = [spike_train[0]]

    for i, interval in enumerate(isi):
        if interval < burst_threshold:
            current_burst.append(spike_train[i + 1])
        else:
            if len(current_burst) > 1:  # Ensure it's a burst, not a single spike
                bursts.append(current_burst)
            current_burst = [spike_train[i + 1]]

    # Include last burst if any
    if len(current_burst) > 1:
        bursts.append(current_burst)
    
    return bursts

# Example burst detection on a spike train
example_spike_train = es.BinnedSpikeTrain(sorting.get_unit_spike_train(sorting.unit_ids[0]) * pq.s, binsize=5 * pq.ms)
bursts = detect_bursts(example_spike_train)
print("Detected Bursts:", bursts)

def plot_bursts(bursts):
    """
    Plot detected bursts using Matplotlib.
    
    Args:
    - bursts (list of lists): Each list contains spike times that form a burst.
    """
    for i, burst in enumerate(bursts):
        plt.plot(burst, np.full(len(burst), i), '|', markersize=10)
    plt.xlabel('Time (s)')
    plt.ylabel('Burst Number')
    plt.title('Detected Bursts')
    plt.show()

# Plot detected bursts
plot_bursts(bursts)

## Visualization
Visualize raster plots, firing rate histograms, correlation matrices, and ISI histograms.

In [None]:
import matplotlib.pyplot as plt  # For static visualizations
import seaborn as sns  # For enhanced visualization with correlation heatmaps
import plotly.express as px  # For interactive visualizations with Plotly
import numpy as np  # For numerical operations

def plot_firing_rate_histogram(firing_rates):
    """
    Plot histogram of firing rates using Plotly for interactive exploration.
    
    Args:
    - firing_rates (dict): Firing rates of units.
    """
    fig = px.histogram(x=list(firing_rates.values()), labels={'x': 'Firing Rate (Hz)'})
    fig.update_layout(title="Firing Rate Histogram", xaxis_title="Firing Rate (Hz)", yaxis_title="Count")
    fig.show()

def plot_raster(sorting):
    """
    Plot raster plot of the spike sorting results.
    
    Args:
    - sorting (si.BaseSorting): Sorted spike data.
    """
    spike_times = [sorting.get_unit_spike_train(unit_id) for unit_id in sorting.unit_ids]
    plt.figure(figsize=(10, 6))
    plt.eventplot(spike_times, colors='black')
    plt.xlabel('Time (s)')
    plt.ylabel('Units')
    plt.title('Raster Plot of Spike Trains')
    plt.show()

def plot_correlation_matrix(correlation_matrix):
    """
    Plot correlation matrix using Seaborn for enhanced visualization.
    
    Args:
    - correlation_matrix (np.ndarray): Correlation matrix of spike trains.
    """
    plt.figure(figsize=(8, 6))
    sns.heatmap(correlation_matrix, annot=True, cmap='viridis', cbar=True)
    plt.title('Spike Train Correlation Matrix')
    plt.xlabel('Neuron Index')
    plt.ylabel('Neuron Index')
    plt.show()

def plot_isi_histogram(isi_hist, bins):
    """
    Plot ISI histogram using Matplotlib.
    
    Args:
    - isi_hist (np.ndarray): ISI histogram.
    - bins (np.ndarray): Bin edges for ISI histogram.
    """
    plt.figure(figsize=(8, 6))
    plt.bar(bins[:-1], isi_hist, width=np.diff(bins), color='blue', edgecolor='black', alpha=0.7)
    plt.xlabel('Interspike Interval (s)')
    plt.ylabel('Count')
    plt.title('Interspike Interval (ISI) Histogram')
    plt.show()

def plot_burst_detection(bursts, spike_train_duration):
    """
    Plot detected bursts in spike trains.
    
    Args:
    - bursts (dict): Burst times for each unit.
    - spike_train_duration (float): Duration of the spike train recording.
    """
    plt.figure(figsize=(10, 6))
    for unit_id, burst_times in bursts.items():
        plt.vlines(burst_times, unit_id - 0.4, unit_id + 0.4, colors='red')
    plt.xlabel('Time (s)')
    plt.ylabel('Units')
    plt.title('Burst Detection in Spike Trains')
    plt.xlim(0, spike_train_duration)
    plt.show()

# Example usage assuming data is prepared
# Replace 'firing_rates', 'sorting', 'correlation_matrix', 'isi_hist', 'bins', and 'bursts' with actual analysis results.

# Example firing rates for visualization
firing_rates = {0: 5.2, 1: 3.5, 2: 7.8, 3: 6.1}

# Generate synthetic data for demonstration
example_sorting = ...  # Replace with actual SpikeInterface sorting object
correlation_matrix = np.random.rand(4, 4)  # Replace with actual correlation matrix
isi_hist = np.random.poisson(5, 100)  # Replace with actual ISI histogram
bins = np.linspace(0, 1, 101)  # Example bin edges for ISI histogram
bursts = {0: [0.2, 0.7, 1.3], 1: [0.5, 1.1], 2: [0.8, 1.5]}  # Replace with actual burst detection data
spike_train_duration = 2.0  # Example duration in seconds

# Plot all visualizations
plot_firing_rate_histogram(firing_rates)
plot_raster(example_sorting)
plot_correlation_matrix(correlation_matrix)
plot_isi_histogram(isi_hist, bins)
plot_burst_detection(bursts, spike_train_duration)


## Conclusion
In this notebook, we performed spike sorting and firing rate analysis using electrophysiological data. We covered data loading, preprocessing, sorting, feature extraction, clustering, spike train analysis, and visualization techniques.