In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
import spikeinterface as si
import matplotlib.pyplot as plt
import os
from matplotlib.backends.backend_pdf import PdfPages

from tqdm import tqdm


import sys
import spikeinterface as si
import matplotlib.pyplot as plt
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
import spikeinterface.widgets as sw
import spikeinterface.qualitymetrics as sqm
import json
import probeinterface

from probeinterface import Probe, ProbeGroup
from probeinterface.plotting import plot_probe, plot_probegroup
from probeinterface import generate_dummy_probe, generate_linear_probe
from probeinterface import write_probeinterface, read_probeinterface
from probeinterface import write_prb, read_prb
from torch.nn.functional import max_pool1d


import torch.nn.functional as F
from pathlib import Path


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torch.utils.data import Subset

# from function.Function import *

In [2]:
def count_array2_in_range_of_array1(array1, array2, threshold=5):

    sorted_array1 = np.sort(array1)
    
    lefts = array2 - threshold
    rights = array2 + threshold
    
    left_indices = np.searchsorted(sorted_array1, lefts, side='left')
    
    right_indices = np.searchsorted(sorted_array1, rights, side='right')
    
    has_within_range = right_indices > left_indices
    
    count = np.sum(has_within_range)
    
    return count


def detect_local_maxima_in_window(data, window_size=20, std_multiplier=2):

    """
    在每个滑动窗口范围内检测局部最大值的索引，并确保最大值大于两倍的标准差。

    参数:
    data : numpy.ndarray
        输入数据，形状为 (n_rows, n_columns)。
    window_size : int
        滑动窗口的大小，用于定义局部范围，默认为 20。
    std_multiplier : float
        标准差的倍数，用于筛选局部最大值，默认为 2。

    返回:
    local_maxima_indices : list of numpy.ndarray
        每行局部最大值的索引列表，每个元素是对应行局部最大值的索引数组。
    """
    local_maxima_indices = []

    for row in data:
        maxima_indices = []
        row_std = np.std(row)
        threshold = std_multiplier * row_std

        for start in range(0, len(row), window_size):
            end = min(start + window_size, len(row))
            window = row[start:end]
            
            if len(window) > 0:
                local_max_index = np.argmax(window)
                local_max_value = window[local_max_index]
                
                if local_max_value > threshold:
                    maxima_indices.append(start + local_max_index)  
        
        local_maxima_indices.extend(maxima_indices)
        local_maxima_indices = list(set(local_maxima_indices))  

    return local_maxima_indices


def cluster_label_array1_based_on_array2(array1, array2, threshold=5):

    """
    根据 array2 的 'time' 和 'cluster' 对 array1 进行标记。
    如果 array1 中的某个值在 threshold 范围内存在于 array2 的 'time' 中，则标记为对应的 'cluster' 值，否则为 0。
    
    参数:
    array1 : numpy.ndarray
        要标记的数组。
    array2 : numpy.ndarray
        包含 'time' 和 'cluster' 的二维数组。
        第一列为 'time'，第二列为 'cluster'。
    threshold : int
        判断范围的阈值。
    
    返回:
    labels : numpy.ndarray
        长度为 len(array1) 的标签数组，值为 array2 中的 'cluster' 或 0。
    """

    array2 = np.array(array2.iloc[:, [5, 1]])
    sorted_indices = np.argsort(array2[:, 0])
    sorted_array2 = array2[sorted_indices]
    
    labels = np.zeros(len(array1), dtype=int)
    
    # 遍历 array1 中的每个元素
    for i, value in enumerate(array1):
        # 计算当前值的范围
        left = value - threshold
        right = value + threshold
        
        left_index = np.searchsorted(sorted_array2[:, 0], left, side='left')
        right_index = np.searchsorted(sorted_array2[:, 0], right, side='right')
        
        # 如果范围内存在值，则标记为对应的 'cluster'
        if right_index > left_index:
            # 获取范围内的第一个匹配值的 'cluster'
            labels[i] = sorted_array2[left_index, 1]
    
    return labels


