In [4]:
import numpy as np
import pandas as pd
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
import spikeinterface as si
import matplotlib.pyplot as plt
import os
from matplotlib.backends.backend_pdf import PdfPages

from tqdm import tqdm


import sys
import spikeinterface as si
import matplotlib.pyplot as plt
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
import spikeinterface.widgets as sw
import spikeinterface.qualitymetrics as sqm
import json
import probeinterface

from probeinterface import Probe, ProbeGroup
from probeinterface.plotting import plot_probe, plot_probegroup
from probeinterface import generate_dummy_probe, generate_linear_probe
from probeinterface import write_probeinterface, read_probeinterface
from probeinterface import write_prb, read_prb
from torch.nn.functional import max_pool1d

from torch_geometric.nn import GATConv

import torch.nn.functional as F
from pathlib import Path


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torch.utils.data import Subset
from probeinterface import write_probeinterface, read_probeinterface
import networkx as nx
import pickle

In [5]:
def count_array2_in_range_of_array1(array1, array2, threshold=5):

    sorted_array1 = np.sort(array1)
    
    lefts = array2 - threshold
    rights = array2 + threshold
    
    left_indices = np.searchsorted(sorted_array1, lefts, side='left')
    
    right_indices = np.searchsorted(sorted_array1, rights, side='right')
    
    has_within_range = right_indices > left_indices
    
    count = np.sum(has_within_range)
    
    return count


def detect_local_maxima_in_window(data, window_size=20, std_multiplier=2):

    """
    在每个滑动窗口范围内检测局部最大值的索引，并确保最大值大于两倍的标准差。

    参数:
    data : numpy.ndarray
        输入数据，形状为 (n_rows, n_columns)。
    window_size : int
        滑动窗口的大小，用于定义局部范围，默认为 20。
    std_multiplier : float
        标准差的倍数，用于筛选局部最大值，默认为 2。

    返回:
    local_maxima_indices : list of numpy.ndarray
        每行局部最大值的索引列表，每个元素是对应行局部最大值的索引数组。
    """
    local_maxima_indices = []

    for row in data:
        maxima_indices = []
        row_std = np.std(row.astype(np.float32))
        threshold = std_multiplier * row_std

        for start in range(0, len(row), window_size):
            end = min(start + window_size, len(row))
            window = row[start:end]
            
            if len(window) > 0:
                local_max_index = np.argmax(window)
                local_max_value = window[local_max_index]
                
                if local_max_value > threshold:
                    maxima_indices.append(start + local_max_index)  
        
        local_maxima_indices.extend(maxima_indices)
        local_maxima_indices = list(set(local_maxima_indices))  

    return local_maxima_indices


def cluster_label_array1_based_on_array2(array1, array2, threshold=5):

    """
    根据 array2 的 'time' 和 'cluster' 对 array1 进行标记。
    如果 array1 中的某个值在 threshold 范围内存在于 array2 的 'time' 中，则标记为对应的 'cluster' 值，否则为 0。
    
    参数:
    array1 : numpy.ndarray
        要标记的数组。
    array2 : numpy.ndarray
        包含 'time' 和 'cluster' 的二维数组。
        第一列为 'time'，第二列为 'cluster'。
    threshold : int
        判断范围的阈值。
    
    返回:
    labels : numpy.ndarray
        长度为 len(array1) 的标签数组，值为 array2 中的 'cluster' 或 0。
    """

    array2 = np.array(array2.iloc[:, [5, 1]])
    sorted_indices = np.argsort(array2[:, 0])
    sorted_array2 = array2[sorted_indices]
    
    labels = np.zeros(len(array1), dtype=int)
    
    # 遍历 array1 中的每个元素
    for i, value in enumerate(array1):
        # 计算当前值的范围
        left = value - threshold
        right = value + threshold
        
        left_index = np.searchsorted(sorted_array2[:, 0], left, side='left')
        right_index = np.searchsorted(sorted_array2[:, 0], right, side='right')
        
        # 如果范围内存在值，则标记为对应的 'cluster'
        if right_index > left_index:
            # 获取范围内的第一个匹配值的 'cluster'
            labels[i] = sorted_array2[left_index, 1]
    
    return labels


def label_array1_based_on_array2(array1, array2, threshold=5):

    """
    根据 array2 的值对 array1 进行标记。
    如果 array1 中的某个值在 threshold 范围内存在于 array2 中，则标记为 1，否则为 0。
    
    参数:
    array1 : numpy.ndarray
        要标记的数组。
    array2 : numpy.ndarray
        用于判断的数组。
    threshold : int
        判断范围的阈值。
    
    返回:
    labels : numpy.ndarray
        长度为 len(array1) 的标签数组，值为 0 或 1。
    """
    # 对 array2 进行排序以加速搜索
    sorted_array2 = np.sort(array2)
    
    # 初始化标签数组，默认值为 0
    labels = np.zeros(len(array1), dtype=int)
    
    # 遍历 array1 中的每个元素
    for i, value in enumerate(array1):
        # 计算当前值的范围
        left = value - threshold
        right = value + threshold
        
        # 使用二分搜索判断范围内是否存在值
        left_index = np.searchsorted(sorted_array2, left, side='left')
        right_index = np.searchsorted(sorted_array2, right, side='right')
        
        # 如果范围内存在值，则标记为 1
        if right_index > left_index:
            labels[i] = 1
    
    return labels


