# 2. spike classification evaluation

In [None]:
import sys
import os

import numpy as np

import torch
from torch.utils.data import TensorDataset, DataLoader

# Get the absolute path of the parent directory (project root)
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))

# Add the project root to the Python path
if project_root not in sys.path:
    sys.path.append(project_root)

from src.utils import load_and_concatenate_npy, normalize_data
from src.models import CNNClassifier, GRUClassifier
from src.evaluate import evaluate_model

from scripts.train_cnn import train_cnn
from scripts.train_gru import train_gru


In [2]:
# 1. Train the Models
# This will run the training scripts and save the best models to the project root.
print("--- Starting CNN Model Training ---")
train_cnn()

--- Starting CNN Model Training ---
Loading data...
Successfully loaded file: ../data/spikes/channel_background_9.npy with shape (5704, 42)
Successfully loaded file: ../data/spikes//channel_background_16.npy with shape (4324, 42)
Successfully loaded file: ../data/spikes/channel_background_33.npy with shape (2934, 42)

All arrays concatenated. Final shape: (12962, 42)
Loading data...
Successfully loaded file: ../data/spikes/channel_spikes_9.npy with shape (1079, 42)
Successfully loaded file: ../data/spikes/channel_spikes_16.npy with shape (253, 42)
Successfully loaded file: ../data/spikes/channel_spikes_33.npy with shape (1159, 42)

All arrays concatenated. Final shape: (2491, 42)
Loading data...
Successfully loaded file: ../data/spikes/channel_background_11.npy with shape (6835, 42)

All arrays concatenated. Final shape: (6835, 42)
Loading data...
Successfully loaded file: ../data/spikes/channel_spikes_11.npy with shape (819, 42)

All arrays concatenated. Final shape: (819, 42)
Startin

In [3]:
print("\n--- Starting GRU Model Training ---")
train_gru()


--- Starting GRU Model Training ---
Loading data...
Successfully loaded file: ../data/spikes/channel_background_9.npy with shape (5704, 42)
Successfully loaded file: ../data/spikes/channel_background_16.npy with shape (4324, 42)
Successfully loaded file: ../data/spikes/channel_background_33.npy with shape (2934, 42)

All arrays concatenated. Final shape: (12962, 42)
Loading data...
Successfully loaded file: ../data/spikes/channel_spikes_9.npy with shape (1079, 42)
Successfully loaded file: ../data/spikes/channel_spikes_16.npy with shape (253, 42)
Successfully loaded file: ../data/spikes/channel_spikes_33.npy with shape (1159, 42)

All arrays concatenated. Final shape: (2491, 42)
Loading data...
Successfully loaded file: ../data/spikes/channel_background_11.npy with shape (6835, 42)

All arrays concatenated. Final shape: (6835, 42)
Loading data...
Successfully loaded file: ../data/spikes/channel_spikes_11.npy with shape (819, 42)

All arrays concatenated. Final shape: (819, 42)
Startin

In [4]:
# 2. Data Loading and Preparation (for evaluation)
background_file_paths_test = [
    '../data/spikes/channel_background_40.npy'
]
spikes_file_paths_test = [
    '../data/spikes/channel_spikes_40.npy'
]

In [5]:
def create_test_dataset(background_paths, spikes_paths):
    background_array = load_and_concatenate_npy(background_paths)
    spikes_array = load_and_concatenate_npy(spikes_paths)
    
    X_np = np.concatenate([background_array, spikes_array], axis=0)
    y_np = np.concatenate([
        np.zeros(background_array.shape[0]), 
        np.ones(spikes_array.shape[0])
    ], axis=0)
    
    X_tensor = torch.from_numpy(X_np).float()
    y_tensor = torch.from_numpy(y_np).long()
    return X_tensor, y_tensor

In [6]:
X_test, y_test = create_test_dataset(background_file_paths_test, spikes_file_paths_test)
train_mean = X_test.mean()
train_std = X_test.std()
X_test_normalized = normalize_data(X_test, train_mean, train_std).unsqueeze(1)
test_dataset = TensorDataset(X_test_normalized, y_test)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

Loading data...
Successfully loaded file: ../data/spikes/channel_background_40.npy with shape (4073, 42)

All arrays concatenated. Final shape: (4073, 42)
Loading data...
Successfully loaded file: ../data/spikes/channel_spikes_40.npy with shape (108, 42)

All arrays concatenated. Final shape: (108, 42)


In [12]:
%%time

# Evaluate CNN model
cnn_model = CNNClassifier(input_size=X_test.shape[1], num_classes=2)
print("Evaluating CNN model...")
evaluate_model(cnn_model, '../model/best_cnn_model.pth', test_loader, 'Test')

Evaluating CNN model...

--- Best Model Test Metrics ---
Accuracy: 0.9823
Macro Precision: 0.7967
Macro Recall: 0.9909
Macro F1 Score: 0.8678
CPU times: total: 312 ms
Wall time: 411 ms


In [None]:
%%time

# Evaluate GRU model
gru_model = GRUClassifier(input_size=X_test.shape[1], num_classes=2)
print("\nEvaluating GRU model...")
evaluate_model(gru_model, '../model/best_gru_model.pth', test_loader, 'Test')


Evaluating GRU model...

--- Best Model Test Metrics ---
Accuracy: 0.9735
Macro Precision: 0.7461
Macro Recall: 0.9729
Macro F1 Score: 0.8202
CPU times: total: 922 ms
Wall time: 1.36 s
