# 3. spike clustering evaluation

In [None]:
import torch
import numpy as np
from sklearn.decomposition import PCA
from src.utils import load_and_concatenate_npy
from src.models import ShallowAutoencoder
from src.cluster import plot_reconstruction, apply_and_plot_dbscan

In [None]:
# 1. Load Data
spikes_file_paths = [
    '../data/spikes/channel_spikes_9.npy',
    '../data/spikes/channel_spikes_16.npy',
    '../data/spikes/channel_spikes_33.npy',
    '../data/spikes/channel_spikes_11.npy',
    '../data/spikes/channel_spikes_40.npy'
]
all_spikes_np = load_and_concatenate_npy(spikes_file_paths)
all_spikes_tensor = torch.from_numpy(all_spikes_np).float()

In [None]:
# 2. Load the Trained Autoencoder
autoencoder_model = ShallowAutoencoder()
autoencoder_model.load_state_dict(torch.load('best_autoencoder_model.pth'))
autoencoder_model.eval()

In [None]:
# Plot reconstruction with the loaded model
plot_reconstruction(autoencoder_model, all_spikes_tensor)

In [None]:
# Get embeddings and plot clusters from Autoencoder
with torch.no_grad():
    autoencoder_embeddings = autoencoder_model.encoder(all_spikes_tensor).cpu().numpy()

apply_and_plot_dbscan(
    data_embeddings=autoencoder_embeddings,
    title='3D Autoencoder Embeddings with DBSCAN Clusters',
    x_label='Embedding Dimension 1',
    y_label='Embedding Dimension 2',
    z_label='Embedding Dimension 3',
    eps=250, 
    min_samples=10
)

In [None]:
# 3. Apply PCA and Plot Clusters
pca = PCA(n_components=3)
pca_embeddings = pca.fit_transform(all_spikes_np)

apply_and_plot_dbscan(
    data_embeddings=pca_embeddings,
    title='3D PCA Components with DBSCAN Clusters',
    x_label='Principal Component 1',
    y_label='Principal Component 2',
    z_label='Principal Component 3',
    eps=250, 
    min_samples=10
)