In [None]:
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt

# Load the filtered data
filtered_data = np.load('data/processed/filtered_data.npy')

# Spike detection using thresholding
threshold = 5 * np.std(filtered_data)  # Example threshold
spike_indices = np.where(filtered_data < -threshold)[0]

# Extract spike waveforms
window_size = 30  # Number of samples before and after the spike event
spike_waveforms = np.array([filtered_data[i-window_size:i+window_size] for i in spike_indices])

# Feature extraction using PCA
pca = PCA(n_components=2)
spike_features = pca.fit_transform(spike_waveforms)

# Clustering spikes using k-means
kmeans = KMeans(n_clusters=3)  # Adjust clusters based on dataset
labels = kmeans.fit_predict(spike_features)

# Plot the clustered spikes in PCA space
plt.figure(figsize=(10, 6))
for label in np.unique(labels):
    plt.scatter(spike_features[labels == label, 0], spike_features[labels == label, 1], label=f'Cluster {label}')
plt.xlabel('PC 1')
plt.ylabel('PC 2')
plt.legend()
plt.title('Spike Clustering using PCA and K-Means')
plt.show()

# Save the cluster labels and spike features
np.save('data/processed/spike_features.npy', spike_features)
np.save('data/processed/spike_labels.npy', labels)