def label_array1_based_on_array2(array1, array2, threshold=5):

    """
    根据 array2 的值对 array1 进行标记。
    如果 array1 中的某个值在 threshold 范围内存在于 array2 中，则标记为 1，否则为 0。
    
    参数:
    array1 : numpy.ndarray
        要标记的数组。
    array2 : numpy.ndarray
        用于判断的数组。
    threshold : int
        判断范围的阈值。
    
    返回:
    labels : numpy.ndarray
        长度为 len(array1) 的标签数组，值为 0 或 1。
    """
    # 对 array2 进行排序以加速搜索
    sorted_array2 = np.sort(array2)
    
    # 初始化标签数组，默认值为 0
    labels = np.zeros(len(array1), dtype=int)
    
    # 遍历 array1 中的每个元素
    for i, value in enumerate(array1):
        # 计算当前值的范围
        left = value - threshold
        right = value + threshold
        
        # 使用二分搜索判断范围内是否存在值
        left_index = np.searchsorted(sorted_array2, left, side='left')
        right_index = np.searchsorted(sorted_array2, right, side='right')
        
        # 如果范围内存在值，则标记为 1
        if right_index > left_index:
            labels[i] = 1
    
    return labels


def extract_windows(data, indices, window_size=61):
    """
    根据给定的时间点索引提取窗口。
    
    参数:
    data : numpy.ndarray
        输入数据，形状为 (n_channels, time)
    indices : numpy.ndarray
        时间点索引数组，用于指定需要提取窗口的中心点
    window_size : int
        窗口长度，默认为61（对应time-30到time+31）
    
    返回:
    windows : numpy.ndarray
        提取的窗口数据，形状为 (len(indices), n_channels, window_size)
    """
    n_channels, time_length = data.shape
    half_window = window_size // 2

    if np.any(indices < half_window) or np.any(indices >= time_length - half_window):
        raise ValueError("Some indices are out of bounds for the given window size.")

    windows = []
    for idx in indices:
        window = data[:, idx - half_window:idx + half_window + 1]
        windows.append(window)

    windows = np.array(windows)
    return windows

In [3]:
class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx].astype(np.float32), self.labels[idx]
    
class Spike_Detection_MLP(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2, output_size, n_channels, time_window):
        super(Spike_Detection_MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_size2, 16)
        self.relu3 = nn.ReLU()
        self.fc4 = nn.Linear(16, output_size)
        self.sigmoid = nn.Sigmoid()  

        self.n_channels = n_channels
        self.time_window = time_window
    def forward(self, x):
        x = x.reshape(-1, self.n_channels * self.time_window)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        x = self.relu3(x)
        x = self.fc4(x)
        x = self.sigmoid(x)
        return x

In [4]:
eval_date_list = ['021322', '022522', '031722', '042422', '052422', '062422',
                  '072322', '082322', '092422', '102122', '112022', '122022']

In [5]:
tpr_dict = {}
tnr_dict = {}
accuracy_dict = {}

for date in eval_date_list:
    tpr_dict[date] = []
    tnr_dict[date] = []
    accuracy_dict[date] = []
    print(f"Processing date: {date}")

    recording_raw = se.read_blackrock(file_path=f'/media/ubuntu/sda/data/mouse6/ns4/natural_image/mouse6_{date}_natural_image_001.ns4')
    recording_recorded = recording_raw.remove_channels(["98", '31', '32'])
    recording_stimulated = recording_raw.channel_slice(['98'])

    recording_f = spre.bandpass_filter(recording_recorded, freq_min=300, freq_max=3000)
    recording_f = spre.common_reference(recording_f, reference="global", operator="median")
    spike_inf = pd.read_csv(f"/media/ubuntu/sda/Spike_Sorting/paper_architecture/01_real_data/01_flexible_probe_30_channels/kilosort_spike_sorting/sorting_results/{date}/spike_inf.csv")


    total_frames = int(recording_f.get_total_duration() * 10000)
    chunk_size = 100000  
    window_size = 31
    half_window = window_size // 2
    
    input_size = 30 * 31
    hidden_size1 = 128
    hidden_size2 = 32
    output_size = 1  
    device = 'cuda'
    
    for trail in range(1, 6):
        print(f"Processing trail: {trail}")
        criterion = nn.BCELoss()  

        model = torch.load(f'/media/ubuntu/sda/Spike_Sorting/paper_architecture/01_real_data/01_flexible_probe_30_channels/spike_detection/train_results/trail_{trail}.pth', weights_only=False)
        model = model.to(device)

        optimizer = optim.Adam(model.parameters(), lr=0.0001)

        correct = 0
        total = 0

        true_positive = 0
        true_negative = 0
        false_positive = 0
        false_negative = 0

        for start_frame in range(0, total_frames, chunk_size):
            all_valid_indices = []
            all_windows = []
            end_frame = min(start_frame + chunk_size, total_frames)
            
            data_chunk = recording_f.get_traces(
                start_frame=start_frame,
                end_frame=end_frame
            )  # shape: (n_channels, chunk_size)
            
            threshold_result = detect_local_maxima_in_window(
                data_chunk.T,  
                std_multiplier=0.7
            )
            
            threshold_result = np.array(threshold_result) + start_frame
            valid_indices = threshold_result[
                (threshold_result >= start_frame + half_window + 1) & 
                (threshold_result < end_frame - half_window)
            ]
            
            for idx in valid_indices:
                rel_idx = idx - start_frame
                window = data_chunk.T[:, rel_idx-half_window : rel_idx+half_window+1]
                all_windows.append(window)
            
            all_valid_indices.extend(valid_indices)

            all_valid_indices = np.array(all_valid_indices)
            all_windows = np.stack(all_windows)  

            labels = label_array1_based_on_array2(all_valid_indices, spike_inf['time'], threshold=1)

            dataset = CustomDataset(all_windows, labels)

            batch_size = 1024 
            val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)


            accuracy_list = []  
            tpr_list = []
            tnr_list = []

            model.eval()

            with torch.no_grad():
                for batch_data, batch_labels in val_loader:
                    batch_labels = batch_labels.float().unsqueeze(1)
                    batch_data = batch_data.to(device)
                    batch_labels = batch_labels.to(device)

                    outputs = model(batch_data)
                    predicted = (outputs > 0.5).float()  
                    total += batch_labels.size(0)
                    correct += (predicted == batch_labels).sum().item()
                    true_positive += ((predicted == 1) & (batch_labels == 1)).sum().item()
                    true_negative += ((predicted == 0) & (batch_labels == 0)).sum().item()
                    false_positive += ((predicted == 1) & (batch_labels == 0)).sum().item()
                    false_negative += ((predicted == 0) & (batch_labels == 1)).sum().item()
        del all_windows

        tpr = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
        tnr = true_negative / (true_negative + false_positive) if (true_negative + false_positive) > 0 else 0
        accuracy = correct / total if total > 0 else 0

        tpr_dict[date].append(tpr)
        tnr_dict[date].append(tnr)
        accuracy_dict[date].append(accuracy)

