# Single-Unit and Multi-Unit Activity Analysis

This notebook provides an end-to-end pipeline for analyzing single-unit and multi-unit activity (SUA/MUA) in electrophysiological data. The analysis involves loading data, preprocessing, spike sorting, feature extraction, clustering, spike train analysis, and visualization.

### Objectives:
- Isolate and analyze activity from single neurons.
- Aggregate spikes from all recorded neurons for network activity analysis.
- Perform spike-triggered averaging (STA) for sensory processing or motor control analysis.
- Map receptive fields using spike-triggered correlation and decoding approaches.
- Generate detailed neuronal response profiles.

### Methods:
- **Data Handling:** Load electrophysiological data using `Neo` and convert it to `SpikeInterface` format.
- **Preprocessing:** Apply bandpass filtering, notch filtering, and common reference to the data.
- **Spike Sorting:** Use advanced sorting algorithms like `Kilosort` with customizable parameters.
- **Postprocessing and Feature Extraction:** Extract features from sorted spikes using PCA.
- **Clustering:** Perform clustering using methods like `GMM`, `DBSCAN`, and `Agglomerative Clustering`.
- **Spike Train Analysis:** Analyze spike trains for burst detection, synchrony, cross-correlation, and STA.
- **Visualization:** Use `Matplotlib` and `Plotly` for static and interactive visualizations.

## Import Libraries

We start by importing the necessary Python libraries for data handling, preprocessing, spike sorting, and visualization.


In [None]:
# Import necessary libraries
import neo  # For data handling
import spikeinterface as si  # Core module for SpikeInterface
import spikeinterface.extractors as se  # For data loading and extraction
import spikeinterface.preprocessing as sp  # For data preprocessing
import spikeinterface.sorters as ss  # For spike sorting algorithms
import spikeinterface.postprocessing as spost  # For postprocessing sorted data
import spikeinterface.qualitymetrics as sq  # For quality control metrics
import elephant  # For advanced analysis on spike trains
import elephant.statistics as es  # For statistical measures like firing rates
import elephant.sta as esta  # For spike-triggered averaging
import elephant.conversion as econv  # For converting spike trains
import elephant.spike_train_correlation as escorr  # For correlation analysis
import elephant.spectral as esp  # For spectral analysis
import pyspike as ps  # For synchrony and burst detection
import quantities as pq  # For unit handling
import matplotlib.pyplot as plt  # For static visualization
import plotly.express as px  # For interactive visualization
import numpy as np  # For numerical operations
from neo.io import NeuralynxIO, BlackrockIO, NixIO  # Example IO for Neo data loading
from sklearn.decomposition import PCA  # For dimensionality reduction
from sklearn.mixture import GaussianMixture  # For GMM clustering
from sklearn.cluster import DBSCAN, AgglomerativeClustering  # For clustering


## 1. Data Handling

We load the electrophysiological data using `Neo` and convert it to `SpikeInterface` format for further analysis.

In [None]:
def load_data(file_path, io_type='NeuralynxIO'):
    """
    Load electrophysiological data using Neo and convert to SpikeInterface format.
    
    Args:
    - file_path (str): Path to the file containing raw data.
    - io_type (str): Type of Neo IO to use ('NeuralynxIO', 'BlackrockIO', 'NixIO', etc.).
    
    Returns:
    - recording (si.BaseRecording): Loaded data in SpikeInterface's RecordingExtractor format.
    """
    io_types = {
        'NeuralynxIO': NeuralynxIO(dirname=file_path),
        'BlackrockIO': BlackrockIO(filename=file_path),
        'NixIO': NixIO(filename=file_path)
    }
    
    if io_type not in io_types:
        raise ValueError(f"Unsupported IO type: {io_type}")
    
    reader = io_types[io_type]
    block = reader.read_block()
    segment = block.segments[0]
    analog_signal = segment.analogsignals[0]
    recording = se.NeoRecordingExtractor(analog_signal)
    return recording

# Example Usage:
example_file_path = 'data/sample_data'  # Adjust the path for your dataset
recording = load_data(example_file_path)


## 2. Preprocessing

We preprocess the loaded data by applying bandpass filtering, optional notch filtering, and common referencing to remove noise and improve the quality of spike detection.

In [None]:
def preprocess_data(recording, freq_min=300, freq_max=3000, notch_freq=None, common_ref_type='median'):
    """
    Preprocess the loaded data by applying bandpass filtering, optional notch filtering, and common reference.
    
    Args:
    - recording (si.BaseRecording): Loaded data in SpikeInterface's RecordingExtractor format.
    - freq_min (int): Minimum frequency for bandpass filter.
    - freq_max (int): Maximum frequency for bandpass filter.
    - notch_freq (float): Frequency for notch filter to remove powerline noise. If None, skip.
    - common_ref_type (str): Type of common reference ('median', 'average', etc.).
    
    Returns:
    - recording_preprocessed (si.BaseRecording): Preprocessed data.
    """
    # Apply bandpass filter
    recording_bp = sp.bandpass_filter(recording, freq_min=freq_min, freq_max=freq_max)
    
    # Apply notch filter if specified
    if notch_freq:
        recording_notch = sp.notch_filter(recording_bp, freq=notch_freq)
    else:
        recording_notch = recording_bp
    
    # Apply common reference
    recording_cmr = sp.common_reference(recording_notch, reference=common_ref_type)
    
    return recording_cmr

