In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
import spikeinterface as si
import matplotlib.pyplot as plt
import os
from matplotlib.backends.backend_pdf import PdfPages

from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import umap
import random

import sys
import spikeinterface as si
import matplotlib.pyplot as plt
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
import spikeinterface.widgets as sw
import spikeinterface.qualitymetrics as sqm
import json
import probeinterface

from probeinterface import Probe, ProbeGroup
from probeinterface.plotting import plot_probe, plot_probegroup
from probeinterface import generate_dummy_probe, generate_linear_probe
from probeinterface import write_probeinterface, read_probeinterface
from probeinterface import write_prb, read_prb
from torch.nn.functional import max_pool1d


import torch.nn.functional as F
from pathlib import Path


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torch.utils.data import Subset

In [2]:
channel_position = {
    0: [650, 0],
    2: [650, 50],
    4: [650, 100],
    6: [600, 100],
    8: [600, 50],
    10: [600, 0],
    1: [0, 0],
    3: [0, 50],
    5: [0, 100],
    7: [50, 100],
    9: [50, 50],
    11: [50, 0],
    13: [150, 200], 
    15: [150, 250],
    17: [150, 300],
    19: [200, 300],
    21: [200, 250],
    23: [200, 200],
    12: [500, 200],
    14: [500, 250],
    16: [500, 300],
    18: [450, 300],
    20: [450, 250],
    22: [450, 200],
    24: [350, 400],
    26: [350, 450],
    28: [350, 500],
    25: [300, 400],
    27: [300, 450],
    29: [300, 500]}

In [None]:
def count_array2_in_range_of_array1(array1, array2, threshold=5):

    sorted_array1 = np.sort(array1)
    
    lefts = array2 - threshold
    rights = array2 + threshold
    
    left_indices = np.searchsorted(sorted_array1, lefts, side='left')
    
    right_indices = np.searchsorted(sorted_array1, rights, side='right')
    
    has_within_range = right_indices > left_indices
    
    count = np.sum(has_within_range)
    
    return count


def detect_local_maxima_in_window(data, window_size=20, std_multiplier=2):

    """
    在每个滑动窗口范围内检测局部最大值的索引，并确保最大值大于两倍的标准差。

    参数:
    data : numpy.ndarray
        输入数据，形状为 (n_rows, n_columns)。
    window_size : int
        滑动窗口的大小，用于定义局部范围，默认为 20。
    std_multiplier : float
        标准差的倍数，用于筛选局部最大值，默认为 2。

    返回:
    local_maxima_indices : list of numpy.ndarray
        每行局部最大值的索引列表，每个元素是对应行局部最大值的索引数组。
    """
    local_maxima_indices = []

    for row in data:
        maxima_indices = []
        row_std = np.std(row)
        threshold = std_multiplier * row_std

        for start in range(0, len(row), window_size):
            end = min(start + window_size, len(row))
            window = row[start:end]
            
            if len(window) > 0:
                local_max_index = np.argmax(window)
                local_max_value = window[local_max_index]
                
                if local_max_value > threshold:
                    maxima_indices.append(start + local_max_index)  
        
        local_maxima_indices.extend(maxima_indices)
        local_maxima_indices = list(set(local_maxima_indices))  

    return local_maxima_indices


def cluster_label_array1_based_on_array2(array1, array2, threshold=5):

    """
    根据 array2 的 'time' 和 'cluster' 对 array1 进行标记。
    如果 array1 中的某个值在 threshold 范围内存在于 array2 的 'time' 中，则标记为对应的 'cluster' 值，否则为 0。
    
    参数:
    array1 : numpy.ndarray
        要标记的数组。
    array2 : numpy.ndarray
        包含 'time' 和 'cluster' 的二维数组。
        第一列为 'time'，第二列为 'cluster'。
    threshold : int
        判断范围的阈值。
    
    返回:
    labels : numpy.ndarray
        长度为 len(array1) 的标签数组，值为 array2 中的 'cluster' 或 0。
    """

    array2 = np.array((array2['time'], array2['cluster'])).T
    sorted_indices = np.argsort(array2[:, 0])
    sorted_array2 = array2[sorted_indices]
    
    labels = -np.ones(len(array1), dtype=int)
    
    # 遍历 array1 中的每个元素
    for i, value in enumerate(array1):
        # 计算当前值的范围
        left = value - threshold
        right = value + threshold
        
        left_index = np.searchsorted(sorted_array2[:, 0], left, side='left')
        right_index = np.searchsorted(sorted_array2[:, 0], right, side='right')
        
        # 如果范围内存在值，则标记为对应的 'cluster'
        if right_index > left_index:
            # 获取范围内的第一个匹配值的 'cluster'
            labels[i] = sorted_array2[left_index, 1]
    
    return labels


