In [13]:
import numpy as np
import pandas as pd
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
import spikeinterface as si
import matplotlib.pyplot as plt
import os
from matplotlib.backends.backend_pdf import PdfPages

from tqdm import tqdm

In [14]:
channel_indices = {
        "1": [1, 3, 5, 6, 9, 11],
        "2": [13, 15, 17, 19, 21, 23],
        "3": [24, 25, 26, 27, 28, 29],
        "4": [12, 14, 16, 18, 20, 22],
        "5": [0, 2, 4, 6, 8, 10]
        }
channel_position = {
    0: [650, 0],
    2: [650, 50],
    4: [650, 100],
    6: [600, 100],
    8: [600, 50],
    10: [600, 0],
    1: [0, 0],
    3: [0, 50],
    5: [0, 100],
    7: [50, 100],
    9: [50, 50],
    11: [50, 0],
    13: [150, 200], 
    15: [150, 250],
    17: [150, 300],
    19: [200, 300],
    21: [200, 250],
    23: [200, 200],
    12: [500, 200],
    14: [500, 250],
    16: [500, 300],
    18: [450, 300],
    20: [450, 250],
    22: [450, 200],
    24: [350, 400],
    26: [350, 450],
    28: [350, 500],
    25: [300, 400],
    27: [300, 450],
    29: [300, 500]}

In [15]:
def calculate_position(row):
    probe_group = str(row['probe_group'])
    channels = channel_indices[probe_group]
    waveform = row['waveform'] 
    
    a_squared = [np.sum(waveform[:, j]**2) for j in range(len(channels))]
    
    sum_x_a = 0
    sum_y_a = 0
    sum_a = 0
    
    for j, channel in enumerate(channels):
        x_i, y_i = channel_position.get(channel, [0, 0])  
        a_i_sq = a_squared[j]
        
        sum_x_a += x_i * a_i_sq
        sum_y_a += y_i * a_i_sq
        sum_a += a_i_sq
    
    if sum_a == 0:
        return pd.Series({'position_1': 0, 'position_2': 0})
    
    x_hat = sum_x_a / sum_a
    y_hat = sum_y_a / sum_a
    return pd.Series({'position_1': x_hat, 'position_2': y_hat})

def calculate_position_waveform(row, channel_position, channel_indices, power=2):
    x_target = row['position_1']
    y_target = row['position_2']
    probe_group = str(row['probe_group'])
    channels = channel_indices[probe_group]  
    waveforms = row['waveform']  
    
    distances = []
    for channel in channels:
        x_channel, y_channel = channel_position.get(channel, [np.nan, np.nan])
        if np.isnan(x_channel):  
            continue
        distance = np.sqrt((x_target - x_channel)**2 + (y_target - y_channel)**2)
        distances.append(distance)
    
    if not distances:  
        return np.zeros(61)
    
    #IDW
    weights = 1 / (np.array(distances) ** power)
    if np.any(distances == 0):
        zero_idx = np.argwhere(distances == 0).flatten()
        return waveforms[:, zero_idx[0]]
    
    weights /= np.sum(weights)
    
    synthesized_waveform = np.zeros(61)
    for t in range(61): 
        weighted_sum = np.dot(waveforms[t, :], weights)
        synthesized_waveform[t] = weighted_sum
    
    return synthesized_waveform