Processing date: 021322
Processing trail: 1
Processing trail: 2
Processing trail: 3
Processing trail: 4
Processing trail: 5
Processing date: 022522
Processing trail: 1
Processing trail: 2
Processing trail: 3
Processing trail: 4
Processing trail: 5
Processing date: 031722
Processing trail: 1
Processing trail: 2
Processing trail: 3
Processing trail: 4
Processing trail: 5
Processing date: 042422
Processing trail: 1
Processing trail: 2
Processing trail: 3
Processing trail: 4
Processing trail: 5
Processing date: 052422
Processing trail: 1
Processing trail: 2
Processing trail: 3
Processing trail: 4
Processing trail: 5
Processing date: 062422
Processing trail: 1
Processing trail: 2
Processing trail: 3
Processing trail: 4
Processing trail: 5
Processing date: 072322
Processing trail: 1
Processing trail: 2
Processing trail: 3
Processing trail: 4
Processing trail: 5
Processing date: 082322
Processing trail: 1
Processing trail: 2
Processing trail: 3
Processing trail: 4
Processing trail: 5
Processi

In [9]:
import pickle
with open("/media/ubuntu/sda/Spike_Sorting/paper_architecture/01_real_data/01_flexible_probe_30_channels/spike_detection/eval_results/tpr_dict.pkl", "wb") as f:
    pickle.dump(tpr_dict, f)

with open("/media/ubuntu/sda/Spike_Sorting/paper_architecture/01_real_data/01_flexible_probe_30_channels/spike_detection/eval_results/tnr_dict.pkl", "wb") as f:
    pickle.dump(tnr_dict, f)

with open("/media/ubuntu/sda/Spike_Sorting/paper_architecture/01_real_data/01_flexible_probe_30_channels/spike_detection/eval_results/accuracy_dict.pkl", "wb") as f:
    pickle.dump(accuracy_dict, f)

- Figure Generation

In [25]:
import umap
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.colors import ListedColormap