def extract_windows(data, indices, window_size=61):
    """
    根据给定的时间点索引提取窗口。
    
    参数:
    data : numpy.ndarray
        输入数据，形状为 (n_channels, time)
    indices : numpy.ndarray
        时间点索引数组，用于指定需要提取窗口的中心点
    window_size : int
        窗口长度，默认为61（对应time-30到time+31）
    
    返回:
    windows : numpy.ndarray
        提取的窗口数据，形状为 (len(indices), n_channels, window_size)
    """
    n_channels, time_length = data.shape
    half_window = window_size // 2

    if np.any(indices < half_window) or np.any(indices >= time_length - half_window):
        raise ValueError("Some indices are out of bounds for the given window size.")

    windows = []
    for idx in indices:
        window = data[:, idx - half_window:idx + half_window + 1]
        windows.append(window)

    windows = np.array(windows)
    return windows

In [6]:
class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]
    
class Spike_Detection_MLP(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2, output_size, n_channels, time_window):
        super(Spike_Detection_MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_size2, 16)
        self.relu3 = nn.ReLU()
        self.fc4 = nn.Linear(16, output_size)
        self.sigmoid = nn.Sigmoid()  

        self.n_channels = n_channels
        self.time_window = time_window
    def forward(self, x):
        x = x.reshape(-1, self.n_channels * self.time_window)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        x = self.relu3(x)
        x = self.fc4(x)
        x = self.sigmoid(x)
        return x

In [None]:
recording_raw = se.MEArecRecordingExtractor(file_path=f'/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/data_generation/setting_1/Neuronexus_32_50_cell_recordings.h5')
recording_f = spre.bandpass_filter(recording_raw, freq_min=300, freq_max=3000)
recording_f = spre.common_reference(recording_f, reference="global", operator="median")
spike_inf = pd.read_csv(f"/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/data_generation/setting_1/spike_inf.csv")

total_frames = int(1000 * 10000)
chunk_size = 100000  
window_size = 31
half_window = window_size // 2

all_valid_indices = []
all_windows = []

for start_frame in range(0, total_frames, chunk_size):
    end_frame = min(start_frame + chunk_size, total_frames)
    
    data_chunk = recording_f.get_traces(
        start_frame=start_frame,
        end_frame=end_frame
    )  # shape: (n_channels, chunk_size)
    
    threshold_result = detect_local_maxima_in_window(
        data_chunk.T,  
        std_multiplier=0.7
    )
    
    threshold_result = np.array(threshold_result) + start_frame
    valid_indices = threshold_result[
        (threshold_result >= start_frame + half_window + 1) & 
        (threshold_result < end_frame - half_window)
    ]
    
    for idx in valid_indices:
        rel_idx = idx - start_frame
        window = data_chunk.T[:, rel_idx-half_window : rel_idx+half_window+1]
        all_windows.append(window)
    
    all_valid_indices.extend(valid_indices)

all_valid_indices = np.array(all_valid_indices)
all_windows = np.stack(all_windows)  

labels = label_array1_based_on_array2(all_valid_indices, spike_inf['time'], threshold=1)

labels = np.array(labels) 
indices_0 = np.where(labels == 0)[0] 
indices_1 = np.where(labels == 1)[0] 

target_0_count = len(indices_1) 

if len(indices_0) > target_0_count:
    sampled_indices_0 = np.random.choice(indices_0, target_0_count, replace=False)
else:
    sampled_indices_0 = indices_0  

final_indices = np.concatenate([sampled_indices_0, indices_1])

np.random.shuffle(final_indices)

sampled_windows = all_windows[final_indices]
sampled_labels = labels[final_indices]

dataset = CustomDataset(sampled_windows, sampled_labels)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

batch_size = 1024 
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


all_valid_indices_val = []
all_windows_val = []

for start_frame in range(total_frames, 1600 * 10000, chunk_size):
    end_frame = min(start_frame + chunk_size, 1600 * 10000)
    
    data_chunk = recording_f.get_traces(
        start_frame=start_frame,
        end_frame=end_frame
    )  # shape: (n_channels, chunk_size)
    
    threshold_result = detect_local_maxima_in_window(
        data_chunk.T,  
        std_multiplier=0.7
    )
    
    threshold_result = np.array(threshold_result) + start_frame
    valid_indices = threshold_result[
        (threshold_result >= start_frame + half_window + 1) & 
        (threshold_result < end_frame - half_window)
    ]
    
    for idx in valid_indices:
        rel_idx = idx - start_frame
        window = data_chunk.T[:, rel_idx-half_window : rel_idx+half_window+1]
        all_windows_val.append(window)
    
    all_valid_indices_val.extend(valid_indices)

all_valid_indices_val = np.array(all_valid_indices_val)
all_windows_val = np.stack(all_windows_val)  

labels_val = label_array1_based_on_array2(all_valid_indices_val, spike_inf['time'], threshold=1)
labels_val = np.array(labels_val) 
dataset = CustomDataset(all_windows_val, labels_val)
val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)


