In [None]:
# Notebook 02: Signal Analysis

This notebook analyses common in vivo electrophysiology data, using methods such as spike sorting, LFP analysis, and extraction of key metrics such as spike rates and inter-spike intervals.

### Objectives:
1. Perform spike sorting to detect and classify neural spikes.
2. Analyse local field potentials (LFPs) for power spectra and event-related potentials.
3. Extract key metrics, including spike rates, inter-spike intervals (ISIs), and time-frequency analysis.

# Notebook 02: Loding Preprocessed Data
import numpy as np
import matplotlib.pyplot as plt

# Load the preprocessed data
data = np.load('data/processed/cleaned_data.npy')

# Preview the first few seconds of data
plt.figure(figsize=(12, 4))
plt.plot(data[:20000])  # Plot the first second for a quick inspection
plt.title('Preprocessed Electrophysiology Data')
plt.xlabel('Time (samples)')
plt.ylabel('Amplitude')
plt.show()

# Notebook 02: Spike Sorting
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans

# Set a threshold for spike detection (e.g., -5 times the standard deviation)
spike_threshold = -5 * np.std(data)
spike_indices = np.where(data < spike_threshold)[0]

# Extract spike waveforms (e.g., 30 samples before and after the spike)
window_size = 30
spike_waveforms = np.array([data[i-window_size:i+window_size] for i in spike_indices if i > window_size and i < len(data) - window_size])

# Feature extraction using PCA
pca = PCA(n_components=2)
spike_features = pca.fit_transform(spike_waveforms)

# Clustering spikes using k-means
kmeans = KMeans(n_clusters=3)
labels = kmeans.fit_predict(spike_features)

# Plot the clustered spikes in PCA space
plt.figure(figsize=(10, 6))
for label in np.unique(labels):
    plt.scatter(spike_features[labels == label, 0], spike_features[labels == label, 1], label=f'Cluster {label}')
plt.xlabel('PC 1')
plt.ylabel('PC 2')
plt.title('Spike Clustering using PCA and K-Means')
plt.legend()
plt.show()

# Save the spike features and labels for further analysis
np.save('data/processed/spike_features.npy', spike_features)
np.save('data/processed/spike_labels.npy', labels)

# Notebook 02: LFP Analysis
from scipy.signal import welch

# Calculate the power spectral density (PSD) using the Welch method
frequencies, psd = welch(data, fs=20000, nperseg=2048)

# Plot the power spectral density
plt.figure(figsize=(10, 6))
plt.semilogy(frequencies, psd)
plt.title('Power Spectral Density (LFP)')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Power/Frequency (dB/Hz)')
plt.show()

# Notebook 02: Other Key Metrics Extraction
# Calculate inter-spike intervals (ISIs)
isi = np.diff(spike_indices) / 20000.0  # Convert from samples to seconds (assuming 20 kHz sampling rate)

# Plot the ISI histogram
plt.figure(figsize=(10, 6))
plt.hist(isi, bins=50, alpha=0.75)
plt.title('Inter-Spike Interval (ISI) Histogram')
plt.xlabel('Time (s)')
plt.ylabel('Count')
plt.show()

# Time-Frequency Analysis (Spectrogram)
from scipy.signal import spectrogram

frequencies, times, Sxx = spectrogram(data, fs=20000, nperseg=1024)

# Plot the spectrogram
plt.figure(figsize=(10, 6))
plt.pcolormesh(times, frequencies, 10 * np.log10(Sxx), shading='gouraud')
plt.title('Spectrogram (Time-Frequency Analysis)')
plt.xlabel('Time (s)')
plt.ylabel('Frequency (Hz)')
plt.ylim([0, 100])  # Focus on lower frequencies
plt.colorbar(label='Power (dB)')
plt.show()