def label_array1_based_on_array2(array1, array2, threshold=5):

    """
    根据 array2 的值对 array1 进行标记。
    如果 array1 中的某个值在 threshold 范围内存在于 array2 中，则标记为 1，否则为 0。
    
    参数:
    array1 : numpy.ndarray
        要标记的数组。
    array2 : numpy.ndarray
        用于判断的数组。
    threshold : int
        判断范围的阈值。
    
    返回:
    labels : numpy.ndarray
        长度为 len(array1) 的标签数组，值为 0 或 1。
    """
    # 对 array2 进行排序以加速搜索
    sorted_array2 = np.sort(array2)
    
    # 初始化标签数组，默认值为 0
    labels = np.zeros(len(array1), dtype=int)
    
    # 遍历 array1 中的每个元素
    for i, value in enumerate(array1):
        # 计算当前值的范围
        left = value - threshold
        right = value + threshold
        
        # 使用二分搜索判断范围内是否存在值
        left_index = np.searchsorted(sorted_array2, left, side='left')
        right_index = np.searchsorted(sorted_array2, right, side='right')
        
        # 如果范围内存在值，则标记为 1
        if right_index > left_index:
            labels[i] = 1
    
    return labels


def extract_windows(data, indices, window_size=61):
    """
    根据给定的时间点索引提取窗口。
    
    参数:
    data : numpy.ndarray
        输入数据，形状为 (n_channels, time)
    indices : numpy.ndarray
        时间点索引数组，用于指定需要提取窗口的中心点
    window_size : int
        窗口长度，默认为61（对应time-30到time+31）
    
    返回:
    windows : numpy.ndarray
        提取的窗口数据，形状为 (len(indices), n_channels, window_size)
    """
    n_channels, time_length = data.shape
    half_window = window_size // 2

    if np.any(indices < half_window) or np.any(indices >= time_length - half_window):
        raise ValueError("Some indices are out of bounds for the given window size.")

    windows = []
    for idx in indices:
        window = data[:, idx - half_window:idx + half_window + 1]
        windows.append(window)

    windows = np.array(windows)
    return windows
def calculate_position(row):
    probe_group = str(row['probe_group'])
    channels = channel_indices[probe_group]
    waveform = row['waveform'] 
    
    a_squared = [np.sum(waveform[:, j]**2) for j in range(len(channels))]
    
    sum_x_a = 0
    sum_y_a = 0
    sum_a = 0
    
    for j, channel in enumerate(channels):
        x_i, y_i = channel_position.get(channel, [0, 0])  
        a_i_sq = a_squared[j]
        
        sum_x_a += x_i * a_i_sq
        sum_y_a += y_i * a_i_sq
        sum_a += a_i_sq
    
    if sum_a == 0:
        return pd.Series({'position_1': 0, 'position_2': 0})
    
    x_hat = sum_x_a / sum_a
    y_hat = sum_y_a / sum_a
    return pd.Series({'position_1': x_hat, 'position_2': y_hat})

