In [5]:
import numpy as np
import pandas as pd
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
import os
from matplotlib.backends.backend_pdf import PdfPages

from tqdm import tqdm

from sklearn.decomposition import PCA
import umap
import random

import sys
import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre

import matplotlib.pyplot as plt
import json

from probeinterface import write_prb, read_prb

import torch.nn.functional as F
from pathlib import Path

import pickle
from scipy.spatial.distance import cdist
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split, Subset
from sklearn.metrics import accuracy_score
from sklearn.cluster import KMeans
from sklearn.metrics import pairwise_distances

In [89]:
def detect_local_maxima_in_window(data, window_size=20, std_multiplier=2):

    """
    在每个滑动窗口范围内检测局部最大值的索引，并确保最大值大于两倍的标准差。

    参数:
    data : numpy.ndarray
        输入数据，形状为 (n_rows, n_columns)。
    window_size : int
        滑动窗口的大小，用于定义局部范围，默认为 20。
    std_multiplier : float
        标准差的倍数，用于筛选局部最大值，默认为 2。

    返回:
    local_maxima_indices : list of numpy.ndarray
        每行局部最大值的索引列表，每个元素是对应行局部最大值的索引数组。
    """
    local_maxima_indices = []

    for row in data:
        maxima_indices = []
        row_std = np.std(row)
        threshold = std_multiplier * row_std

        for start in range(0, len(row), window_size):
            end = min(start + window_size, len(row))
            window = row[start:end]
            
            if len(window) > 0:
                local_max_index = np.argmax(window)
                local_max_value = window[local_max_index]
                
                if local_max_value > threshold:
                    maxima_indices.append(start + local_max_index)  
        
        local_maxima_indices.extend(maxima_indices)
        local_maxima_indices = list(set(local_maxima_indices))  

    return local_maxima_indices

def cluster_label_array1_based_on_array2(array1, array2, threshold=5, 
                                         cluster_column='cluster'):

    """
    根据 array2 的 'time' 和 'cluster' 对 array1 进行标记。
    如果 array1 中的某个值在 threshold 范围内存在于 array2 的 'time' 中，则标记为对应的 'cluster' 值，否则为 0。
    
    参数:
    array1 : numpy.ndarray
        要标记的数组。
    array2 : numpy.ndarray
        包含 'time' 和 'cluster' 的二维数组。
        第一列为 'time'，第二列为 'cluster'。
    threshold : int
        判断范围的阈值。
    
    返回:
    labels : numpy.ndarray
        长度为 len(array1) 的标签数组，值为 array2 中的 'cluster' 或 0。
    """

    array2 = np.array((array2['time'], array2[cluster_column])).T
    sorted_indices = np.argsort(array2[:, 0])
    sorted_array2 = array2[sorted_indices]
    
    labels = -np.ones(len(array1), dtype=int)
    
    # 遍历 array1 中的每个元素
    for i, value in enumerate(array1):
        # 计算当前值的范围
        left = value - threshold
        right = value + threshold
        
        left_index = np.searchsorted(sorted_array2[:, 0], left, side='left')
        right_index = np.searchsorted(sorted_array2[:, 0], right, side='right')
        
        # 如果范围内存在值，则标记为对应的 'cluster'
        if right_index > left_index:
            # 获取范围内的第一个匹配值的 'cluster'
            labels[i] = sorted_array2[left_index, 1]
    
    return labels

def extract_windows(data, indices, window_size=61):
    """
    根据给定的时间点索引提取窗口。
    
    参数:
    data : numpy.ndarray
        输入数据，形状为 (n_channels, time)
    indices : numpy.ndarray
        时间点索引数组，用于指定需要提取窗口的中心点
    window_size : int
        窗口长度，默认为61（对应time-30到time+31）
    
    返回:
    windows : numpy.ndarray
        提取的窗口数据，形状为 (len(indices), n_channels, window_size)
    """
    n_channels, time_length = data.shape
    half_window = window_size // 2

    if np.any(indices < half_window) or np.any(indices >= time_length - half_window):
        raise ValueError("Some indices are out of bounds for the given window size.")

    windows = []
    for idx in indices:
        window = data[:, idx - half_window:idx + half_window + 1]
        windows.append(window)

    windows = np.array(windows)
    return windows

