In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
import spikeinterface as si
import matplotlib.pyplot as plt
import os
from matplotlib.backends.backend_pdf import PdfPages

from tqdm import tqdm


import sys
import spikeinterface as si
import matplotlib.pyplot as plt
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
import spikeinterface.widgets as sw
import spikeinterface.qualitymetrics as sqm
import json
import probeinterface

from probeinterface import Probe, ProbeGroup
from probeinterface.plotting import plot_probe, plot_probegroup
from probeinterface import generate_dummy_probe, generate_linear_probe
from probeinterface import write_probeinterface, read_probeinterface
from probeinterface import write_prb, read_prb
from torch.nn.functional import max_pool1d


import torch.nn.functional as F
from pathlib import Path
import networkx as nx


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torch.utils.data import Subset

# from function.Function import *

In [2]:
def count_array2_in_range_of_array1(array1, array2, threshold=5):

    sorted_array1 = np.sort(array1)
    
    lefts = array2 - threshold
    rights = array2 + threshold
    
    left_indices = np.searchsorted(sorted_array1, lefts, side='left')
    
    right_indices = np.searchsorted(sorted_array1, rights, side='right')
    
    has_within_range = right_indices > left_indices
    
    count = np.sum(has_within_range)
    
    return count


def detect_local_maxima_in_window(data, window_size=20, std_multiplier=2):

    """
    在每个滑动窗口范围内检测局部最大值的索引，并确保最大值大于两倍的标准差。

    参数:
    data : numpy.ndarray
        输入数据，形状为 (n_rows, n_columns)。
    window_size : int
        滑动窗口的大小，用于定义局部范围，默认为 20。
    std_multiplier : float
        标准差的倍数，用于筛选局部最大值，默认为 2。

    返回:
    local_maxima_indices : list of numpy.ndarray
        每行局部最大值的索引列表，每个元素是对应行局部最大值的索引数组。
    """
    local_maxima_indices = []

    for row in data:
        maxima_indices = []
        row_std = np.std(row)
        threshold = std_multiplier * row_std

        for start in range(0, len(row), window_size):
            end = min(start + window_size, len(row))
            window = row[start:end]
            
            if len(window) > 0:
                local_max_index = np.argmax(window)
                local_max_value = window[local_max_index]
                
                if local_max_value > threshold:
                    maxima_indices.append(start + local_max_index)  
        
        local_maxima_indices.extend(maxima_indices)
        local_maxima_indices = list(set(local_maxima_indices))  

    return local_maxima_indices


def cluster_label_array1_based_on_array2(array1, array2, threshold=5):

    """
    根据 array2 的 'time' 和 'cluster' 对 array1 进行标记。
    如果 array1 中的某个值在 threshold 范围内存在于 array2 的 'time' 中，则标记为对应的 'cluster' 值，否则为 0。
    
    参数:
    array1 : numpy.ndarray
        要标记的数组。
    array2 : numpy.ndarray
        包含 'time' 和 'cluster' 的二维数组。
        第一列为 'time'，第二列为 'cluster'。
    threshold : int
        判断范围的阈值。
    
    返回:
    labels : numpy.ndarray
        长度为 len(array1) 的标签数组，值为 array2 中的 'cluster' 或 0。
    """

    array2 = np.array(array2.iloc[:, [5, 1]])
    sorted_indices = np.argsort(array2[:, 0])
    sorted_array2 = array2[sorted_indices]
    
    labels = np.zeros(len(array1), dtype=int)
    
    # 遍历 array1 中的每个元素
    for i, value in enumerate(array1):
        # 计算当前值的范围
        left = value - threshold
        right = value + threshold
        
        left_index = np.searchsorted(sorted_array2[:, 0], left, side='left')
        right_index = np.searchsorted(sorted_array2[:, 0], right, side='right')
        
        # 如果范围内存在值，则标记为对应的 'cluster'
        if right_index > left_index:
            # 获取范围内的第一个匹配值的 'cluster'
            labels[i] = sorted_array2[left_index, 1]
    
    return labels