def calculate_position_waveform(row, channel_position, channel_indices, power=2):
    x_target = row['position_1']
    y_target = row['position_2']
    probe_group = str(row['probe_group'])
    channels = channel_indices[probe_group]  
    waveforms = row['waveform']  
    
    distances = []
    for channel in channels:
        x_channel, y_channel = channel_position.get(channel, [np.nan, np.nan])
        if np.isnan(x_channel):  
            continue
        distance = np.sqrt((x_target - x_channel)**2 + (y_target - y_channel)**2)
        distances.append(distance)
    
    if not distances:  
        return np.zeros(61)
    
    #IDW
    weights = 1 / (np.array(distances) ** power)
    if np.any(distances == 0):
        zero_idx = np.argwhere(distances == 0).flatten()
        return waveforms[:, zero_idx[0]]
    
    weights /= np.sum(weights)
    
    synthesized_waveform = np.zeros(61)
    for t in range(61): 
        weighted_sum = np.dot(waveforms[t, :], weights)
        synthesized_waveform[t] = weighted_sum
    
    return synthesized_waveform


def compute_cluster_average(sample_data, potent_spike_inf, cluster_column='cluster_predicted'):
    """
    计算 potent_spike_inf 中每个 cluster_predicted 对应的 sample_data 的平均值。
    
    参数:
    - sample_data: np.ndarray, 输入的 (n, 30, 61) 矩阵。
    - potent_spike_inf: pd.DataFrame, 包含 cluster_predicted 信息的 DataFrame。
    - cluster_column: str, cluster 信息所在的列名。
    
    返回:
    - cluster_averages: dict, 每个 cluster 对应的平均值矩阵 (30, 61)。
    """
    cluster_averages = {}
    unique_clusters = potent_spike_inf[cluster_column].unique()
    
    for cluster in unique_clusters:
        cluster_indices = potent_spike_inf[potent_spike_inf[cluster_column] == cluster].index
        cluster_average = sample_data[cluster_indices].mean(axis=0) 
        cluster_averages[cluster] = cluster_average
    
    return cluster_averages

def judge_cluster_reality(row, neuron_inf):
    from scipy.stats import pearsonr

    position_threshold =10
    position_condition = (
        (abs(neuron_inf['position_1'] - row['position_1']) <= position_threshold) &
        (abs(neuron_inf['position_2'] - row['position_2']) <= position_threshold)
    )

    candidate_neurons = neuron_inf[position_condition]

    if candidate_neurons.empty:
        return None

    waveform_threshold = 0.95
    row_waveform = row['position_waveform']
    best_match = None
    best_corr = -1 

    for _, candidate in candidate_neurons.iterrows():
        neuron_inf_waveform = candidate['position_waveform']
        corr, _ = pearsonr(row_waveform, neuron_inf_waveform)

        if corr > waveform_threshold and corr > best_corr:
            best_corr = corr
            best_match = candidate['cluster']

    return best_match if best_match is not None else None


def process_cluster_averages(cluster_averages, channel_indices):
    """
    对 cluster_averages 中的每个 item，找到最大值所在的通道，
    并根据 channel_indices 保留对应的 6 个通道。
    
    参数:
    - cluster_averages: dict, 每个 cluster 的平均值 (30, 61)。
    - channel_indices: dict, 通道索引字典。
    
    返回:
    - processed_averages: dict, 处理后的字典，键为 cluster_channelindices，值为 (6, 61) 的数组。
    """
    processed_averages = {}
    
    for cluster, avg_matrix in cluster_averages.items():
        max_channel = np.argmax(avg_matrix.max(axis=1))  
        
        for key, indices in channel_indices.items():
            if max_channel in indices:
                selected_channels = avg_matrix[indices, :]
                new_key = f"{cluster}_{key}"
                processed_averages[new_key] = selected_channels
                break
    
    return processed_averages

def predict_new(feature, kmeans):
    dists = pairwise_distances(feature, kmeans.cluster_centers_ )
    return np.argmin(dists, axis=1)

In [4]:
recording_raw = se.read_blackrock(file_path='/media/ubuntu/sda/data/mouse6/ns4/natural_image/mouse6_021322_natural_image001.ns4')
recording_recorded = recording_raw.remove_channels(["98", '31', '32'])
recording_stimulated = recording_raw.channel_slice(['98'])