def compute_cluster_average(sample_data, labels):
    """
    计算 potent_spike_inf 中每个 cluster_predicted 对应的 sample_data 的平均值。
    
    参数:
    - sample_data: np.ndarray, 输入的 (n, 30, 61) 矩阵。
    - potent_spike_inf: pd.DataFrame, 包含 cluster_predicted 信息的 DataFrame。
    - cluster_column: str, cluster 信息所在的列名。
    
    返回:
    - cluster_averages: dict, 每个 cluster 对应的平均值矩阵 (30, 61)。
    """
    cluster_averages = {}
    unique_clusters = np.unique(labels)
    
    for cluster in unique_clusters:
        cluster_indices = np.where(labels == cluster)[0]
        cluster_average = sample_data[cluster_indices].mean(axis=0) 
        cluster_averages[cluster] = cluster_average
    
    return cluster_averages

In [97]:
def process_cluster_averages(cluster_averages, channel_position, n_neighbors=5):
    """
    处理cluster平均波形，基于最大峰值通道选择最近的n_neighbors个通道
    
    参数:
    - cluster_averages: dict, 每个cluster的平均波形 (channels, time)
    - channel_position: 通道位置字典 {channel_id: (x, y)}
    - n_neighbors: 邻居通道数量
    
    返回:
    - processed_data: dict, 处理后的数据，键为cluster_id，值为字典包含:
        'waveform': 选中的通道波形 (n_neighbors+1, time)
        'channel_ids': 选中的通道ID列表
    """
    processed_data = {}
    
    for cluster, avg_matrix in cluster_averages.items():
        # 1. 计算每个通道的峰值（绝对值最大值）
        channel_peaks = np.max(np.abs(avg_matrix), axis=1)
        
        # 2. 找到最大峰值通道
        max_peak_idx = np.argmax(channel_peaks)
        
        # 3. 获取所有通道位置
        all_channels = list(range(avg_matrix.shape[0]))
        positions = np.array([channel_position.get(ch, (np.nan, np.nan)) for ch in all_channels])
        
        # 4. 过滤无效位置
        valid_mask = ~np.isnan(positions[:, 0])
        valid_positions = positions[valid_mask]
        valid_indices = np.array(all_channels)[valid_mask]
        
        if len(valid_positions) == 0:
            continue
        
        # 5. 计算到最大峰值通道的距离
        max_pos = valid_positions[valid_indices == max_peak_idx][0].reshape(1, -1)
        distances = cdist(valid_positions, max_pos).flatten()
        
        # 6. 选择距离最近的n_neighbors+1个通道（包括最大峰值通道自身）
        closest_indices = np.argsort(distances)[:n_neighbors+1]
        selected_indices = valid_indices[closest_indices]
        
        # 7. 存储处理后的数据
        processed_data[cluster] = {
            'waveform': avg_matrix[selected_indices, :],
            'channel_ids': selected_indices.tolist(),
        }
    
    return processed_data

def calculate_position(row, channel_position):
    """
    计算cluster的位置（质心），使用预选的通道
    
    参数:
    - row: DataFrame行，包含'waveform'和'channel_ids'
    - channel_position: 通道位置字典 {channel_id: (x, y)}
    
    返回:
    - pd.Series: 包含position_1和position_2
    """
    waveform = row['waveform']
    channel_ids = row['channel_ids']
    
    # 计算每个通道的波形幅值（平方和）
    channel_amplitudes = np.sum(waveform**2, axis=1)
    
    sum_x_a = 0
    sum_y_a = 0
    sum_a = 0
    
    for i, ch_id in enumerate(channel_ids):
        # 获取通道位置
        x_i, y_i = channel_position.get(ch_id, (0, 0))
        
        # 使用幅值平方作为权重
        a_i_sq = channel_amplitudes[i]
        
        sum_x_a += x_i * a_i_sq
        sum_y_a += y_i * a_i_sq
        sum_a += a_i_sq
    
    if sum_a == 0:
        return pd.Series({'position_1': 0, 'position_2': 0})
    
    x_hat = sum_x_a / sum_a
    y_hat = sum_y_a / sum_a
    return pd.Series({'position_1': x_hat, 'position_2': y_hat})

def calculate_position_waveform(row, channel_position, power=2):
    """
    计算位置波形，使用预选的通道
    
    参数:
    - row: DataFrame行，包含'waveform', 'channel_ids', 'position_1', 'position_2'
    - channel_position: 通道位置字典
    - power: IDW的幂参数
    
    返回:
    - 合成波形 (31个点的数组)
    """
    x_target = row['position_1']
    y_target = row['position_2']
    waveform = row['waveform']
    channel_ids = row['channel_ids']
    
    # 计算通道位置
    positions = np.array([channel_position.get(ch_id, (np.nan, np.nan)) for ch_id in channel_ids])
    
    # 计算每个通道到目标位置的距离
    target_pos = np.array([[x_target, y_target]])
    distances = cdist(positions, target_pos).flatten()
    
    # IDW反距离加权
    weights = 1 / (distances ** power)
    
    # 处理零距离情况
    if np.any(distances == 0):
        zero_idx = np.argwhere(distances == 0).flatten()
        return waveform[zero_idx[0], :]
    
    # 归一化权重
    weights /= np.sum(weights)
    
    # 使用权重合成波形
    synthesized_waveform = np.zeros(waveform.shape[1])
    for t in range(waveform.shape[1]):
        weighted_sum = np.dot(waveform[:, t], weights)
        synthesized_waveform[t] = weighted_sum
    
    return synthesized_waveform