def label_array1_based_on_array2(array1, array2, threshold=5):

    """
    根据 array2 的值对 array1 进行标记。
    如果 array1 中的某个值在 threshold 范围内存在于 array2 中，则标记为 1，否则为 0。
    
    参数:
    array1 : numpy.ndarray
        要标记的数组。
    array2 : numpy.ndarray
        用于判断的数组。
    threshold : int
        判断范围的阈值。
    
    返回:
    labels : numpy.ndarray
        长度为 len(array1) 的标签数组，值为 0 或 1。
    """
    # 对 array2 进行排序以加速搜索
    sorted_array2 = np.sort(array2)
    
    # 初始化标签数组，默认值为 0
    labels = np.zeros(len(array1), dtype=int)
    
    # 遍历 array1 中的每个元素
    for i, value in enumerate(array1):
        # 计算当前值的范围
        left = value - threshold
        right = value + threshold
        
        # 使用二分搜索判断范围内是否存在值
        left_index = np.searchsorted(sorted_array2, left, side='left')
        right_index = np.searchsorted(sorted_array2, right, side='right')
        
        # 如果范围内存在值，则标记为 1
        if right_index > left_index:
            labels[i] = 1
    
    return labels


def extract_windows(data, indices, window_size=61):
    """
    根据给定的时间点索引提取窗口。
    
    参数:
    data : numpy.ndarray
        输入数据，形状为 (n_channels, time)
    indices : numpy.ndarray
        时间点索引数组，用于指定需要提取窗口的中心点
    window_size : int
        窗口长度，默认为61（对应time-30到time+31）
    
    返回:
    windows : numpy.ndarray
        提取的窗口数据，形状为 (len(indices), n_channels, window_size)
    """
    n_channels, time_length = data.shape
    half_window = window_size // 2

    if np.any(indices < half_window) or np.any(indices >= time_length - half_window):
        raise ValueError("Some indices are out of bounds for the given window size.")

    windows = []
    for idx in indices:
        window = data[:, idx - half_window:idx + half_window + 1]
        windows.append(window)

    windows = np.array(windows)
    return windows

In [3]:
probe_30channel = read_probeinterface('/media/ubuntu/sda/data/probe.json')
probe_30channel_df = probe_30channel.to_dataframe()
probe_30channel_df = np.array(probe_30channel_df.iloc[:, 1:3])

In [4]:
eps=1e-5
distance_threshold = np.sqrt( 50 * 50 + 100 * 100 + 1)
dist_matrix = np.linalg.norm(probe_30channel_df[:, np.newaxis] - probe_30channel_df, axis=2)
np.fill_diagonal(dist_matrix, 0)
dist_matrix[dist_matrix < eps] = eps
inv_dist = np.zeros_like(dist_matrix)

inv_dist = np.where(dist_matrix > 0, 1, 0)
np.fill_diagonal(inv_dist, 0)  
if distance_threshold is not None:
    inv_dist[dist_matrix > distance_threshold] = 0

graph = nx.from_numpy_array(inv_dist)
maximal_cliques = list(nx.find_cliques(graph))
cliques_dict = {i: clique for i, clique in enumerate(maximal_cliques)}

In [5]:
recording_raw = se.read_blackrock(file_path='/media/ubuntu/sda/data/mouse6/ns4/natural_image/mouse6_021322_natural_image001.ns4')
recording_recorded = recording_raw.remove_channels(["98", '31', '32'])
recording_stimulated = recording_raw.channel_slice(['98'])

recording_f = spre.bandpass_filter(recording_recorded, freq_min=750, freq_max=3000)
recording_f = spre.common_reference(recording_f, reference="global", operator="median")

data = recording_f.get_traces().astype("float32").T

threshold_result = detect_local_maxima_in_window(data)

In [6]:
threshold_result = np.array(threshold_result)
valid_indices = threshold_result[(threshold_result > 30)]
valid_indices = valid_indices[valid_indices < data.shape[1] - 31]

In [7]:
spike_inf = pd.read_csv("/media/ubuntu/sda/Spike_Sorting/sorting_results/021322/spike_inf.csv")

In [8]:
labels = label_array1_based_on_array2(valid_indices, spike_inf['time'], threshold=1)

In [9]:
reordered_indices = []
for key in sorted(cliques_dict.keys()):  
    reordered_indices.extend(cliques_dict[key])

In [10]:
class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]
    

labels = np.array(labels) 
indices_0 = np.where(labels == 0)[0] 
indices_1 = np.where(labels == 1)[0] 