recording_f = spre.bandpass_filter(recording_recorded, freq_min=300, freq_max=3000)
recording_f = spre.common_reference(recording_f, reference="global", operator="median")
data = recording_f.get_traces().astype("float32").T


threshold_result = detect_local_maxima_in_window(data)


threshold_result = np.array(threshold_result)
valid_indices = threshold_result[(threshold_result > 30)]
valid_indices = valid_indices[valid_indices < data.shape[1] - 31]

In [5]:
spike_inf = pd.read_csv("/home/ubuntu/Documents/jct/project/code/Spike_Sorting/sorting_results/021322/spike_inf.csv")
removed_spikes = spike_inf['cluster'].value_counts()
removed_spikes = removed_spikes[removed_spikes < 8000].index
spike_inf = spike_inf[~spike_inf['cluster'].isin(removed_spikes)]

In [6]:
unique_clusters = np.unique(spike_inf['cluster'])
cluster_to_index = {cluster: idx for idx, cluster in enumerate(unique_clusters)}
spike_inf['cluster'] = np.array([cluster_to_index[cluster] for cluster in spike_inf['cluster']])

In [7]:
cluster_labels = cluster_label_array1_based_on_array2(valid_indices, spike_inf, threshold=1)

In [8]:
potent_spike_inf = pd.DataFrame((valid_indices, cluster_labels), index= ['time', 'cluster']).T

In [9]:
potent_spike_inf = potent_spike_inf[potent_spike_inf['cluster'] != -1]

In [10]:
class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]
indices = potent_spike_inf['time'].values

labels = potent_spike_inf['cluster'].values

window_view_classify = extract_windows(data, indices, window_size=61)

balanced_indices = []
for cluster in spike_inf['cluster'].unique():
    cluster_indices = np.where(labels == cluster)[0]
    if len(cluster_indices) > 8000:
        sampled_indices = np.random.choice(cluster_indices, 8000, replace=False)
    else:
        sampled_indices = cluster_indices
    balanced_indices.extend(sampled_indices)

np.random.shuffle(balanced_indices)

balanced_data = window_view_classify[balanced_indices]
balanced_labels = labels[balanced_indices]

dataset = CustomDataset(balanced_data, balanced_labels)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

batch_size = 1024
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [11]:
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2, num_classes):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_size2, num_classes)  # 输出类别数

    def forward(self, x):
        x = x.reshape(-1, 61 * 30)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)  
        return x