# Example Usage:
recording_preprocessed = preprocess_data(recording)


## 3. Spike Sorting

We perform spike sorting using advanced algorithms such as `Kilosort` to classify and isolate spikes from different neurons.

In [None]:
def sort_spikes(recording, sorter_name='kilosort2', custom_params=None):
    """
    Perform spike sorting on the preprocessed data with configurable parameters.
    
    Args:
    - recording (si.BaseRecording): Preprocessed recording data.
    - sorter_name (str): Name of the sorting algorithm to use (e.g., 'kilosort2').
    - custom_params (dict): Optional custom parameters for the sorting algorithm.
    
    Returns:
    - sorting (si.BaseSorting): Sorted spike data.
    """
    sorter_params = custom_params if custom_params else ss.get_default_params(sorter_name)
    sorting = ss.run_sorter(sorter_name, recording, output_folder='sorting_output', **sorter_params)
    return sorting

# Example Usage:
sorting = sort_spikes(recording_preprocessed)


## 4. Postprocessing and Feature Extraction

We extract features from sorted spike waveforms to facilitate further analysis, such as clustering and visualization.


In [None]:
def postprocess_sorting(sorting, recording):
    """
    Extract features and waveforms from sorted spikes.
    
    Args:
    - sorting (si.BaseSorting): Sorted spike data.
    - recording (si.BaseRecording): Preprocessed recording data.
    
    Returns:
    - waveform_extractor (si.WaveformExtractor): Extracted waveforms.
    """
    waveform_extractor = spost.WaveformExtractor.create(recording, sorting, folder='waveforms', remove_existing_folder=True)
    waveform_extractor.set_params(ms_before=1.5, ms_after=2.5)
    waveform_extractor.run()
    return waveform_extractor

def extract_features(waveform_extractor, method='pca', n_components=3):
    """
    Extract features from the sorted spike waveforms for clustering using PCA.
    
    Args:
    - waveform_extractor (si.WaveformExtractor): Extracted waveforms.
    - method (str): Method of feature extraction ('pca', 'waveform').
    - n_components (int): Number of PCA components.
    
    Returns:
    - features (np.ndarray): Feature matrix.
    """
    waveforms = waveform_extractor.get_waveforms()
    if method == 'pca':
        pca = PCA(n_components=n_components)
        features = pca.fit_transform(waveforms.reshape(waveforms.shape[0], -1))
    else:
        # Simple feature extraction: mean and std of waveforms
        spike_width = np.mean(np.abs(waveforms), axis=(1, 2))
        spike_amplitude = np.std(waveforms, axis=(1, 2))
        features = np.column_stack((spike_width, spike_amplitude))
    
    return features

# Example Usage:
waveform_extractor = postprocess_sorting(sorting, recording_preprocessed)
features = extract_features(waveform_extractor)


## 5. Clustering and Spike Train Analysis

We cluster the extracted spike features using advanced clustering algorithms and analyze spike trains for burst detection, synchrony, and other measures.

In [None]:
def cluster_spikes(features, method='gmm', **kwargs):
    """
    Cluster spikes using specified clustering algorithm.
    
    Args:
    - features (np.ndarray): Feature matrix for clustering.
    - method (str): Clustering method ('gmm', 'dbscan', 'hdbscan', etc.).
    - kwargs: Additional parameters for clustering methods.
    
    Returns:
    - labels (np.ndarray): Cluster labels for each spike.
    """
    if method == 'gmm':
        gmm = GaussianMixture(n_components=kwargs.get('n_components', 3))
        labels = gmm.fit_predict(features)
    elif method == 'dbscan':
        db = DBSCAN(eps=kwargs.get('eps', 0.5), min_samples=kwargs.get('min_samples', 5))
        labels = db.fit_predict(features)
    elif method == 'agglomerative':
        agc = AgglomerativeClustering(n_clusters=kwargs.get('n_clusters', 3))
        labels = agc.fit_predict(features)
    else:
        raise ValueError(f"Unsupported clustering method: {method}")
    return labels

# Example Usage:
labels = cluster_spikes(features, method='gmm')
print("Cluster Labels:", labels)


## 8. Visualization

Visualize the results using `Matplotlib` and `Plotly` to provide comprehensive insights into neuronal activity.

In [None]:
def plot_cluster_features(features, labels):
    """
    Plot the clustered features using PCA or other feature extraction method.
    
    Args:
    - features (np.ndarray): Feature matrix.
    - labels (np.ndarray): Cluster labels.
    """
    fig = px.scatter(x=features[:, 0], y=features[:, 1], color=labels)
    fig.update_layout(title='Spike Clustering', xaxis_title='Feature 1', yaxis_title='Feature 2')
    fig.show()

# Example Usage:
plot_cluster_features(features, labels)