def predict_new(feature, kmeans):
    dists = pairwise_distances(feature, kmeans.cluster_centers_ )
    return np.argmin(dists, axis=1)

def judge_cluster_reality(row, neuron_inf):
    from scipy.stats import pearsonr

    position_threshold =10
    position_condition = (
        (abs(neuron_inf['position_1'] - row['position_1']) <= position_threshold) &
        (abs(neuron_inf['position_2'] - row['position_2']) <= position_threshold)
    )

    candidate_neurons = neuron_inf[position_condition]

    if candidate_neurons.empty:
        return None

    waveform_threshold = 0.95
    row_waveform = row['position_waveform']
    best_match = None
    best_corr = -1 

    for _, candidate in candidate_neurons.iterrows():
        neuron_inf_waveform = candidate['position_waveform']
        corr, _ = pearsonr(row_waveform, neuron_inf_waveform)

        if corr > waveform_threshold and corr > best_corr:
            best_corr = corr
            best_match = candidate['cluster']

    return best_match if best_match is not None else None

In [8]:
class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        self.labels = torch.tensor(self.labels, dtype=torch.long)
        return self.data[idx].astype(np.float32), self.labels[idx]

class Spike_Classification_MLP(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2, num_classes):
        super(Spike_Classification_MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_size2, num_classes)  

    def forward(self, x):
        x = x.reshape(-1, 31 * 32)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)  
        return x

In [22]:
recording_raw = se.MEArecRecordingExtractor(file_path=f'/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/data_generation/setting_1/Neuronexus_32_50_cell_recordings.h5')
recording_f = spre.bandpass_filter(recording_raw, freq_min=300, freq_max=3000)
recording_f = spre.common_reference(recording_f, reference="global", operator="median")
spike_inf = pd.read_csv(f"/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/data_generation/setting_1/spike_inf.csv")
spike_inf['label'] = spike_inf['Neuron']
unique_clusters = np.unique(spike_inf['label'])
cluster_to_index = {cluster: idx for idx, cluster in enumerate(unique_clusters)}
spike_inf['label'] = np.array([cluster_to_index[cluster] for cluster in spike_inf['label']])


In [25]:
total_frames = int(1000 * 10000)
chunk_size = 100000  
window_size = 31
half_window = window_size // 2

all_labels = []
all_windows = []

for start_frame in range(0, total_frames, chunk_size):
    end_frame = min(start_frame + chunk_size, total_frames)
    
    data_chunk = recording_f.get_traces(
        start_frame=start_frame,
        end_frame=end_frame
    )  # shape: (n_channels, chunk_size)
    
    spike_inf_temp = spike_inf[(spike_inf['time'] >= start_frame + half_window) & (spike_inf['time'] < end_frame - half_window)]

    for idx in spike_inf_temp['time']:
        rel_idx = idx - start_frame
        window = data_chunk.T[:, rel_idx-half_window : rel_idx+half_window+1]
        all_windows.append(window)
    
    all_labels.extend(spike_inf_temp['label'])

all_labels = np.array(all_labels)
all_windows = np.stack(all_windows) 


In [26]:
probe = recording_f.get_probegroup().to_dataframe()
channel_position = {}
for i in range(len(probe)):
    channel_position[i] = (probe['x'][i], probe['y'][i])

In [None]:
cluster_averages = compute_cluster_average(all_windows, all_labels)

processed_averages = process_cluster_averages(cluster_averages, n_neighbors=5, channel_position=channel_position)

neuron_inf = pd.DataFrame([
    {"cluster": cluster, "waveform": waveform['waveform'], "channel_ids": waveform['channel_ids']}
    for cluster, waveform in processed_averages.items()
])

neuron_inf[['position_1', 'position_2']] = neuron_inf.apply(
    lambda row: calculate_position(row, channel_position), axis=1
)

neuron_inf['position_waveform'] = neuron_inf.apply(
    lambda row: calculate_position_waveform(row, channel_position, power=2), axis=1
)