device = 'cuda'
input_size = balanced_data.shape[1] * balanced_data.shape[2]
hidden_size1 = 64
hidden_size2 = 50
num_classes = potent_spike_inf['cluster'].nunique()
model = MLP(input_size, hidden_size1, hidden_size2, num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()  
optimizer = optim.Adam(model.parameters(), lr=0.00001)

In [None]:

num_epochs = 210
for epoch in range(num_epochs):
    all_labels = []
    all_predictions = []
    model.train()
    total_loss = 0
    for batch_data, batch_labels in train_loader:
        batch_data = batch_data.to(device)
        batch_labels = batch_labels.to(device)

        outputs = model(batch_data)
        predicted = torch.argmax(outputs, dim=1)  

        loss = criterion(outputs, batch_labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        all_labels.extend(batch_labels.cpu().numpy())
        all_predictions.extend(predicted.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_predictions)
    print(f"Train Accuracy: {accuracy * 100:.2f}%")
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}")

    model.eval()
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for batch_data, batch_labels in test_loader:
            batch_data = batch_data.to(device)
            batch_labels = batch_labels.to(device)

            outputs = model(batch_data)
            predicted = torch.argmax(outputs, dim=1)  

            all_labels.extend(batch_labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    all_labels = np.array(all_labels)
    all_predictions = np.array(all_predictions)

    accuracy = accuracy_score(all_labels, all_predictions)
    print(f"Test Accuracy: {accuracy * 100:.2f}%")


Train Accuracy: 4.07%
Epoch [1/210], Loss: 10.1799
Test Accuracy: 5.58%
Train Accuracy: 7.30%
Epoch [2/210], Loss: 6.7954
Test Accuracy: 9.32%
Train Accuracy: 11.02%
Epoch [3/210], Loss: 5.2177
Test Accuracy: 12.87%
Train Accuracy: 14.85%
Epoch [4/210], Loss: 4.3346
Test Accuracy: 16.52%
Train Accuracy: 18.57%
Epoch [5/210], Loss: 3.7668
Test Accuracy: 20.27%
Train Accuracy: 22.49%
Epoch [6/210], Loss: 3.3626
Test Accuracy: 24.05%
Train Accuracy: 26.57%
Epoch [7/210], Loss: 3.0454
Test Accuracy: 27.94%
Train Accuracy: 30.87%
Epoch [8/210], Loss: 2.7819
Test Accuracy: 32.01%
Train Accuracy: 35.12%
Epoch [9/210], Loss: 2.5555
Test Accuracy: 36.05%
Train Accuracy: 39.34%
Epoch [10/210], Loss: 2.3529
Test Accuracy: 39.97%
Train Accuracy: 43.35%
Epoch [11/210], Loss: 2.1735
Test Accuracy: 43.54%
Train Accuracy: 47.08%
Epoch [12/210], Loss: 2.0128
Test Accuracy: 47.04%
Train Accuracy: 50.46%
Epoch [13/210], Loss: 1.8712
Test Accuracy: 50.15%
Train Accuracy: 53.68%
Epoch [14/210], Loss: 1.745

In [13]:
torch.save(model, 'spike_classification_model.pth')

In [14]:
model = torch.load("/home/ubuntu/Documents/jct/project/code/Spike_Sorting/spike_classification/spike_classification_model.pth")

In [13]:
all_labels = []
predicted_labels = []
latent_value = []
output_max = []

with torch.no_grad():
    for batch_data, batch_labels in test_loader:
        batch_labels = batch_labels.float().unsqueeze(1)
        batch_data = batch_data.to(device)
        batch_labels = batch_labels.to(device)

        outputs = model(batch_data)
        predicted = torch.argmax(outputs, dim=1)  

        outputs = outputs.cpu().numpy().max(axis = 1)
        output_max.extend(list(outputs))

        batch_data = batch_data.reshape(-1, 61 * 30)
        batch_data = model.fc1(batch_data)
        batch_data = model.relu1(batch_data)
        batch_data = model.fc2(batch_data)
        batch_data = model.relu2(batch_data)
        latent_value.append(batch_data.cpu())  
        
        all_labels.extend(batch_labels.cpu().numpy())
        predicted_labels.extend(predicted.cpu().numpy())

output_max = np.array(output_max)
all_labels = np.array(all_labels)
predicted_labels = np.array(predicted_labels)
latent_value = torch.cat(latent_value, dim=0).numpy()

In [14]:
num_samples = 50000
indices = random.sample(range(len(latent_value)), num_samples)
latent_value_subset = latent_value[indices, :]
all_labels_subset = np.concatenate(all_labels[indices]).astype(int).astype(str)
all_predictions_subset = predicted_labels[indices].astype(int).astype(str)
output_max_subset = output_max[indices]

In [None]:
pca = PCA(n_components=20)
pca_data = pca.fit_transform(latent_value_subset)

umap_reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors= 5, min_dist=0)
umap_data = umap_reducer.fit_transform(pca_data)

In [16]:
unique_dates = np.unique(all_labels_subset)
color_map = {date: plt.cm.tab10(i % 10) for i, date in enumerate(unique_dates)} 
colors = [color_map[date] for date in all_labels_subset]

x_min, x_max = umap_data[:, 0].min(), umap_data[:, 0].max()
y_min, y_max = umap_data[:, 1].min(), umap_data[:, 1].max()

with PdfPages('test_dataset_clusters_visualization_gt.pdf') as pdf:
    for cluster in unique_dates:
        plt.figure(figsize=(8, 8))
        
        cluster_points = umap_data[all_labels_subset == cluster]
        other_points = umap_data[all_labels_subset != cluster]
        
        plt.scatter(other_points[:, 0], other_points[:, 1], c='lightgray', s=0.1, alpha=0.5, label='Other Clusters')
        
        plt.scatter(cluster_points[:, 0], cluster_points[:, 1], c=color_map[cluster], s=0.1, alpha=0.7, label=f'Cluster {cluster}')
        
        centroid = cluster_points.mean(axis=0)  # 计算质心
        plt.text(centroid[0], centroid[1], f'Cluster {cluster}', 
                 fontsize=12, color=color_map[cluster], ha='center', va='center')
        
        plt.xlim(x_min, x_max)
        plt.ylim(y_min, y_max)
        
        plt.title(f'Cluster {cluster}', fontsize=14)
        plt.xlabel('UMAP 1')
        plt.ylabel('UMAP 2')
        plt.legend(loc='upper right', fontsize=8)
        
        pdf.savefig()
        plt.close()

In [17]:
unique_dates = np.unique(all_predictions_subset)
color_map = {date: plt.cm.tab10(i % 10) for i, date in enumerate(unique_dates)} 
colors = [color_map[date] for date in all_predictions_subset]

x_min, x_max = umap_data[:, 0].min(), umap_data[:, 0].max()
y_min, y_max = umap_data[:, 1].min(), umap_data[:, 1].max()

with PdfPages('test_dataset_clusters_visualization_predicted.pdf') as pdf:
    for cluster in unique_dates:
        plt.figure(figsize=(8, 8))
        
        cluster_points = umap_data[all_predictions_subset == cluster]
        other_points = umap_data[all_predictions_subset != cluster]
        
        plt.scatter(other_points[:, 0], other_points[:, 1], c='lightgray', s=0.1, alpha=0.5, label='Other Clusters')
        
        plt.scatter(cluster_points[:, 0], cluster_points[:, 1], c=color_map[cluster], s=0.1, alpha=0.7, label=f'Cluster {cluster}')
        
        centroid = cluster_points.mean(axis=0)  
        plt.text(centroid[0], centroid[1], f'Cluster {cluster}', 
                 fontsize=12, color=color_map[cluster], ha='center', va='center')
        
        plt.xlim(x_min, x_max)
        plt.ylim(y_min, y_max)
        
        plt.title(f'Cluster {cluster}', fontsize=14)
        plt.xlabel('UMAP 1')
        plt.ylabel('UMAP 2')
        plt.legend(loc='upper right', fontsize=8)
        
        pdf.savefig()
        plt.close()

In [18]:
recording_raw = se.read_blackrock(file_path='/media/ubuntu/sda/data/mouse6/ns4/natural_image/mouse6_022522_natural_image_001.ns4')
recording_recorded = recording_raw.remove_channels(['98', '31', '32'])
recording_stimulated = recording_raw.channel_slice(['98'])

recording_f_22522 = spre.bandpass_filter(recording_recorded, freq_min=300, freq_max=3000)
recording_f_22522 = spre.common_reference(recording_f_22522, reference="global", operator="median")

In [19]:
data_22522 = recording_f_22522.get_traces().astype("float32").T

In [20]:
threshold_result_22522 = detect_local_maxima_in_window(data_22522)
threshold_result_22522 = np.array(threshold_result_22522)
valid_indices_22522 = threshold_result_22522[(threshold_result_22522 > 30)]
valid_indices_22522 = valid_indices_22522[valid_indices_22522 < data_22522.shape[1] - 31]

In [21]:
spike_inf_22522 = pd.read_csv("/home/ubuntu/Documents/jct/project/code/Spike_Sorting/sorting_results/022522/spike_inf.csv")

removed_spikes = spike_inf_22522['cluster'].value_counts()
removed_spikes = removed_spikes[removed_spikes < 8000].index
spike_inf_22522 = spike_inf_22522[~spike_inf_22522['cluster'].isin(removed_spikes)]

In [22]:
spike_inf_22522 = spike_inf_22522[(spike_inf_22522['time'] > 30) & (spike_inf_22522['time'] < data_22522.shape[1] - 31)]

In [23]:
spike_inf_22522_aligned = pd.read_csv("/home/ubuntu/Documents/jct/project/code/Spike_Sorting/sorting_results/022522/spike_inf_aligned.csv")

In [48]:
#spike_inf_22522.loc[~spike_inf_22522['cluster'].isin(spike_inf_22522_aligned['cluster']), 'cluster'] = -2

In [24]:
cluster_labels_22522 = cluster_label_array1_based_on_array2(valid_indices_22522, spike_inf_22522, threshold=2)

In [107]:
potent_spike_inf = pd.DataFrame((valid_indices_22522, cluster_labels_22522), index= ['time', 'cluster']).T

In [108]:
potent_spike_inf = potent_spike_inf[potent_spike_inf['cluster'] != -1]

In [27]:
sampled_data_22522 = extract_windows(data_22522, potent_spike_inf['time'], window_size=61)

In [36]:
val_dataset = CustomDataset(sampled_data_22522, potent_spike_inf['cluster'].values)
val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=False)

