In [3]:
#from spike_sorting_utils.basic_util import *
 
import numpy as np
import pandas as pd
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
import os
from matplotlib.backends.backend_pdf import PdfPages

from tqdm import tqdm

from sklearn.decomposition import PCA
import umap
import random

import sys
import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre

import matplotlib.pyplot as plt
import json

from probeinterface import write_prb, read_prb

import torch.nn.functional as F
from pathlib import Path

import pickle

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split, Subset
from sklearn.metrics import accuracy_score
from sklearn.cluster import KMeans
from sklearn.metrics import pairwise_distances
import time


In [2]:
class SpikeDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
class Spike_Detection_MLP(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2, output_size):
        super(Spike_Detection_MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_size2, output_size)
        self.sigmoid = nn.Sigmoid()  

    def forward(self, x):
        x = x.reshape(-1, 61 * 30)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        x = self.sigmoid(x)
        return x
    
class Spike_Classification_MLP(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2, num_classes):
        super(Spike_Classification_MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_size2, num_classes)  

    def forward(self, x):
        x = x.reshape(-1, 61 * 30)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)  
        return x

In [3]:
spike_detection_model = torch.load("/home/ubuntu/Documents/jct/project/code/Spike_Sorting/spike_detection/new_model/spike_detection_model_1.pth")
spike_classification_model = torch.load("/home/ubuntu/Documents/jct/project/code/Spike_Sorting/spike_classification/new_model/spike_classification_model_1.pth")

- Data input

In [4]:
recording_raw = se.read_blackrock(file_path='/media/ubuntu/sda/data/mouse6/ns4/natural_image/mouse6_022522_natural_image_001.ns4')
recording_recorded = recording_raw.remove_channels(['98', '31', '32'])
recording_stimulated = recording_raw.channel_slice(['98'])

recording_f_22522 = spre.bandpass_filter(recording_recorded, freq_min=300, freq_max=3000)
recording_f_22522 = spre.common_reference(recording_f_22522, reference="global", operator="median")

data_22522 = recording_f_22522.get_traces().astype("float32").T

threshold_result_22522 = detect_local_maxima_in_window(data_22522)
threshold_result_22522 = np.array(threshold_result_22522)
valid_indices_22522 = threshold_result_22522[(threshold_result_22522 > 30)]
valid_indices_22522 = valid_indices_22522[valid_indices_22522 < data_22522.shape[1] - 31]

potent_spike_inf = pd.DataFrame(valid_indices_22522, columns= ['time'])
data_input = extract_windows(data_22522, valid_indices_22522)

In [5]:
val_dataset = SpikeDataset(data_input)
val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=False)

- spike_detection

In [6]:
predicted_labels = []
device = 'cuda'
spike_detection_model = spike_detection_model.to(device)

with torch.no_grad():
    for batch_data in val_loader:
        batch_data = batch_data.to(device)

        outputs = spike_detection_model(batch_data)
        predicted = (outputs > 0.5).float()  

        predicted_labels.extend(predicted.cpu().numpy())
        
predicted_labels = np.array(predicted_labels)

In [7]:
potent_spike_inf['spike_detection_label'] = predicted_labels
potent_spike_inf = potent_spike_inf[potent_spike_inf['spike_detection_label'] == 1]

potent_spikes = np.where(predicted_labels == 1)[0]


In [8]:
data_input = data_input[potent_spikes, :, :]

In [9]:
val_dataset = SpikeDataset(data_input)
val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=False)

In [10]:
latent_value = []

with torch.no_grad():
    for batch_data in val_loader:

        batch_data = batch_data.to(device)
        batch_data = batch_data.reshape(-1, 61 * 30)
        batch_data = spike_classification_model.fc1(batch_data)
        batch_data = spike_classification_model.relu1(batch_data)
        batch_data = spike_classification_model.fc2(batch_data)
        batch_data = spike_classification_model.relu2(batch_data)
        latent_value.append(batch_data.cpu())  
        
latent_value = torch.cat(latent_value, dim=0).numpy()

In [11]:
latent_value_subset = latent_value[random.sample(range(len(latent_value)), 10000), :]
final_kmeans = KMeans(n_clusters=50, n_init=10, random_state=42).fit(latent_value_subset)

In [12]:
with open('/home/ubuntu/Documents/jct/project/code/Spike_Sorting/sorting_results/021322/neuron_inf.pkl', 'rb') as f:
    neuron_inf = pickle.load(f)

In [13]:
channel_indices = {
        "1": [1, 3, 5, 7, 9, 11],
        "2": [13, 15, 17, 19, 21, 23],
        "3": [24, 25, 26, 27, 28, 29],
        "4": [12, 14, 16, 18, 20, 22],
        "5": [0, 2, 4, 6, 8, 10]
        }
channel_position = {
    0: [650, 0],
    2: [650, 50],
    4: [650, 100],
    6: [600, 100],
    8: [600, 50],
    10: [600, 0],
    1: [0, 0],
    3: [0, 50],
    5: [0, 100],
    7: [50, 100],
    9: [50, 50],
    11: [50, 0],
    13: [150, 200], 
    15: [150, 250],
    17: [150, 300],
    19: [200, 300],
    21: [200, 250],
    23: [200, 200],
    12: [500, 200],
    14: [500, 250],
    16: [500, 300],
    18: [450, 300],
    20: [450, 250],
    22: [450, 200],
    24: [350, 400],
    26: [350, 450],
    28: [350, 500],
    25: [300, 400],
    27: [300, 450],
    29: [300, 500]}

In [17]:
predicted_labels = predict_new(latent_value, final_kmeans)

potent_spike_inf['cluster_predicted'] = predicted_labels
potent_spike_inf.index = range(len(potent_spike_inf))

cluster_averages = compute_cluster_average(data_input, potent_spike_inf)
processed_averages = process_cluster_averages(cluster_averages, channel_indices)

df = pd.DataFrame([
    {"cluster": key.split("_")[0], "probe_group": key.split("_")[1], "waveform": value.T}
    for key, value in processed_averages.items()
])

df[['position_1', 'position_2']] = df.apply(
    lambda row: calculate_position(row, channel_indices, channel_position), axis=1
)
df['position_waveform'] = df.apply(
    lambda row: calculate_position_waveform(row, channel_position, channel_indices), axis=1
)

df['label'] = 1
df['label'] = df.apply(
    lambda row: judge_cluster_reality(row, neuron_inf), axis=1
)

df = df[~df['label'].isna()]
df['cluster'] = df['cluster'].astype(int)


potent_spike_inf['label'] = '-1'

for i, row in potent_spike_inf.iterrows():
    df_temp = df[df['cluster'] == row['cluster_predicted']]
    
    if not df_temp.empty:
        potent_spike_inf.loc[i, 'label'] = df_temp['label'].values[0] 
