### For filtering, spike detection and sorting to obtain single unit activity (discard MUAs)

In [None]:
import pandas as pd
import os
import numpy as np
from scipy.signal import find_peaks, butter, filtfilt
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt

#### Filtering

In [None]:
# change no need wrapper... use filtfilt

# Apply a bandpass filter to isolate the frequency range of interest
lowcut = 300  # Low cutoff frequency in Hz
highcut = 6000  # High cutoff frequency in Hz
fs = 30000  # Sampling frequency in Hz

def butter_bandpass(lowcut, highcut, fs, order=4):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    return b, a

def butter_bandpass_filter(data, lowcut, highcut, fs, order=4):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = lfilter(b, a, data)
    return y

filtered_data = butter_bandpass_filter(data, lowcut, highcut, fs)

#### find peaks

In [None]:
# Find peaks in the filtered data above a certain threshold
threshold = 3.0  # Adjust this threshold as needed
spike_peaks, _ = find_peaks(filtered_data, height=threshold)

#### waveforms and clustering

In [None]:
window_size = 30  # Adjust as needed
spike_waveforms = []
for peak in spike_peaks:
    if peak - window_size >= 0 and peak + window_size < len(filtered_data):
        spike_waveforms.append(filtered_data[peak - window_size: peak + window_size])
spike_waveforms = np.array(spike_waveforms)

# Apply PCA for dimensionality reduction
n_components = 3  # Adjust as needed
pca = PCA(n_components=n_components)
waveform_features = pca.fit_transform(spike_waveforms)

# Cluster the waveforms using K-means
n_clusters = 4  # Adjust as needed
kmeans = KMeans(n_clusters=n_clusters, random_state=0)
cluster_labels = kmeans.fit_predict(waveform_features)

#### Viz

In [None]:
# Visualize the clustered waveforms
for i in range(n_clusters):
    cluster_waveforms = spike_waveforms[cluster_labels == i]
    mean_waveform = np.mean(cluster_waveforms, axis=0)
    plt.plot(mean_waveform, label=f'Cluster {i + 1}')

plt.xlabel('Time (samples)')
plt.ylabel('Amplitude')
plt.legend()
plt.show()