In [84]:
data = recording_f.get_traces(
        start_frame=1000 * 10000,
        end_frame=1600 * 10000
    )  

threshold_result = detect_local_maxima_in_window(
        data.T,  
        std_multiplier=0.7
    )
    
threshold_result = np.array(threshold_result) + 1000 * 10000

valid_indices = threshold_result[
    (threshold_result >= 1000 * 10000 + 15 + 1) & 
    (threshold_result < 1600 * 10000 - 15)
]

cluster_labels = cluster_label_array1_based_on_array2(valid_indices, spike_inf, threshold=2, cluster_column='label')

potent_spike_inf = pd.DataFrame((valid_indices, cluster_labels), index= ['time', 'label']).T
potent_spike_inf = potent_spike_inf[potent_spike_inf['label'] != -1]

In [86]:
potent_spike_inf['time'] = potent_spike_inf['time'] - 1000 * 10000
index_to_cluster = {idx: cluster for cluster, idx in cluster_to_index.items()}
potent_spike_inf['Neuron'] = np.array([index_to_cluster[label] for label in potent_spike_inf['label']])


sampled_data = extract_windows(data.T, potent_spike_inf['time'], window_size=31)

val_dataset = CustomDataset(sampled_data, potent_spike_inf['label'].values)
val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=False)

In [87]:
model = torch.load('/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/spike_classification/train_results/spike_classification_model_1.pth', weights_only=False)

In [151]:
all_labels = []
predicted_labels = []
latent_value = []
device = 'cuda'
print("Start Eval...")
model.eval()
with torch.no_grad():
    for batch_data, batch_labels in val_loader:
        batch_labels = batch_labels.float().unsqueeze(1)
        batch_data = batch_data.to(device)
        batch_labels = batch_labels.to(device)

        batch_data = batch_data.reshape(-1, 31 * 32)
        batch_data = model.fc1(batch_data)
        batch_data = model.relu1(batch_data)
        batch_data = model.fc2(batch_data)
        batch_data = model.relu2(batch_data)
        latent_value.append(model.fc3(batch_data).cpu())  
        
        all_labels.extend(batch_labels.cpu().numpy())

all_labels = np.array(all_labels)
latent_value = torch.cat(latent_value, dim=0).numpy()

print('Start KMeans...')
latent_value_subset = latent_value[random.sample(range(len(latent_value)), 50000), :]
final_kmeans = KMeans(n_clusters=80, n_init=10, random_state=42).fit(latent_value_subset)



Start Eval...
Start KMeans...


In [152]:
predicted_labels = predict_new(latent_value, final_kmeans)
potent_spike_inf['cluster_predicted'] = predicted_labels
potent_spike_inf.index = range(len(potent_spike_inf))

In [153]:
potent_spike_inf.index = range(len(potent_spike_inf))
cluster_averages = compute_cluster_average(sampled_data, potent_spike_inf['cluster_predicted'])
processed_averages = process_cluster_averages(cluster_averages, channel_position=channel_position)


In [154]:
result_df = pd.DataFrame([
    {"cluster": cluster, "waveform": waveform['waveform'], "channel_ids": waveform['channel_ids']}
    for cluster, waveform in processed_averages.items()
])

result_df[['position_1', 'position_2']] = result_df.apply(
    lambda row: calculate_position(row, channel_position), axis=1
)

result_df['position_waveform'] = result_df.apply(
    lambda row: calculate_position_waveform(row, channel_position, power=2), axis=1
)

In [155]:
result_df['label'] = 1
result_df['label'] = result_df.apply(
    lambda row: judge_cluster_reality(row, neuron_inf), axis=1
)



In [156]:
result_df = result_df[~result_df['label'].isna()]
result_df['label'] = result_df['label'].astype(int)
result_df['cluster'] = result_df['cluster'].astype(int)

potent_spike_inf['label_predicted'] = -1

cluster_label_map = result_df.drop_duplicates('cluster').set_index('cluster')['label']

mask = potent_spike_inf['cluster_predicted'].isin(cluster_label_map.index)
potent_spike_inf.loc[mask, 'label_predicted'] = potent_spike_inf.loc[mask, 'cluster_predicted'].map(cluster_label_map)

In [161]:
a = pd.crosstab(potent_spike_inf['label'], potent_spike_inf['label_predicted'])
a = a.div(a.sum(axis=0), axis=1)
a.to_csv("/media/ubuntu/sda/Spike_Sorting/paper_architecture/02_simulation_data/01_Neuroxenus_32_channels/spike_classification/eval_results/setting_1/heatmap_data.csv")