In [16]:
def get_spike_inf(file_path):
    cluster_inf = pd.read_csv(file_path + "/analyzer_kilosort4_binary/extensions/quality_metrics/metrics.csv")
    cluster_inf.columns = ['cluster', 'num_spikes', 'firing_rate', 'presence_ratio', 'snr',
                           'isi_violations_ratio', 'isi_violations_count', 'rp_contamination',
                           'rp_violations', 'sliding_rp_violation', 'amplitude_cutoff',
                           'amplitude_median', 'amplitude_cv_median', 'amplitude_cv_range',
                           'sync_spike_2', 'sync_spike_4', 'sync_spike_8', 'firing_range',
                           'drift_ptp', 'drift_std', 'drift_mad', 'sd_ratio']
    
    cluster_inf['cluster'] = cluster_inf['cluster'].astype(str)
    cluster_inf['position_1'] = None
    cluster_inf['position_2'] = None

    spike_clusters = pd.DataFrame(np.load(file_path + "/kilosort/sorter_output/spike_clusters.npy").astype(str))
    spike_positions = pd.DataFrame(np.load(file_path + "/kilosort/sorter_output/spike_positions.npy").astype(float))
    spike_templates = pd.DataFrame(np.load(file_path + "/kilosort/sorter_output/spike_templates.npy"))
    spike_times = pd.DataFrame(np.load(file_path + "/kilosort/sorter_output/spike_times.npy").astype(int))
    tf = pd.DataFrame(np.load(file_path + "/kilosort/sorter_output/tF.npy")[:, 0, :])

    spike_inf = pd.concat((spike_clusters, spike_positions, spike_templates, spike_times, tf), axis=1)
    spike_inf.columns = ['cluster', 'position_1', 'position_2', 'templates', 'time', 'PC_1', 'PC_2', 'PC_3', 'PC_4', 'PC_5', 'PC_6']

    for i in spike_inf['cluster'].value_counts().index:
        temp = spike_inf[spike_inf['cluster'] == i]
        cluster_inf.loc[cluster_inf['cluster'] == i, 'position_1'] = np.mean(temp['position_1'])
        cluster_inf.loc[cluster_inf['cluster'] == i, 'position_2'] = np.mean(temp['position_2'])


    cluster_inf['probe_group'] = "1"

    for i in cluster_inf['cluster'].unique():
        cluster_rows = cluster_inf[cluster_inf['cluster'] == i]
        if (cluster_rows['position_1'] > 100).any() and (cluster_rows['position_1'] < 250).any():
            cluster_inf.loc[cluster_inf['cluster'] == i, 'probe_group'] = "2"
        elif (cluster_rows['position_1'] > 250).any() and (cluster_rows['position_1'] < 400).any():
            cluster_inf.loc[cluster_inf['cluster'] == i, 'probe_group'] = "3"
        elif (cluster_rows['position_1'] > 400).any() and (cluster_rows['position_1'] < 550).any():
            cluster_inf.loc[cluster_inf['cluster'] == i, 'probe_group'] = "4"
        elif (cluster_rows['position_1'] > 550).any():
            cluster_inf.loc[cluster_inf['cluster'] == i, 'probe_group'] = "5"

    waveform = np.load(file_path + "/kilosort/sorter_output/templates.npy")
    cluster_inf['waveform'] = [waveform[i] for i in range(waveform.shape[0])]

    channel_indices = {
    "1": [1, 3, 5, 7, 9, 11],
    "2": [13, 15, 17, 19, 21, 23],
    "3": [24, 25, 26, 27, 28, 29],
    "4": [12, 14, 16, 18, 20, 22],
    "5": [0, 2, 4, 6, 8, 10]
    }

    for index, row in cluster_inf.iterrows():
        probe_group = row['probe_group']
        if probe_group in channel_indices:
            selected_channels = channel_indices[probe_group]
            cluster_inf.at[index, 'waveform'] = row['waveform'][:, selected_channels]

    cluster_inf = cluster_inf[((cluster_inf['snr'] > 3) & (cluster_inf['num_spikes'] > int(5000))) | ((cluster_inf['snr'] < 3) & (cluster_inf['num_spikes'] > 8000))]

    return cluster_inf

In [6]:
cluster_inf = get_spike_inf("/home/ubuntu/Documents/jct/project/code/Spike_Sorting/whole_segment_rep1")

In [7]:
cluster_inf[['position_1', 'position_2']] = cluster_inf.apply(calculate_position, axis=1)
cluster_inf['Neuron'] = 'Neuron_1'
current_max_neuron = 1  