In [37]:
all_labels = []
predicted_labels = []
latent_value = []
#output_max = []

with torch.no_grad():
    for batch_data, batch_labels in val_loader:
        batch_labels = batch_labels.float().unsqueeze(1)
        batch_data = batch_data.to(device)
        batch_labels = batch_labels.to(device)

        outputs = model(batch_data)
        #predicted = torch.argmax(outputs, dim=1)  

        outputs = outputs.cpu().numpy().max(axis = 1)
        #output_max.extend(list(outputs))

        batch_data = batch_data.reshape(-1, 61 * 30)
        batch_data = model.fc1(batch_data)
        batch_data = model.relu1(batch_data)
        batch_data = model.fc2(batch_data)
        batch_data = model.relu2(batch_data)
        latent_value.append(batch_data.cpu())  
        
        all_labels.extend(batch_labels.cpu().numpy())
        #predicted_labels.extend(predicted.cpu().numpy())

#output_max = np.array(output_max)
all_labels = np.array(all_labels)
#predicted_labels = np.array(predicted_labels)
latent_value = torch.cat(latent_value, dim=0).numpy()

In [125]:
final_kmeans = KMeans(n_clusters=50, n_init=10, random_state=42).fit(latent_value_subset)

In [126]:
predicted_labels = predict_new(latent_value, final_kmeans)