target_0_count = len(indices_1) 

if len(indices_0) > target_0_count:
    sampled_indices_0 = np.random.choice(indices_0, target_0_count, replace=False)
else:
    sampled_indices_0 = indices_0  

final_indices = np.concatenate([sampled_indices_0, indices_1])

np.random.shuffle(final_indices)


sampled_data = extract_windows(data, valid_indices[final_indices], window_size=31)
sampled_labels = labels[final_indices]

reordered_data = sampled_data[:, reordered_indices, :]

In [14]:
sampled_data

array([[[ -13.,   -2.,    3., ...,  -16.,    1.,   22.],
        [  91.,   64.,    2., ...,   38.,  126.,   42.],
        [  37.,   27.,    1., ...,    0.,   43.,   28.],
        ...,
        [ -19.,   -2.,   11., ...,    7.,  -16.,  -21.],
        [  -5.,   -2.,   -3., ...,   29.,   32.,    8.],
        [ -17.,    2.,   33., ...,   -4.,   -5.,   -4.]],

       [[  25.,   31.,    2., ...,   39.,   45.,   13.],
        [ -63.,  -58.,  -44., ...,  -41.,  -13.,   54.],
        [ -22.,  -19.,  -18., ...,    2.,    8.,   23.],
        ...,
        [  14.,   34.,   25., ...,   -2.,   50.,   14.],
        [ -34.,    8.,   16., ...,  -13.,   -2.,   25.],
        [ -18.,   34.,   29., ...,   32.,    9.,   -1.]],

       [[ -41.,  -49.,  -18., ...,   18.,   -7.,  -17.],
        [  27.,   13.,   -8., ...,  -80.,   -7.,   61.],
        [  16.,   31.,   15., ...,   -2.,    7.,   14.],
        ...,
        [ -23.,    3.,   -2., ...,   30.,  -25.,  -49.],
        [ -33.,  -18.,    1., ...,    4.,  -2

In [11]:
dataset = CustomDataset(reordered_data, sampled_labels)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

batch_size = 1024 
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

class Spike_Detection_MLP(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2, output_size):
        super(Spike_Detection_MLP, self).__init__()
        # self.conv1 = nn.Conv2d(
        #             in_channels=1,         
        #             out_channels=12,         
        #             kernel_size=(6, 1),     
        #             stride=(6, 1),          
        #             padding=0,             
        #             bias=True              
        #         )
        
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_size2, output_size)
        self.sigmoid = nn.Sigmoid()  

    def forward(self, x):
        # x = x.unsqueeze(1)
        # x = self.conv1(x)
        # x = x.mean(dim=1)

        x = x.reshape(-1, 31 * 30)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        x = self.sigmoid(x)
        return x

input_size = 30 * 31
hidden_size1 = 128
hidden_size2 = 32
output_size = 1  
device = 'cuda'


In [12]:
accuracy_list = []  
tpr_list = []
tnr_list = []


#for trail in range(1, 6):
criterion = nn.BCELoss()  

model = Spike_Detection_MLP(input_size, hidden_size1, hidden_size2, output_size)
model = model.to(device)

optimizer = optim.Adam(
    model.parameters(),
    lr=0.00001,
    weight_decay=0.0001 
)

num_epochs = 80
tpr_best = 0
i = 0
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch_data, batch_labels in train_loader:
        batch_labels = batch_labels.float().unsqueeze(1)

        batch_data = batch_data.to(device)
        batch_labels = batch_labels.to(device)

        outputs = model(batch_data)
        loss = criterion(outputs, batch_labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}")

    model.eval()
    correct = 0
    total = 0

    true_positive = 0
    true_negative = 0
    false_positive = 0
    false_negative = 0

    with torch.no_grad():
        for batch_data, batch_labels in test_loader:
            batch_labels = batch_labels.float().unsqueeze(1)
            batch_data = batch_data.to(device)
            batch_labels = batch_labels.to(device)

            outputs = model(batch_data)
            predicted = (outputs > 0.5).float()  
            total += batch_labels.size(0)
            correct += (predicted == batch_labels).sum().item()
            true_positive += ((predicted == 1) & (batch_labels == 1)).sum().item()
            true_negative += ((predicted == 0) & (batch_labels == 0)).sum().item()
            false_positive += ((predicted == 1) & (batch_labels == 0)).sum().item()
            false_negative += ((predicted == 0) & (batch_labels == 1)).sum().item()


    tpr = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
    tnr = true_negative / (true_negative + false_positive) if (true_negative + false_positive) > 0 else 0

    print(f"Test Accuracy: {100 * correct / total:.2f}%")
    print(f"Test TPR: {100 * tpr:.2f}%")
    print(f"Test TNR: {100 * tnr:.2f}%")

    # if tpr > tpr_best:
    #     tpr_best = tpr
    #     i = 0
    #     torch.save(model, f'new_model/spike_detection_model_{trail}.pth')
    #     print(f"Best model saved with TPR: {tpr_best:.4f}")
    #     print("_" * 60)

    # else:
    #     i += 1
    #     if i == 3:
    #         print(f"Training stopped after {epoch+1} epochs with best TPR: {tpr_best:.4f}")
    #         print("_" * 60)
    #         break
    
    # all_labels = []
    # predicted_labels = []

    # model.eval()
    # with torch.no_grad():
    #     for batch_data, batch_labels in test_loader:
    #         batch_labels = batch_labels.float().unsqueeze(1)
    #         batch_data = batch_data.to(device)
    #         batch_labels = batch_labels.to(device)

    #         outputs = model(batch_data)
    #         predicted = (outputs > 0.5).float()  
            
    #         all_labels.extend(batch_labels.cpu().numpy())
    #         predicted_labels.extend(predicted.cpu().numpy())

    # all_labels = np.array(all_labels)
    # all_labels = np.concatenate(all_labels, axis=0).astype(int)
    # predicted_labels = np.array(predicted_labels)
    # predicted_labels = np.concatenate(predicted_labels, axis=0).astype(int) 

    # true_positive += ((predicted == 1) & (batch_labels == 1)).sum().item()
    # true_negative += ((predicted == 0) & (batch_labels == 0)).sum().item()
    # false_positive += ((predicted == 1) & (batch_labels == 0)).sum().item()
    # false_negative += ((predicted == 0) & (batch_labels == 1)).sum().item()

    # accuracy = 100 * (true_positive + true_negative) / len(all_labels)
    # tpr = true_positive / (true_positive + false_negative) * 100 if (true_positive + false_negative) > 0 else 0
    # tnr = true_negative / (true_negative + false_positive) * 100 if (true_negative + false_positive) > 0 else 0

    # accuracy_list.append(accuracy)
    # tpr_list.append(tpr)
    # tnr_list.append(tnr)

Epoch [1/80], Loss: 0.6463
Test Accuracy: 72.93%
Test TPR: 71.52%
Test TNR: 74.35%
Epoch [2/80], Loss: 0.5097
Test Accuracy: 77.09%
Test TPR: 74.95%
Test TNR: 79.23%
Epoch [3/80], Loss: 0.4694
Test Accuracy: 78.98%
Test TPR: 77.01%
Test TNR: 80.95%
Epoch [4/80], Loss: 0.4459
Test Accuracy: 80.16%
Test TPR: 78.03%
Test TNR: 82.28%
Epoch [5/80], Loss: 0.4295
Test Accuracy: 80.95%
Test TPR: 79.02%
Test TNR: 82.88%
Epoch [6/80], Loss: 0.4170
Test Accuracy: 81.53%
Test TPR: 79.08%
Test TNR: 83.99%
Epoch [7/80], Loss: 0.4068
Test Accuracy: 81.98%
Test TPR: 79.67%
Test TNR: 84.30%
Epoch [8/80], Loss: 0.3983
Test Accuracy: 82.37%
Test TPR: 80.34%
Test TNR: 84.40%
Epoch [9/80], Loss: 0.3909
Test Accuracy: 82.70%
Test TPR: 80.79%
Test TNR: 84.62%
Epoch [10/80], Loss: 0.3844
Test Accuracy: 82.98%
Test TPR: 81.03%
Test TNR: 84.94%
Epoch [11/80], Loss: 0.3787
Test Accuracy: 83.24%
Test TPR: 81.87%
Test TNR: 84.62%
Epoch [12/80], Loss: 0.3734
Test Accuracy: 83.48%
Test TPR: 82.68%
Test TNR: 84.29%
E