cluster_inf = cluster_inf.reset_index(drop=True)
for i in range(1, len(cluster_inf)):
    current_pos1 = cluster_inf.at[i, 'position_1']
    current_pos2 = cluster_inf.at[i, 'position_2']
    
    mask = (
        (cluster_inf.loc[:i-1, 'position_1'] - current_pos1).abs().lt(10) & 
        (cluster_inf.loc[:i-1, 'position_2'] - current_pos2).abs().lt(10)
    )
    
    matched = cluster_inf.loc[:i-1][mask]
    
    if not matched.empty:
        cluster_inf.at[i, 'Neuron'] = matched['Neuron'].iloc[-1]
    else:
        current_max_neuron += 1
        cluster_inf.at[i, 'Neuron'] = f'Neuron_{current_max_neuron}'

cluster_inf['position_waveform'] = None
for idx, row in cluster_inf.iterrows():
    cluster_inf.at[idx, 'position_waveform'] = calculate_position_waveform(row, channel_position, channel_indices, 2)

In [8]:
from scipy.stats import pearsonr
for i in cluster_inf['Neuron'].unique():
    temp = cluster_inf[cluster_inf['Neuron'] == i]
    if len(temp) > 1:
        waverform = np.stack(temp['position_waveform'].values)

        matrix = np.zeros((len(waverform), len(waverform)))
        for j in range(len(waverform)):
            for k in range(len(waverform)):
                matrix[j][k] = pearsonr(waverform[j], waverform[k])[0]

        for len_1 in range(len(waverform)):
            for len_2 in range(len(waverform)):
                if len_1 != len_2:
                    if matrix[len_1, len_2] < 0.9:
                        cluster_inf.at[temp.index[len_2], 'Neuron'] = f'Neuron_{current_max_neuron}'
                        current_max_neuron += 1
            

In [9]:
cluster_inf['channel_id'] = None
for index, row in cluster_inf.iterrows():
    probe_group = row['probe_group']
    if probe_group in channel_indices:
        cluster_inf.at[index, 'channel_id'] = channel_indices[probe_group]


In [10]:
neuron_inf = pd.DataFrame()

for i in cluster_inf['Neuron'].unique():
    temp = cluster_inf[cluster_inf['Neuron'] == i]
    if len(temp) > 1:
        neuron_inf = pd.concat((neuron_inf, pd.DataFrame([i, np.mean(temp['position_1']), np.mean(temp['position_2']), 
                                                          np.mean(temp['position_waveform']), temp['channel_id'].iloc[0],
                                                          np.stack(temp['waveform'].values).mean(axis = 0), temp['cluster'].values[0]])), axis=1, ignore_index=True)
    else:
        neuron_inf = pd.concat((neuron_inf, pd.DataFrame([i, temp['position_1'].iloc[0], temp['position_2'].iloc[0], 
                                                          temp['position_waveform'].iloc[0], temp['channel_id'].iloc[0],
                                                          temp['waveform'].values[0], temp['cluster'].values[0]])), axis=1, ignore_index=True)

neuron_inf = neuron_inf.T
neuron_inf.columns = ['Neuron', 'position_1', 'position_2', 'position_waveform', 'channel_id', 'channel_waveform', 'cluster']

In [192]:
import pickle 
with open('neuron_inf.pkl', 'wb') as f:
    pickle.dump(neuron_inf, f)

In [193]:
file_path = "/home/ubuntu/Documents/jct/project/code/Spike_Sorting/whole_segment"
spike_clusters = pd.DataFrame(np.load(file_path + "/kilosort/sorter_output/spike_clusters.npy").astype(str))
spike_positions = pd.DataFrame(np.load(file_path + "/kilosort/sorter_output/spike_positions.npy").astype(float))
spike_templates = pd.DataFrame(np.load(file_path + "/kilosort/sorter_output/spike_templates.npy"))
spike_times = pd.DataFrame(np.load(file_path + "/kilosort/sorter_output/spike_times.npy").astype(int))
tf = pd.DataFrame(np.load(file_path + "/kilosort/sorter_output/tF.npy")[:, 0, :])

spike_inf = pd.concat((spike_clusters, spike_positions, spike_templates, spike_times, tf), axis=1)
spike_inf.columns = ['cluster', 'position_1', 'position_2', 'templates', 'time', 'PC_1', 'PC_2', 'PC_3', 'PC_4', 'PC_5', 'PC_6']

In [194]:
spike_inf.to_csv('spike_inf.tsv', index=False, sep='\t')