In [28]:
pdf_path = "embedding_visualization.pdf"
with PdfPages(pdf_path) as pdf:
    for date in eval_date_list:
        print(f"Processing date: {date}")

        recording_raw = se.read_blackrock(file_path=f'/media/ubuntu/sda/data/mouse6/ns4/natural_image/mouse6_{date}_natural_image_001.ns4')
        recording_recorded = recording_raw.remove_channels(["98", '31', '32'])
        recording_stimulated = recording_raw.channel_slice(['98'])

        recording_f = spre.bandpass_filter(recording_recorded, freq_min=300, freq_max=3000)
        recording_f = spre.common_reference(recording_f, reference="global", operator="median")
        spike_inf = pd.read_csv(f"/media/ubuntu/sda/Spike_Sorting/paper_architecture/01_real_data/01_flexible_probe_30_channels/kilosort_spike_sorting/sorting_results/{date}/spike_inf.csv")


        total_frames = int(recording_f.get_total_duration() * 10000)
        chunk_size = 100000  
        window_size = 31
        half_window = window_size // 2
        
        input_size = 30 * 31
        hidden_size1 = 128
        hidden_size2 = 32
        output_size = 1  
        device = 'cuda'
        

        criterion = nn.BCELoss()  

        model = torch.load(f'/media/ubuntu/sda/Spike_Sorting/paper_architecture/01_real_data/01_flexible_probe_30_channels/spike_detection/train_results/trail_1.pth', weights_only=False)
        model = model.to(device)

        optimizer = optim.Adam(model.parameters(), lr=0.0001)

        start_frame = 0
        all_valid_indices = []
        all_windows = []
        end_frame = min(start_frame + chunk_size, total_frames)
        
        data_chunk = recording_f.get_traces(
            start_frame=start_frame,
            end_frame=end_frame
        )  # shape: (n_channels, chunk_size)
        
        threshold_result = detect_local_maxima_in_window(
            data_chunk.T,  
            std_multiplier=0.7
        )
        
        threshold_result = np.array(threshold_result) + start_frame
        valid_indices = threshold_result[
            (threshold_result >= start_frame + half_window + 1) & 
            (threshold_result < end_frame - half_window)
        ]
        
        for idx in valid_indices:
            rel_idx = idx - start_frame
            window = data_chunk.T[:, rel_idx-half_window : rel_idx+half_window+1]
            all_windows.append(window)
        
        all_valid_indices.extend(valid_indices)

        all_valid_indices = np.array(all_valid_indices)
        all_windows = np.stack(all_windows)  

        labels = label_array1_based_on_array2(all_valid_indices, spike_inf['time'], threshold=1)

        dataset = CustomDataset(all_windows, labels)

        batch_size = 1024 
        val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)


        accuracy_list = []  
        tpr_list = []
        tnr_list = []

        model.eval()
        embeddings = []
        true_labels = []
        predicted_labels = []

        with torch.no_grad():
            for batch_data, batch_labels in val_loader:
                batch_labels = batch_labels.float().unsqueeze(1)
                batch_data = batch_data.to(device)
                batch_labels = batch_labels.to(device)

                outputs = model(batch_data)  
                predicted = (outputs > 0.5).float()  

                x = batch_data.reshape(-1, 30 * 31)
                x = model.fc1(x)
                x = model.relu1(x)
                x = model.fc2(x)
                x = model.relu2(x)
                x = model.fc3(x)
                embeddings.append(x.cpu().numpy())
                true_labels.append(batch_labels.cpu().numpy())
                predicted_labels.append(predicted.cpu().numpy())

        embeddings = np.vstack(embeddings)
        true_labels = np.vstack(true_labels).flatten()
        predicted_labels = np.vstack(predicted_labels).flatten()

        umap_model = umap.UMAP(n_components=2, random_state=42)
        embeddings_2d = umap_model.fit_transform(embeddings)

        custom_cmap = ListedColormap(['lightgray', 'orange'])

        plt.figure(figsize=(8, 8))
        plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=true_labels, cmap=custom_cmap, alpha=1, s = 0.8)
        plt.title(f"{date} UMAP Visualization - True Labels")
        plt.grid(False)
        pdf.savefig()  
        plt.close()

        plt.figure(figsize=(8, 8))
        plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=predicted_labels, cmap=custom_cmap, alpha=1, s = 0.8)
        plt.title(f"{date} UMAP Visualization - Predicted Labels")
        plt.grid(False)
        pdf.savefig()  
        plt.close()

        print(f"Visualization saved to {pdf_path}")

Processing date: 021322
Visualization saved to embedding_visualization.pdf
Processing date: 022522
Visualization saved to embedding_visualization.pdf
Processing date: 031722
Visualization saved to embedding_visualization.pdf
Processing date: 042422
Visualization saved to embedding_visualization.pdf
Processing date: 052422
Visualization saved to embedding_visualization.pdf
Processing date: 062422
Visualization saved to embedding_visualization.pdf
Processing date: 072322
Visualization saved to embedding_visualization.pdf
Processing date: 082322
Visualization saved to embedding_visualization.pdf
Processing date: 092422
Visualization saved to embedding_visualization.pdf
Processing date: 102122
Visualization saved to embedding_visualization.pdf
Processing date: 112022
Visualization saved to embedding_visualization.pdf
Processing date: 122022
Visualization saved to embedding_visualization.pdf
