# 2. spike classification evaluation

In [None]:
import torch
from torch.utils.data import TensorDataset, DataLoader
from src.utils import load_and_concatenate_npy, normalize_data
from src.models import CNNClassifier, GRUClassifier
from src.train import evaluate_model
import numpy as np

In [None]:
# 1. Data Loading and Preparation (for evaluation)
background_file_paths_test = [
    '../data/spikeshannel_background_40.npy'
]
spikes_file_paths_test = [
    '../data/spikes/channel_spikes_40.npy'
]

In [None]:
def create_test_dataset(background_paths, spikes_paths):
    background_array = load_and_concatenate_npy(background_paths)
    spikes_array = load_and_concatenate_npy(spikes_paths)
    
    X_np = np.concatenate([background_array, spikes_array], axis=0)
    y_np = np.concatenate([
        np.zeros(background_array.shape[0]), 
        np.ones(spikes_array.shape[0])
    ], axis=0)
    
    X_tensor = torch.from_numpy(X_np).float()
    y_tensor = torch.from_numpy(y_np).long()
    return X_tensor, y_tensor

In [None]:
X_test, y_test = create_test_dataset(background_file_paths_test, spikes_file_paths_test)
train_mean = X_test.mean()
train_std = X_test.std()
X_test_normalized = normalize_data(X_test, train_mean, train_std).unsqueeze(1)
test_dataset = TensorDataset(X_test_normalized, y_test)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

In [None]:
# Evaluate CNN model
cnn_model = CNNClassifier(input_size=X_test.shape[1], num_classes=2)
print("Evaluating CNN model...")
evaluate_model(cnn_model, 'best_cnn_model.pth', test_loader, 'Test')

In [None]:
# Evaluate GRU model
gru_model = GRUClassifier(input_size=X_test.shape[1], num_classes=2)
print("\nEvaluating GRU model...")
evaluate_model(gru_model, 'best_gru_model.pth', test_loader, 'Test')