In [127]:
import random
num_samples = 100000
indices = random.sample(range(len(latent_value)), num_samples)
latent_value_subset = latent_value[indices, :]
all_labels_subset = np.concatenate(all_labels[indices]).astype(int).astype(str)
all_predictions_subset = predicted_labels[indices].astype(int).astype(str)
#output_max_subset = output_max[indices]

In [61]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import umap
import random

pca = PCA(n_components=30)
pca_data = pca.fit_transform(latent_value_subset)

umap_reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors= 5, min_dist=0)
umap_data = umap_reducer.fit_transform(pca_data)

In [62]:
unique_dates = np.unique(all_labels_subset)
color_map = {date: plt.cm.tab10(i % 10) for i, date in enumerate(unique_dates)} 
colors = [color_map[date] for date in all_labels_subset]

x_min, x_max = umap_data[:, 0].min(), umap_data[:, 0].max()
y_min, y_max = umap_data[:, 1].min(), umap_data[:, 1].max()

with PdfPages('clusters_visualization_gt.pdf') as pdf:
    for cluster in unique_dates:
        plt.figure(figsize=(8, 8))
        
        cluster_points = umap_data[all_labels_subset == cluster]
        other_points = umap_data[all_labels_subset != cluster]
        
        plt.scatter(other_points[:, 0], other_points[:, 1], c='lightgray', s=0.1, alpha=0.5, label='Other Clusters')
        
        plt.scatter(cluster_points[:, 0], cluster_points[:, 1], c=color_map[cluster], s=0.1, alpha=0.7, label=f'Cluster {cluster}')
        
        centroid = cluster_points.mean(axis=0) 
        plt.text(centroid[0], centroid[1], f'Cluster {cluster}', 
                 fontsize=12, color=color_map[cluster], ha='center', va='center')
        
        plt.xlim(x_min, x_max)
        plt.ylim(y_min, y_max)
        
        plt.title(f'Cluster {cluster}', fontsize=14)
        plt.xlabel('UMAP 1')
        plt.ylabel('UMAP 2')
        plt.legend(loc='upper right', fontsize=8)
        
        pdf.savefig()
        plt.close()

