### For filtering, spike detection and sorting to obtain single unit activity (discard MUAs)

In [1]:
import pandas as pd
import os
import numpy as np
from scipy.signal import find_peaks, butter, filtfilt
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import matplotlib.pyplot as plt

### Filtering functions

In [None]:
def butter_bandpass(lowcut, highcut, fs, order=4):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    return b, a

def butter_bandpass_filter(data, lowcut, highcut, fs, order=4):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = filtfilt(b, a, data)
    return y

In [None]:
lowcut = 300  # Low cutoff frequency in Hz
highcut = 6000  # High cutoff frequency in Hz
fs = 12500  # Sampling frequency in Hz
window_size = 25  # half of what the algorithm looks at when clustering
n_components = 3  # number of PCA components
cluster_numbers = [i for i in range(2,6)]
data_dir = 'path_to_data'

figure_idx = 0
fig, axes = None, None

results_dict = {'neuron_' + str(i) :None for i in range(1,360+1)}
peaks_dict = {'neuron_' + str(i) :None for i in range(1,360+1)}
clusters_dict = {'neuron_' + str(i) :None for i in range(1,360+1)}

for neuron in range(1,360+1):
    file = os.path.join('/Users/jazlynn/Downloads/neurons-csv-format','neuron_' + str(neuron).zfill(3) + '.csv')
    if os.path.exists(file):
        data = pd.read_csv(file)
        filtered_data = butter_bandpass_filter(data['Voltage'], lowcut, highcut, fs)
        spike_peaks, _ = find_peaks(-filtered_data, height=-np.percentile(filtered_data,0.7))
        
        spike_waveforms = []
        edge_peak_list = []
        for n_peak,peak in enumerate(spike_peaks):
            if peak - window_size >= 0 and peak + window_size < len(filtered_data):
                spike_waveforms.append(filtered_data[peak - window_size: peak + window_size])
            else:
                edge_peak_list.append(n_peak)
                
        if edge_peak_list:
            spike_peaks = np.delete(spike_peaks, edge_peak_list)
            # print('edge spikes present')
            
        peaks_dict['neuron_' + str(neuron)] = spike_peaks
        spike_waveforms = np.array(spike_waveforms)

        # Apply PCA for dimensionality reduction
        pca = PCA(n_components=n_components)
        waveform_features = pca.fit_transform(spike_waveforms)

        si_score = []
        for n_clusters in cluster_numbers:
            kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init='auto')
            cluster_labels = kmeans.fit_predict(waveform_features)
            si_score.append(silhouette_score(waveform_features, cluster_labels, metric='euclidean'))
            
        opt_n_clusters = cluster_numbers[np.argmax(si_score)]
        # print(opt_n_clusters)
        
        kmeans = KMeans(n_clusters=opt_n_clusters, random_state=0, n_init='auto')
        cluster_labels = kmeans.fit_predict(waveform_features)
        clusters_dict['neuron_' + str(neuron)] = cluster_labels
        
        for final_cluster in range(opt_n_clusters):
            ### Refractory violations ###
            refractory_violation_counter = 0
            cluster_spikes = spike_peaks[cluster_labels==final_cluster]
            for isi in [cluster_spikes[i]-cluster_spikes[i-1] for i in range(1,len(cluster_spikes))]:
                if isi <12.5:
                    refractory_violation_counter = refractory_violation_counter + 1
            
            if refractory_violation_counter > 0:        
                print('neuron ' + str(neuron) + ' cluster ' + str(final_cluster) + ' has ' + str(refractory_violation_counter) + ' refractory violations')
                
            ### SNR ###
            cluster_waveforms = spike_waveforms[cluster_labels == k]
            mean_waveform = np.mean(cluster_waveforms, axis=0)
            snr = (np.max(mean_waveform) - np.min(mean_waveform) )/ (np.std(mean_waveform) * 2)
            print('SNR of neuron ' + str(neuron) + ' cluster ' + str(final_cluster) + ': ' + f'{snr:.2f}')

        ### PLOTTING ###     
        if figure_idx % 9 == 0:
            if fig is not None:
                    plt.tight_layout()
                    # Save the figure
                    fig.savefig(os.path.join('Group5_spikesorting_QC', f'figure_{figure_idx//9}.png'))
                    plt.close(fig)
            fig, axes = plt.subplots(3, 3, figsize=(15, 15))
            plt.suptitle(f'Figures {figure_idx//9 + 1}-{figure_idx//9 + 1}')

        i, j = figure_idx % 9 // 3, figure_idx % 9 % 3
        
        for k in range(opt_n_clusters):
            cluster_waveforms = spike_waveforms[cluster_labels == k]
            mean_waveform = np.mean(cluster_waveforms, axis=0)
            axes[i, j].plot([t/(fs/1000) for t in range(-window_size,window_size)],mean_waveform, label=f'{np.sum(cluster_labels==k)} spikes')
        
        axes[i, j].set_title('neuron ' + str(neuron))
        axes[i, j].set_xlabel('Time (ms)')
        axes[i, j].set_ylabel('Amplitude')
        axes[i, j].legend()
        figure_idx += 1

# last fig
if fig is not None:
    plt.tight_layout()
    # Save the last figure
    fig.savefig(os.path.join('Group5_spikesorting_QC', f'figure_{figure_idx//9}.png'))
    plt.close(fig)