input_size = sampled_windows.shape[1] * sampled_windows.shape[2]
hidden_size1 = 128
hidden_size2 = 32
output_size = 1  
device = 'cuda'

val_results = {}
val_results['tpr'] = []
val_results['tnr'] = []
val_results['accuracy'] = []
for trail in range(1, 6):
    criterion = nn.BCELoss()  

    model = Spike_Detection_MLP(input_size, hidden_size1, hidden_size2, 
                                    output_size, n_channels=sampled_windows.shape[1], time_window= sampled_windows.shape[2])
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=0.0001)

    num_epochs = 50
    tpr_best = 0
    i = 0
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for batch_data, batch_labels in train_loader:
            batch_labels = batch_labels.float().unsqueeze(1)

            batch_data = batch_data.to(device)
            batch_labels = batch_labels.to(device)

            outputs = model(batch_data)
            loss = criterion(outputs, batch_labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        # print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}")

        model.eval()
        correct = 0
        total = 0

        true_positive = 0
        true_negative = 0
        false_positive = 0
        false_negative = 0

        with torch.no_grad():
            for batch_data, batch_labels in val_loader:
                batch_labels = batch_labels.float().unsqueeze(1)
                batch_data = batch_data.to(device)
                batch_labels = batch_labels.to(device)

                outputs = model(batch_data)
                predicted = (outputs > 0.5).float()  
                total += batch_labels.size(0)
                correct += (predicted == batch_labels).sum().item()
                true_positive += ((predicted == 1) & (batch_labels == 1)).sum().item()
                true_negative += ((predicted == 0) & (batch_labels == 0)).sum().item()
                false_positive += ((predicted == 1) & (batch_labels == 0)).sum().item()
                false_negative += ((predicted == 0) & (batch_labels == 1)).sum().item()


        tpr = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
        tnr = true_negative / (true_negative + false_positive) if (true_negative + false_positive) > 0 else 0
        accuracy = correct / total
        print(f"Test Accuracy: {100 * correct / total:.2f}%")
        # print(f"Test TPR: {100 * tpr:.2f}%")
        # print(f"Test TNR: {100 * tnr:.2f}%")

        if tpr > tpr_best:
            tpr_best = tpr
            i = 0
            torch.save(model, f'/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/spike_detection/train_results/setting_1/trail_{trail}.pth')
            # print(f"Best model saved with TPR: {tpr_best:.4f}")
            # print("_" * 60)

        else:
            i += 1
            if i == 3:
                print(f"Training stopped after {epoch+1} epochs with best TPR: {tpr_best:.4f}")
                print("_" * 60)
                break
    val_results['tpr'].append(tpr_best)
    val_results['tnr'].append(tnr)
    val_results['accuracy'].append(accuracy)

with open("/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/spike_detection/eval_results/setting_1/val_results.pkl", "wb") as f:
    pickle.dump(val_results, f)

Test Accuracy: 97.66%
Test Accuracy: 98.45%
Test Accuracy: 98.68%
Test Accuracy: 98.90%
Test Accuracy: 98.98%
Test Accuracy: 99.16%
Test Accuracy: 99.11%
Test Accuracy: 99.23%
Test Accuracy: 99.24%
Test Accuracy: 99.26%
Test Accuracy: 99.36%
Test Accuracy: 99.29%
Test Accuracy: 99.33%
Training stopped after 13 epochs with best TPR: 0.9916
____________________________________________________________
Test Accuracy: 97.59%
Test Accuracy: 98.49%
Test Accuracy: 98.62%
Test Accuracy: 98.92%
Test Accuracy: 99.05%
Test Accuracy: 99.06%
Test Accuracy: 99.16%
Test Accuracy: 99.19%
Test Accuracy: 99.15%
Test Accuracy: 99.21%
Test Accuracy: 99.23%
Test Accuracy: 99.36%
Test Accuracy: 99.23%
Test Accuracy: 99.32%
Test Accuracy: 99.33%
Test Accuracy: 99.33%
Training stopped after 16 epochs with best TPR: 0.9931
____________________________________________________________
Test Accuracy: 97.51%
Test Accuracy: 98.27%
Test Accuracy: 98.64%
Test Accuracy: 98.91%
Test Accuracy: 98.90%
Test Accuracy: 99.04

In [13]:
with open("/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/spike_detection/eval_results/setting_1/val_results.pkl", "wb") as f:
    pickle.dump(val_results, f)

In [7]:
for file in os.listdir("/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/data_generation/setting_6_neuron_type/recordings"):
    print(f"Processing file: {file}")
    recording_raw = se.MEArecRecordingExtractor(file_path=f'/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/data_generation/setting_6_neuron_type/recordings/{file}')
    recording_f = spre.bandpass_filter(recording_raw, freq_min=300, freq_max=3000)
    recording_f = spre.common_reference(recording_f, reference="global", operator="median")
    file_seg = file.split(".")[0]
    spike_inf = pd.read_csv(f"/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/data_generation/setting_6_neuron_type/spike_inf/{file_seg}_spike_inf.csv")

    total_frames = int(600 * 10000)
    chunk_size = 100000  
    window_size = 31
    half_window = window_size // 2

    all_valid_indices = []
    all_windows = []

    for start_frame in range(0, total_frames, chunk_size):
        end_frame = min(start_frame + chunk_size, total_frames)
        
        data_chunk = recording_f.get_traces(
            start_frame=start_frame,
            end_frame=end_frame
        )  # shape: (n_channels, chunk_size)
        
        threshold_result = detect_local_maxima_in_window(
            data_chunk.T,  
            std_multiplier=0.7
        )
        
        threshold_result = np.array(threshold_result) + start_frame
        valid_indices = threshold_result[
            (threshold_result >= start_frame + half_window + 1) & 
            (threshold_result < end_frame - half_window)
        ]
        
        for idx in valid_indices:
            rel_idx = idx - start_frame
            window = data_chunk.T[:, rel_idx-half_window : rel_idx+half_window+1]
            all_windows.append(window)
        
        all_valid_indices.extend(valid_indices)

    all_valid_indices = np.array(all_valid_indices)
    all_windows = np.stack(all_windows)  

    labels = label_array1_based_on_array2(all_valid_indices, spike_inf['time'], threshold=1)

    labels = np.array(labels) 
    indices_0 = np.where(labels == 0)[0] 
    indices_1 = np.where(labels == 1)[0] 

    target_0_count = len(indices_1) 

    if len(indices_0) > target_0_count:
        sampled_indices_0 = np.random.choice(indices_0, target_0_count, replace=False)
    else:
        sampled_indices_0 = indices_0  

    final_indices = np.concatenate([sampled_indices_0, indices_1])

    np.random.shuffle(final_indices)

    sampled_windows = all_windows[final_indices]
    sampled_labels = labels[final_indices]

    dataset = CustomDataset(sampled_windows, sampled_labels)

    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    batch_size = 1024 
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


    all_valid_indices_val = []
    all_windows_val = []

    for start_frame in range(1200 * 10000, 1600 * 10000, chunk_size):
        end_frame = min(start_frame + chunk_size, 1600 * 10000)
        
        data_chunk = recording_f.get_traces(
            start_frame=start_frame,
            end_frame=end_frame
        )  # shape: (n_channels, chunk_size)
        
        threshold_result = detect_local_maxima_in_window(
            data_chunk.T,  
            std_multiplier=0.7
        )
        
        threshold_result = np.array(threshold_result) + start_frame
        valid_indices = threshold_result[
            (threshold_result >= start_frame + half_window + 1) & 
            (threshold_result < end_frame - half_window)
        ]
        
        for idx in valid_indices:
            rel_idx = idx - start_frame
            window = data_chunk.T[:, rel_idx-half_window : rel_idx+half_window+1]
            all_windows_val.append(window)
        
        all_valid_indices_val.extend(valid_indices)

    all_valid_indices_val = np.array(all_valid_indices_val)
    all_windows_val = np.stack(all_windows_val)  

    labels_val = label_array1_based_on_array2(all_valid_indices_val, spike_inf['time'], threshold=1)
    labels_val = np.array(labels_val) 
    dataset = CustomDataset(all_windows_val, labels_val)
    val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)


    input_size = sampled_windows.shape[1] * sampled_windows.shape[2]
    hidden_size1 = 128
    hidden_size2 = 32
    output_size = 1  
    device = 'cuda'

    val_results = {}
    val_results['tpr'] = []
    val_results['tnr'] = []
    val_results['accuracy'] = []
    print(f"Training model for file segment: {file_seg}")
    os.makedirs(f"/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/spike_detection/train_results/setting_6/{file_seg}", exist_ok=True)
    for trail in range(1, 6):
        criterion = nn.BCELoss()  

        model = Spike_Detection_MLP(input_size, hidden_size1, hidden_size2, 
                                        output_size, n_channels=sampled_windows.shape[1], time_window= sampled_windows.shape[2])
        model = model.to(device)

        optimizer = optim.Adam(model.parameters(), lr=0.0001)

        num_epochs = 50
        tpr_best = 0
        i = 0
        for epoch in range(num_epochs):
            model.train()
            total_loss = 0
            for batch_data, batch_labels in train_loader:
                batch_labels = batch_labels.float().unsqueeze(1)

                batch_data = batch_data.to(device)
                batch_labels = batch_labels.to(device)

                outputs = model(batch_data)
                loss = criterion(outputs, batch_labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()

            # print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}")

            model.eval()
            correct = 0
            total = 0

            true_positive = 0
            true_negative = 0
            false_positive = 0
            false_negative = 0

            with torch.no_grad():
                for batch_data, batch_labels in test_loader:
                    batch_labels = batch_labels.float().unsqueeze(1)
                    batch_data = batch_data.to(device)
                    batch_labels = batch_labels.to(device)

                    outputs = model(batch_data)
                    predicted = (outputs > 0.5).float()  
                    total += batch_labels.size(0)
                    correct += (predicted == batch_labels).sum().item()
                    true_positive += ((predicted == 1) & (batch_labels == 1)).sum().item()
                    true_negative += ((predicted == 0) & (batch_labels == 0)).sum().item()
                    false_positive += ((predicted == 1) & (batch_labels == 0)).sum().item()
                    false_negative += ((predicted == 0) & (batch_labels == 1)).sum().item()


            tpr = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
            tnr = true_negative / (true_negative + false_positive) if (true_negative + false_positive) > 0 else 0
            accuracy = correct / total
            # print(f"Test Accuracy: {100 * correct / total:.2f}%")
            # print(f"Test TPR: {100 * tpr:.2f}%")
            # print(f"Test TNR: {100 * tnr:.2f}%")

            if tpr > tpr_best:
                tpr_best = tpr
                i = 0
                torch.save(model, f'/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/spike_detection/train_results/setting_6/{file_seg}/trail_{trail}.pth')
                # print(f"Best model saved with TPR: {tpr_best:.4f}")
                # print("_" * 60)

            else:
                i += 1
                if i == 3:
                    print(f"Training stopped after {epoch+1} epochs with best TPR: {tpr_best:.4f}")
                    print("_" * 60)
                    break

        val_results['tpr'].append(tpr_best)
        val_results['tnr'].append(tnr)
        val_results['accuracy'].append(accuracy)

    with open(f"/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/spike_detection/eval_results/setting_6/val_results_{file_seg}.pkl", "wb") as f:
        pickle.dump(val_results, f)

    del all_windows, all_windows_val

Processing file: Neuronexus_32_50_cell_cell_type_ChC_recording.h5
Training model for file segment: Neuronexus_32_50_cell_cell_type_ChC_recording
Training stopped after 11 epochs with best TPR: 0.7366
____________________________________________________________
Training stopped after 8 epochs with best TPR: 0.7355
____________________________________________________________
Training stopped after 12 epochs with best TPR: 0.7370
____________________________________________________________
Training stopped after 10 epochs with best TPR: 0.7317
____________________________________________________________
Training stopped after 12 epochs with best TPR: 0.7345
____________________________________________________________
Processing file: Neuronexus_32_50_cell_cell_type_DBC_recording.h5
Training model for file segment: Neuronexus_32_50_cell_cell_type_DBC_recording
Training stopped after 13 epochs with best TPR: 0.7555
____________________________________________________________
Training stoppe

In [11]:
for file in os.listdir("/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/data_generation/setting_2_neuron_num/recordings"):
    print(f"Processing file: {file}")
    recording_raw = se.MEArecRecordingExtractor(file_path=f'/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/data_generation/setting_2_neuron_num/recordings/{file}')
    recording_f = spre.bandpass_filter(recording_raw, freq_min=300, freq_max=3000)
    recording_f = spre.common_reference(recording_f, reference="global", operator="median")
    file_seg = file.split(".")[0]
    spike_inf = pd.read_csv(f"/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/data_generation/setting_2_neuron_num/spike_inf/{file_seg}_spike_inf.csv")

    total_frames = int(600 * 10000)
    chunk_size = 100000  
    window_size = 31
    half_window = window_size // 2

    all_valid_indices = []
    all_windows = []

    for start_frame in range(0, total_frames, chunk_size):
        end_frame = min(start_frame + chunk_size, total_frames)
        
        data_chunk = recording_f.get_traces(
            start_frame=start_frame,
            end_frame=end_frame
        )  # shape: (n_channels, chunk_size)
        
        threshold_result = detect_local_maxima_in_window(
            data_chunk.T,  
            std_multiplier=0.7
        )
        
        threshold_result = np.array(threshold_result) + start_frame
        valid_indices = threshold_result[
            (threshold_result >= start_frame + half_window + 1) & 
            (threshold_result < end_frame - half_window)
        ]
        
        for idx in valid_indices:
            rel_idx = idx - start_frame
            window = data_chunk.T[:, rel_idx-half_window : rel_idx+half_window+1]
            all_windows.append(window)
        
        all_valid_indices.extend(valid_indices)

    all_valid_indices = np.array(all_valid_indices)
    all_windows = np.stack(all_windows)  

    labels = label_array1_based_on_array2(all_valid_indices, spike_inf['time'], threshold=1)

    labels = np.array(labels) 
    indices_0 = np.where(labels == 0)[0] 
    indices_1 = np.where(labels == 1)[0] 

    target_0_count = len(indices_1) 

    if len(indices_0) > target_0_count:
        sampled_indices_0 = np.random.choice(indices_0, target_0_count, replace=False)
    else:
        sampled_indices_0 = indices_0  

    final_indices = np.concatenate([sampled_indices_0, indices_1])

    np.random.shuffle(final_indices)

    sampled_windows = all_windows[final_indices]
    sampled_labels = labels[final_indices]

    dataset = CustomDataset(sampled_windows, sampled_labels)

    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    batch_size = 1024 
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


    all_valid_indices_val = []
    all_windows_val = []

    for start_frame in range(1200 * 10000, 1600 * 10000, chunk_size):
        end_frame = min(start_frame + chunk_size, 1600 * 10000)
        
        data_chunk = recording_f.get_traces(
            start_frame=start_frame,
            end_frame=end_frame
        )  # shape: (n_channels, chunk_size)
        
        threshold_result = detect_local_maxima_in_window(
            data_chunk.T,  
            std_multiplier=0.7
        )
        
        threshold_result = np.array(threshold_result) + start_frame
        valid_indices = threshold_result[
            (threshold_result >= start_frame + half_window + 1) & 
            (threshold_result < end_frame - half_window)
        ]
        
        for idx in valid_indices:
            rel_idx = idx - start_frame
            window = data_chunk.T[:, rel_idx-half_window : rel_idx+half_window+1]
            all_windows_val.append(window)
        
        all_valid_indices_val.extend(valid_indices)

    all_valid_indices_val = np.array(all_valid_indices_val)
    all_windows_val = np.stack(all_windows_val)  

    labels_val = label_array1_based_on_array2(all_valid_indices_val, spike_inf['time'], threshold=1)
    labels_val = np.array(labels_val) 
    dataset = CustomDataset(all_windows_val, labels_val)
    val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)


    input_size = sampled_windows.shape[1] * sampled_windows.shape[2]
    hidden_size1 = 128
    hidden_size2 = 32
    output_size = 1  
    device = 'cuda'

    val_results = {}
    val_results['tpr'] = []
    val_results['tnr'] = []
    val_results['accuracy'] = []
    print(f"Training model for file segment: {file_seg}")
    os.makedirs(f"/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/spike_detection/train_results/setting_2/{file_seg}", exist_ok=True)
    for trail in range(1, 6):
        criterion = nn.BCELoss()  

        model = Spike_Detection_MLP(input_size, hidden_size1, hidden_size2, 
                                        output_size, n_channels=sampled_windows.shape[1], time_window= sampled_windows.shape[2])
        model = model.to(device)

        optimizer = optim.Adam(model.parameters(), lr=0.0001)

        num_epochs = 50
        tpr_best = 0
        i = 0
        for epoch in range(num_epochs):
            model.train()
            total_loss = 0
            for batch_data, batch_labels in train_loader:
                batch_labels = batch_labels.float().unsqueeze(1)

                batch_data = batch_data.to(device)
                batch_labels = batch_labels.to(device)

                outputs = model(batch_data)
                loss = criterion(outputs, batch_labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()

            # print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}")

            model.eval()
            correct = 0
            total = 0

            true_positive = 0
            true_negative = 0
            false_positive = 0
            false_negative = 0

            with torch.no_grad():
                for batch_data, batch_labels in test_loader:
                    batch_labels = batch_labels.float().unsqueeze(1)
                    batch_data = batch_data.to(device)
                    batch_labels = batch_labels.to(device)

                    outputs = model(batch_data)
                    predicted = (outputs > 0.5).float()  
                    total += batch_labels.size(0)
                    correct += (predicted == batch_labels).sum().item()
                    true_positive += ((predicted == 1) & (batch_labels == 1)).sum().item()
                    true_negative += ((predicted == 0) & (batch_labels == 0)).sum().item()
                    false_positive += ((predicted == 1) & (batch_labels == 0)).sum().item()
                    false_negative += ((predicted == 0) & (batch_labels == 1)).sum().item()


            tpr = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
            tnr = true_negative / (true_negative + false_positive) if (true_negative + false_positive) > 0 else 0
            accuracy = correct / total
            # print(f"Test Accuracy: {100 * correct / total:.2f}%")
            # print(f"Test TPR: {100 * tpr:.2f}%")
            # print(f"Test TNR: {100 * tnr:.2f}%")

            if tpr > tpr_best:
                tpr_best = tpr
                i = 0
                torch.save(model, f'/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/spike_detection/train_results/setting_2/{file_seg}/trail_{trail}.pth')
                # print(f"Best model saved with TPR: {tpr_best:.4f}")
                # print("_" * 60)

            else:
                i += 1
                if i == 3:
                    print(f"Training stopped after {epoch+1} epochs with best TPR: {tpr_best:.4f}")
                    print("_" * 60)
                    break

        val_results['tpr'].append(tpr_best)
        val_results['tnr'].append(tnr)
        val_results['accuracy'].append(accuracy)

    with open(f"/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/spike_detection/eval_results/setting_2/val_results_{file_seg}.pkl", "wb") as f:
        pickle.dump(val_results, f)

    del all_windows, all_windows_val

Processing file: Neuronexus_32_20_recording.h5
Training model for file segment: Neuronexus_32_20_recording
Training stopped after 13 epochs with best TPR: 0.9959
____________________________________________________________
Training stopped after 15 epochs with best TPR: 0.9960
____________________________________________________________
Training stopped after 16 epochs with best TPR: 0.9966
____________________________________________________________
Training stopped after 16 epochs with best TPR: 0.9962
____________________________________________________________
Training stopped after 9 epochs with best TPR: 0.9951
____________________________________________________________
Processing file: Neuronexus_32_40_recording.h5
Training model for file segment: Neuronexus_32_40_recording
Training stopped after 12 epochs with best TPR: 0.9910
____________________________________________________________
Training stopped after 17 epochs with best TPR: 0.9919
____________________________________

In [9]:
for file in os.listdir("/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/data_generation/setting_4_minimum_neuron_distance/recordings"):
    print(f"Processing file: {file}")
    recording_raw = se.MEArecRecordingExtractor(file_path=f'/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/data_generation/setting_4_minimum_neuron_distance/recordings/{file}')
    recording_f = spre.bandpass_filter(recording_raw, freq_min=300, freq_max=3000)
    recording_f = spre.common_reference(recording_f, reference="global", operator="median")
    file_seg = file.split(".")[0]
    spike_inf = pd.read_csv(f"/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/data_generation/setting_4_minimum_neuron_distance/spike_inf/{file_seg}_spike_inf.csv")

    total_frames = int(600 * 10000)
    chunk_size = 100000  
    window_size = 31
    half_window = window_size // 2

    all_valid_indices = []
    all_windows = []

    for start_frame in range(0, total_frames, chunk_size):
        end_frame = min(start_frame + chunk_size, total_frames)
        
        data_chunk = recording_f.get_traces(
            start_frame=start_frame,
            end_frame=end_frame
        )  # shape: (n_channels, chunk_size)
        
        threshold_result = detect_local_maxima_in_window(
            data_chunk.T,  
            std_multiplier=0.7
        )
        
        threshold_result = np.array(threshold_result) + start_frame
        valid_indices = threshold_result[
            (threshold_result >= start_frame + half_window + 1) & 
            (threshold_result < end_frame - half_window)
        ]
        
        for idx in valid_indices:
            rel_idx = idx - start_frame
            window = data_chunk.T[:, rel_idx-half_window : rel_idx+half_window+1]
            all_windows.append(window)
        
        all_valid_indices.extend(valid_indices)

    all_valid_indices = np.array(all_valid_indices)
    all_windows = np.stack(all_windows)  

    labels = label_array1_based_on_array2(all_valid_indices, spike_inf['time'], threshold=1)

    labels = np.array(labels) 
    indices_0 = np.where(labels == 0)[0] 
    indices_1 = np.where(labels == 1)[0] 

    target_0_count = len(indices_1) 

    if len(indices_0) > target_0_count:
        sampled_indices_0 = np.random.choice(indices_0, target_0_count, replace=False)
    else:
        sampled_indices_0 = indices_0  

    final_indices = np.concatenate([sampled_indices_0, indices_1])

    np.random.shuffle(final_indices)

    sampled_windows = all_windows[final_indices]
    sampled_labels = labels[final_indices]

    dataset = CustomDataset(sampled_windows, sampled_labels)

    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    batch_size = 1024 
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


    all_valid_indices_val = []
    all_windows_val = []

    for start_frame in range(1200 * 10000, 1600 * 10000, chunk_size):
        end_frame = min(start_frame + chunk_size, 1600 * 10000)
        
        data_chunk = recording_f.get_traces(
            start_frame=start_frame,
            end_frame=end_frame
        )  # shape: (n_channels, chunk_size)
        
        threshold_result = detect_local_maxima_in_window(
            data_chunk.T,  
            std_multiplier=0.7
        )
        
        threshold_result = np.array(threshold_result) + start_frame
        valid_indices = threshold_result[
            (threshold_result >= start_frame + half_window + 1) & 
            (threshold_result < end_frame - half_window)
        ]
        
        for idx in valid_indices:
            rel_idx = idx - start_frame
            window = data_chunk.T[:, rel_idx-half_window : rel_idx+half_window+1]
            all_windows_val.append(window)
        
        all_valid_indices_val.extend(valid_indices)

    all_valid_indices_val = np.array(all_valid_indices_val)
    all_windows_val = np.stack(all_windows_val)  

    labels_val = label_array1_based_on_array2(all_valid_indices_val, spike_inf['time'], threshold=1)
    labels_val = np.array(labels_val) 
    dataset = CustomDataset(all_windows_val, labels_val)
    val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)


    input_size = sampled_windows.shape[1] * sampled_windows.shape[2]
    hidden_size1 = 128
    hidden_size2 = 32
    output_size = 1  
    device = 'cuda'

    val_results = {}
    val_results['tpr'] = []
    val_results['tnr'] = []
    val_results['accuracy'] = []
    print(f"Training model for file segment: {file_seg}")
    os.makedirs(f"/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/spike_detection/train_results/setting_4/{file_seg}", exist_ok=True)
    for trail in range(1, 6):
        criterion = nn.BCELoss()  

        model = Spike_Detection_MLP(input_size, hidden_size1, hidden_size2, 
                                        output_size, n_channels=sampled_windows.shape[1], time_window= sampled_windows.shape[2])
        model = model.to(device)

        optimizer = optim.Adam(model.parameters(), lr=0.0001)

        num_epochs = 50
        tpr_best = 0
        i = 0
        for epoch in range(num_epochs):
            model.train()
            total_loss = 0
            for batch_data, batch_labels in train_loader:
                batch_labels = batch_labels.float().unsqueeze(1)

                batch_data = batch_data.to(device)
                batch_labels = batch_labels.to(device)

                outputs = model(batch_data)
                loss = criterion(outputs, batch_labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()

            # print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}")

            model.eval()
            correct = 0
            total = 0

            true_positive = 0
            true_negative = 0
            false_positive = 0
            false_negative = 0

            with torch.no_grad():
                for batch_data, batch_labels in test_loader:
                    batch_labels = batch_labels.float().unsqueeze(1)
                    batch_data = batch_data.to(device)
                    batch_labels = batch_labels.to(device)

                    outputs = model(batch_data)
                    predicted = (outputs > 0.5).float()  
                    total += batch_labels.size(0)
                    correct += (predicted == batch_labels).sum().item()
                    true_positive += ((predicted == 1) & (batch_labels == 1)).sum().item()
                    true_negative += ((predicted == 0) & (batch_labels == 0)).sum().item()
                    false_positive += ((predicted == 1) & (batch_labels == 0)).sum().item()
                    false_negative += ((predicted == 0) & (batch_labels == 1)).sum().item()


            tpr = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
            tnr = true_negative / (true_negative + false_positive) if (true_negative + false_positive) > 0 else 0
            accuracy = correct / total
            # print(f"Test Accuracy: {100 * correct / total:.2f}%")
            # print(f"Test TPR: {100 * tpr:.2f}%")
            # print(f"Test TNR: {100 * tnr:.2f}%")

            if tpr > tpr_best:
                tpr_best = tpr
                i = 0
                torch.save(model, f'/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/spike_detection/train_results/setting_4/{file_seg}/trail_{trail}.pth')
                # print(f"Best model saved with TPR: {tpr_best:.4f}")
                # print("_" * 60)

            else:
                i += 1
                if i == 3:
                    print(f"Training stopped after {epoch+1} epochs with best TPR: {tpr_best:.4f}")
                    print("_" * 60)
                    break

        val_results['tpr'].append(tpr_best)
        val_results['tnr'].append(tnr)
        val_results['accuracy'].append(accuracy)

    with open(f"/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/spike_detection/eval_results/setting_4/val_results_{file_seg}.pkl", "wb") as f:
        pickle.dump(val_results, f)

    del all_windows, all_windows_val

Processing file: Neuronexus_32_50_cell_min_distance_15_recording.h5
Training model for file segment: Neuronexus_32_50_cell_min_distance_15_recording
Training stopped after 9 epochs with best TPR: 0.9693
____________________________________________________________
Training stopped after 9 epochs with best TPR: 0.9710
____________________________________________________________
Training stopped after 11 epochs with best TPR: 0.9720
____________________________________________________________
Training stopped after 11 epochs with best TPR: 0.9714
____________________________________________________________
Training stopped after 14 epochs with best TPR: 0.9725
____________________________________________________________
Processing file: Neuronexus_32_50_cell_min_distance_10_recording.h5
Training model for file segment: Neuronexus_32_50_cell_min_distance_10_recording
Training stopped after 15 epochs with best TPR: 0.9671
____________________________________________________________
Training

- Evaluation

In [None]:
recording_raw = se.MEArecRecordingExtractor(file_path='/media/ubuntu/sda/Spike_Sorting/simulation/recording/recording_Neuronexus_3600s_50cells_1.h5')

recording_f = spre.bandpass_filter(recording_raw, freq_min=300, freq_max=3000)
recording_f = spre.common_reference(recording_f, reference="global", operator="median")
spike_inf = pd.read_csv("/media/ubuntu/sda/Spike_Sorting/simulation/sorting/recording_Neuronexus_3600s_50cells_1.csv")

hidden_size1 = 128
hidden_size2 = 32
output_size = 1  
device = 'cuda'

data_temp = recording_f.get_traces(start_frame=700 * 10000, end_frame= 1400 * 10000).T

In [None]:
threshold_result = detect_local_maxima_in_window(data_temp)
threshold_result = np.array(threshold_result)
valid_indices = threshold_result[(threshold_result > 45)]
valid_indices = valid_indices[valid_indices < data_temp.shape[1] - 46]

labels = label_array1_based_on_array2(valid_indices, spike_inf['time'], threshold=1)

labels = np.array(labels) 

sampled_data = extract_windows(data_temp, valid_indices, window_size=91)
sampled_labels = labels

dataset = CustomDataset(sampled_data, sampled_labels)

batch_size = 1024 
val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

input_size = sampled_data.shape[1] * sampled_data.shape[2]



accuracy_list = []  
tpr_list = []
tnr_list = []

model = torch.load("/media/ubuntu/sda/Spike_Sorting/simulation/Neuronexus/spike_detection_model/spike_detection_model_1.pth", weights_only=False)
model = model.to(device)

model.eval()
correct = 0
total = 0

true_positive = 0
true_negative = 0
false_positive = 0
false_negative = 0

with torch.no_grad():
    for batch_data, batch_labels in test_loader:
        batch_labels = batch_labels.float().unsqueeze(1)
        batch_data = batch_data.to(device)
        batch_labels = batch_labels.to(device)

        outputs = model(batch_data)
        predicted = (outputs > 0.5).float()  
        total += batch_labels.size(0)
        correct += (predicted == batch_labels).sum().item()
        true_positive += ((predicted == 1) & (batch_labels == 1)).sum().item()
        true_negative += ((predicted == 0) & (batch_labels == 0)).sum().item()
        false_positive += ((predicted == 1) & (batch_labels == 0)).sum().item()
        false_negative += ((predicted == 0) & (batch_labels == 1)).sum().item()


tpr = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
tnr = true_negative / (true_negative + false_positive) if (true_negative + false_positive) > 0 else 0

print(f"Test Accuracy: {100 * correct / total:.2f}%")
print(f"Test TPR: {100 * tpr:.2f}%")
print(f"Test TNR: {100 * tnr:.2f}%")



Test Accuracy: 96.65%
Test TPR: 98.95%
Test TNR: 94.40%
