# Spike Sorting and Firing Rate Analysis

This notebook provides a step-by-step guide to performing spike sorting and firing rate analysis using in vivo electrophysiology data. It integrates well-established libraries like Neo, SpikeInterface, and Elephant for data handling, spike sorting, and advanced analysis. The analysis covers data loading, preprocessing, spike detection and sorting, feature extraction, firing rate analysis, and visualization of results.

## Setup and Prerequisites
Ensure you have all the necessary libraries installed. If they are not already installed, you can use the following command to install them:

In [None]:
# Install required libraries (if not already installed)
!pip install neo spikeinterface elephant matplotlib plotly

# Import necessary libraries
import neo
import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.postprocessing as spost
import spikeinterface.curation as scu
import elephant
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import numpy as np
from spike_sorting_firing_rate_analysis import load_data, preprocess_data, perform_spike_sorting, extract_features, analyze_firing_rate, plot_results

# Set plotly rendering in Jupyter Notebooks
import plotly.io as pio
pio.renderers.default = 'notebook'

## Data Loading
The first step is to load the raw electrophysiology data. We will use the `load_data` function from the Python script (`spike_sorting_firing_rate_analysis.py`) to load data in formats compatible with the Neo library.

In [None]:
# Load the data using the Neo library and the load_data function from the script
file_path = 'path/to/your/electrophysiology/data'
block = load_data(file_path)

# Display a summary of the loaded data
print(block.segments)

## Preprocessing

Preprocessing the data is essential for reducing noise and enhancing signal quality. We will use the `preprocess_data` function to apply a band-pass filter to the raw data.

In [None]:
# Preprocess the data (e.g., band-pass filtering)
preprocessed_data = preprocess_data(block)

# Plot raw and preprocessed signals for comparison
plt.figure(figsize=(10, 6))
plt.plot(block.segments[0].analogsignals[0], label='Raw Data', alpha=0.5)
plt.plot(preprocessed_data, label='Preprocessed Data', alpha=0.75)
plt.legend()
plt.title('Raw vs. Preprocessed Data')
plt.xlabel('Time (ms)')
plt.ylabel('Amplitude (uV)')
plt.show()

## Spike Sorting

Spike sorting is performed using SpikeInterface's spike sorting algorithms. The `perform_spike_sorting` function handles this process by applying a spike sorting method of your choice (e.g., Kilosort, HDSort).

In [None]:
# Perform spike sorting using SpikeInterface
sorting = perform_spike_sorting(preprocessed_data)

# Display sorting results
print(sorting)

## Feature Extraction

Once spikes are sorted, we extract features from the spike waveforms for further analysis and visualization. The `extract_features` function computes features like principal components or waveform parameters.

In [None]:
# Extract spike features for further analysis
features = extract_features(sorting)

# Plot feature distributions (e.g., PCA components)
plt.figure(figsize=(8, 6))
plt.scatter(features[:, 0], features[:, 1], c=sorting.get_unit_ids(), cmap='viridis', s=10)
plt.title('Spike Features (PCA)')
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.show()

## Firing Rate Analysis
Analyzing the firing rate of neurons is crucial to understanding their activity patterns. We use the `analyze_firing_rate` function to calculate firing rates from sorted spikes and visualize them as histograms.

In [None]:
# Analyze firing rates
firing_rate_results = analyze_firing_rate(sorting)

# Plot firing rate histograms
firing_rate_results['histogram'].show()

## Visualization of Results

Advanced visualizations using Plotly provide interactive plots that enable dynamic exploration of the data. Below is an example of an interactive raster plot.

In [None]:
# Interactive raster plot using Plotly
raster_plot = go.Figure()
for unit_id in sorting.get_unit_ids():
    spike_times = sorting.get_unit_spike_train(unit_id=unit_id)
    raster_plot.add_trace(go.Scatter(x=spike_times, y=[unit_id] * len(spike_times), mode='markers', name=f'Unit {unit_id}'))

raster_plot.update_layout(title='Raster Plot of Sorted Spikes', xaxis_title='Time (ms)', yaxis_title='Unit ID')
raster_plot.show()