In [63]:


unique_dates = np.unique(all_predictions_subset)
color_map = {date: plt.cm.tab10(i % 10) for i, date in enumerate(unique_dates)} 
colors = [color_map[date] for date in all_predictions_subset]

x_min, x_max = umap_data[:, 0].min(), umap_data[:, 0].max()
y_min, y_max = umap_data[:, 1].min(), umap_data[:, 1].max()

with PdfPages('clusters_visualization.pdf') as pdf:
    for cluster in unique_dates:
        plt.figure(figsize=(8, 8))
        
        cluster_points = umap_data[all_predictions_subset == cluster]
        other_points = umap_data[all_predictions_subset != cluster]
        
        plt.scatter(other_points[:, 0], other_points[:, 1], c='lightgray', s=0.1, alpha=0.5, label='Other Clusters')
        
        plt.scatter(cluster_points[:, 0], cluster_points[:, 1], c=color_map[cluster], s=0.1, alpha=0.7, label=f'Cluster {cluster}')
        
        centroid = cluster_points.mean(axis=0)  
        plt.text(centroid[0], centroid[1], f'Cluster {cluster}', 
                 fontsize=12, color=color_map[cluster], ha='center', va='center')
        
        plt.xlim(x_min, x_max)
        plt.ylim(y_min, y_max)
        
        plt.title(f'Cluster {cluster}', fontsize=14)
        plt.xlabel('UMAP 1')
        plt.ylabel('UMAP 2')
        plt.legend(loc='upper right', fontsize=8)
        
        pdf.savefig()
        plt.close()

print("PDF file 'clusters_visualization.pdf' has been created.")

PDF file 'clusters_visualization.pdf' has been created.


In [128]:
potent_spike_inf['cluster_predicted'] = predicted_labels

In [129]:
potent_spike_inf.index = range(len(potent_spike_inf))

In [202]:
potent_spike_inf_subset = potent_spike_inf.sample(n = 10000, random_state= 30)

In [204]:
cluster_averages = compute_cluster_average(sampled_data_22522, potent_spike_inf_subset)

In [206]:
channel_indices = {
        "1": [1, 3, 5, 7, 9, 11],
        "2": [13, 15, 17, 19, 21, 23],
        "3": [24, 25, 26, 27, 28, 29],
        "4": [12, 14, 16, 18, 20, 22],
        "5": [0, 2, 4, 6, 8, 10]
        }

In [207]:
processed_averages = process_cluster_averages(cluster_averages, channel_indices)

In [208]:
import pickle
with open('/home/ubuntu/Documents/jct/project/code/Spike_Sorting/sorting_results/021322/neuron_inf.pkl', 'rb') as f:
    neuron_inf = pickle.load(f)

In [209]:
df = pd.DataFrame([
    {"cluster": key.split("_")[0], "probe_group": key.split("_")[1], "waveform": value.T}
    for key, value in processed_averages.items()
])

In [210]:
df[['position_1', 'position_2']] = df.apply(calculate_position, axis=1)
df['position_waveform'] = df.apply(
    lambda row: calculate_position_waveform(row, channel_position, channel_indices), axis=1
)

In [212]:
df['label'] = 1
df['label'] = df.apply(
    lambda row: judge_cluster_reality(row, neuron_inf), axis=1
)

In [213]:
df = df[~df['label'].isna()]

In [266]:
df['cluster'] = df['cluster'].astype(int)

In [267]:
potent_spike_inf['label'] = -1

for i, row in potent_spike_inf.iterrows():
    df_temp = df[df['cluster'] == row['cluster_predicted']]
    
    if not df_temp.empty:
        potent_spike_inf.loc[i, 'label'] = df_temp['label'].values[0] 