In [31]:
from pathlib import Path
from kilosort.io import load_ops
import sys
import spikeinterface as si
import matplotlib.pyplot as plt

import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
import spikeinterface.widgets as sw
import spikeinterface.qualitymetrics as sqm
import json
import numpy as np
import pandas as pd
import seaborn as sns
import warnings
from kilosort import io
warnings.filterwarnings('ignore')

global_job_kwargs = dict(n_jobs = 4)
si.set_global_job_kwargs(**global_job_kwargs)

In [32]:
import os
import pandas as pd
import numpy as np

def get_spike_inf(file_path):
    cluster_inf = pd.read_csv(file_path + "/analyzer_kilosort4_binary/extensions/quality_metrics/metrics.csv")
    cluster_inf.columns = ['cluster', 'num_spikes', 'firing_rate', 'presence_ratio', 'snr',
                           'isi_violations_ratio', 'isi_violations_count', 'rp_contamination',
                           'rp_violations', 'sliding_rp_violation', 'amplitude_cutoff',
                           'amplitude_median', 'amplitude_cv_median', 'amplitude_cv_range',
                           'sync_spike_2', 'sync_spike_4', 'sync_spike_8', 'firing_range',
                           'drift_ptp', 'drift_std', 'drift_mad', 'sd_ratio']
    
    cluster_inf['cluster'] = cluster_inf['cluster'].astype(str)
    cluster_inf['position_1'] = None
    cluster_inf['position_2'] = None

    spike_clusters = pd.DataFrame(np.load(file_path + "/kilosort4/sorter_output/spike_clusters.npy").astype(str))
    spike_positions = pd.DataFrame(np.load(file_path + "/kilosort4/sorter_output/spike_positions.npy").astype(float))
    spike_templates = pd.DataFrame(np.load(file_path + "/kilosort4/sorter_output/spike_templates.npy"))
    spike_times = pd.DataFrame(np.load(file_path + "/kilosort4/sorter_output/spike_times.npy").astype(int))
    tf = pd.DataFrame(np.load(file_path + "/kilosort4/sorter_output/tF.npy")[:, 0, :])

    spike_inf = pd.concat((spike_clusters, spike_positions, spike_templates, spike_times, tf), axis=1)
    spike_inf.columns = ['cluster', 'position_1', 'position_2', 'templates', 'time', 'PC_1', 'PC_2', 'PC_3', 'PC_4', 'PC_5', 'PC_6']

    for i in spike_inf['cluster'].value_counts().index:
        temp = spike_inf[spike_inf['cluster'] == i]
        cluster_inf.loc[cluster_inf['cluster'] == i, 'position_1'] = np.mean(temp['position_1'])
        cluster_inf.loc[cluster_inf['cluster'] == i, 'position_2'] = np.mean(temp['position_2'])

    cluster_inf['probe_group'] = "1"

    for i in spike_inf['cluster'].value_counts().index:
        cluster_rows = cluster_inf[cluster_inf['cluster'] == i]
        if (cluster_rows['position_1'] > 100).any() and (cluster_rows['position_1'] < 250).any():
            cluster_inf.loc[cluster_inf['cluster'] == i, 'probe_group'] = "2"
        elif (cluster_rows['position_1'] > 250).any() and (cluster_rows['position_1'] < 400).any():
            cluster_inf.loc[cluster_inf['cluster'] == i, 'probe_group'] = "3"
        elif (cluster_rows['position_1'] > 400).any() and (cluster_rows['position_1'] < 550).any():
            cluster_inf.loc[cluster_inf['cluster'] == i, 'probe_group'] = "4"
        elif (cluster_rows['position_1'] > 550).any():
            cluster_inf.loc[cluster_inf['cluster'] == i, 'probe_group'] = "5"

    waveform = np.load(file_path + "/kilosort4/sorter_output/templates.npy")
    cluster_inf['waveform'] = [waveform[i] for i in range(waveform.shape[0])]

    cluster_inf = cluster_inf[((cluster_inf['snr'] > 3) & (cluster_inf['num_spikes'] > int(5000))) | ((cluster_inf['snr'] < 3) & (cluster_inf['num_spikes'] > 8000))]
    spike_inf = spike_inf[spike_inf['cluster'].isin(list(cluster_inf['cluster']))]
    spike_inf = spike_inf[spike_inf['time'] > 200]
    cluster_inf['date'] = date
    spike_inf['date'] = date
    
    channel_indices = {
        "1": [1, 3, 5, 6, 9, 11],
        "2": [13, 15, 17, 19, 21, 23],
        "3": [24, 25, 26, 27, 28, 29],
        "4": [12, 14, 16, 18, 20, 22],
        "5": [0, 2, 4, 6, 8, 10]
        }

    for index, row in cluster_inf.iterrows():
        probe_group = row['probe_group']
        if probe_group in channel_indices:
            selected_channels = channel_indices[probe_group]
            cluster_inf.at[index, 'waveform'] = row['waveform'][:, selected_channels]

    return cluster_inf, spike_inf

all_cluster_inf = pd.DataFrame()
all_spike_inf = pd.DataFrame()

for date in os.listdir("/media/ubuntu/sda/data/sort_output/mouse2/natural_image"):
    cluster_inf, spike_inf = get_spike_inf(file_path=f"/media/ubuntu/sda/data/sort_output/mouse2/natural_image/{date}")
    all_cluster_inf = pd.concat([all_cluster_inf, cluster_inf], ignore_index=True)
    all_spike_inf = pd.concat([all_spike_inf, spike_inf], ignore_index=True)

In [33]:
channel_indices = {
        "1": [1, 3, 5, 6, 9, 11],
        "2": [13, 15, 17, 19, 21, 23],
        "3": [24, 25, 26, 27, 28, 29],
        "4": [12, 14, 16, 18, 20, 22],
        "5": [0, 2, 4, 6, 8, 10]
        }

channel_position = {
    0: [650, 0],
    2: [650, 50],
    4: [650, 100],
    6: [600, 100],
    8: [600, 50],
    10: [600, 0],
    1: [0, 0],
    3: [0, 50],
    5: [0, 100],
    7: [50, 100],
    9: [50, 50],
    11: [50, 0],
    13: [150, 200], 
    15: [150, 250],
    17: [150, 300],
    19: [200, 300],
    21: [200, 250],
    23: [200, 200],
    12: [500, 200],
    14: [500, 250],
    16: [500, 300],
    18: [450, 300],
    20: [450, 250],
    22: [450, 200],
    24: [350, 400],
    26: [350, 450],
    28: [350, 500],
    25: [300, 400],
    27: [300, 450],
    29: [300, 500]
}

In [34]:
def calculate_position(row):
    probe_group = str(row['probe_group'])
    channels = channel_indices[probe_group]
    waveform = row['waveform'] 
    
    a_squared = [np.sum(waveform[:, j]**2) for j in range(len(channels))]
    
    sum_x_a = 0
    sum_y_a = 0
    sum_a = 0
    
    for j, channel in enumerate(channels):
        x_i, y_i = channel_position.get(channel, [0, 0])  
        a_i_sq = a_squared[j]
        
        sum_x_a += x_i * a_i_sq
        sum_y_a += y_i * a_i_sq
        sum_a += a_i_sq
    
    if sum_a == 0:
        return pd.Series({'position_1': 0, 'position_2': 0})
    
    x_hat = sum_x_a / sum_a
    y_hat = sum_y_a / sum_a
    return pd.Series({'position_1': x_hat, 'position_2': y_hat})

all_cluster_inf[['position_1', 'position_2']] = all_cluster_inf.apply(calculate_position, axis=1)

In [35]:
import pandas as pd
import numpy as np

all_cluster_inf['Neuron'] = None

current_max_neuron = 1  

for i in range(1, len(all_cluster_inf)):
    current_pos1 = all_cluster_inf.at[i, 'position_1']
    current_pos2 = all_cluster_inf.at[i, 'position_2']
    
    mask = (
        (all_cluster_inf.loc[:i-1, 'position_1'] - current_pos1).abs().lt(10) & 
        (all_cluster_inf.loc[:i-1, 'position_2'] - current_pos2).abs().lt(10)
    )
    
    matched = all_cluster_inf.loc[:i-1][mask]
    
    if not matched.empty:
        all_cluster_inf.at[i, 'Neuron'] = matched['Neuron'].iloc[-1]
    else:
        current_max_neuron += 1
        all_cluster_inf.at[i, 'Neuron'] = f'Neuron_{current_max_neuron}'


In [36]:
len(os.listdir("/media/ubuntu/sda/data/sort_output/mouse2/natural_image"))

12

In [37]:
neuron_date = pd.crosstab(all_cluster_inf['Neuron'], all_cluster_inf['date'])   
neuron_date[neuron_date > 1] = 1
neuron_date = neuron_date.sum(axis=1)
neuron_date = neuron_date[neuron_date == 12]
neuron_date = neuron_date.index

In [38]:
all_cluster_inf = all_cluster_inf[all_cluster_inf['Neuron'].isin(neuron_date)]
all_cluster_inf['cluster_date'] = all_cluster_inf['date']  + "_" +  all_cluster_inf['cluster']
all_spike_inf['cluster_date'] = all_spike_inf['date']  + "_" +  all_spike_inf['cluster']

all_spike_inf = all_spike_inf[all_spike_inf['cluster_date'].isin(list(all_cluster_inf['cluster_date']))]

In [39]:
import numpy as np

def calculate_position_waveform(row, channel_position, channel_indices, power=2):
    x_target = row['position_1']
    y_target = row['position_2']
    probe_group = str(row['probe_group'])
    channels = channel_indices[probe_group]  
    waveforms = row['waveform']  
    
    distances = []
    for channel in channels:
        x_channel, y_channel = channel_position.get(channel, [np.nan, np.nan])
        if np.isnan(x_channel):  
            continue
        distance = np.sqrt((x_target - x_channel)**2 + (y_target - y_channel)**2)
        distances.append(distance)
    
    if not distances:  
        return np.zeros(61)
    
    #IDW
    weights = 1 / (np.array(distances) ** power)
    if np.any(distances == 0):
        zero_idx = np.argwhere(distances == 0).flatten()
        return waveforms[:, zero_idx[0]]
    
    weights /= np.sum(weights)
    
    synthesized_waveform = np.zeros(61)
    for t in range(61): 
        weighted_sum = np.dot(waveforms[t, :], weights)
        synthesized_waveform[t] = weighted_sum
    
    return synthesized_waveform

all_cluster_inf['position_waveform'] = all_cluster_inf.apply(
    calculate_position_waveform, 
    axis=1, 
    args=(channel_position, channel_indices, 2)  
)

In [40]:
for neuron in all_cluster_inf['Neuron']:
    temp = all_cluster_inf[all_cluster_inf['Neuron'] == neuron]
    temp.index = temp['cluster_date']
    df_expanded = temp['position_waveform'].apply(pd.Series)
    df_expanded.to_csv(f"/media/ubuntu/sda/data/filter_neuron/mouse_2/natural_image/waveform/waveform_mean_{neuron}.csv")

In [41]:
num = 0
results = {}
folder_path = '/media/ubuntu/sda/data/filter_neuron/mouse_2/natural_image/waveform'

csv_files = [f for f in os.listdir(folder_path) if f.startswith('waveform_mean_Neuron_') and f.endswith('.csv')]

label_df = pd.DataFrame()
for csv_file in csv_files:
    df = pd.read_csv(os.path.join(folder_path, csv_file), index_col=0)
    
    from sklearn.cluster import DBSCAN
    from sklearn.decomposition import PCA
    
    pca = PCA(n_components=2)
    principal_components = pca.fit_transform(df)

    eps = 1.8
    min_samples = 1

    dbscan = DBSCAN(eps=eps, min_samples=min_samples)
    dbscan.fit(principal_components)

    label = pd.DataFrame(dbscan.labels_, columns=['labels'])
    label['cluster_date'] = df.index
    label['date'] = label['cluster_date'].apply(lambda x: x.split('_')[0])

    remain_label = label['labels'].value_counts()
    remain_label = remain_label[remain_label >= 12]
    for i in remain_label.index:
        temp = label[label['labels'] == i]
        if temp['date'].nunique() != 12:
            remain_label = remain_label.drop(i)
    label = label[label['labels'].isin(remain_label.index)]
    for i in label['labels'].unique():
        results[num] = label.loc[label['labels'] ==i, 'cluster_date'].values
        num += 1

In [42]:
len(results)

6

In [43]:
all_cluster_inf['Neuron'] = None
for key,item in results.items():
    all_cluster_inf.loc[all_cluster_inf['cluster_date'].isin(item), 'Neuron'] = f'Neuron_{key+1}'

In [44]:
all_cluster_inf = all_cluster_inf.dropna(subset=['Neuron'])
all_spike_inf = all_spike_inf[all_spike_inf['cluster_date'].isin(all_cluster_inf['cluster_date'].unique())]

In [45]:
all_spike_inf['Neuron'] = None
for i in range(len(all_cluster_inf)):
    all_spike_inf.loc[all_spike_inf['cluster_date'] == all_cluster_inf.iloc[i, 28], "Neuron"] = all_cluster_inf.iloc[i, 27]

In [46]:
all_cluster_inf['neuron_date'] = all_cluster_inf['date'] + "_" + all_cluster_inf['Neuron']
all_spike_inf['neuron_date'] = all_spike_inf['date'] + "_" + all_spike_inf['Neuron']

In [47]:
waveform_mean = pd.DataFrame()
csv_files = [f for f in os.listdir('/media/ubuntu/sda/data/filter_neuron/mouse_2/natural_image/waveform') if f.startswith('waveform_mean_Neuron_') and f.endswith('.csv')]
for csv_file in csv_files:
    df = pd.read_csv(os.path.join('/media/ubuntu/sda/data/filter_neuron/mouse_2/natural_image/waveform', csv_file), index_col=0)
    waveform_mean = pd.concat((waveform_mean, df), axis=0)

waveform_mean = waveform_mean.loc[list(all_cluster_inf['cluster_date'])]

In [48]:
all_cluster_inf = all_cluster_inf.set_index('cluster_date')
all_cluster_inf = all_cluster_inf.join(waveform_mean, how="right")

In [49]:
from scipy.stats import pearsonr
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import colorsys
from matplotlib.backends.backend_pdf import PdfPages

def generate_base_colors(n, saturation=0.6, lightness=0.5):
    hues = np.linspace(0, 1, n, endpoint=False)  
    base_colors = []
    for h in hues:
        rgb = colorsys.hls_to_rgb(h, lightness, saturation)
        base_colors.append(rgb)
    return base_colors

n_neurons = len(all_cluster_inf['Neuron'].unique())
base_palette = generate_base_colors(n_neurons) 


def get_gradient_palette(base_color, n_levels=16, reverse=False):
    palette = sns.light_palette(base_color, n_levels, reverse=reverse)
    return palette


with PdfPages('figure/cluster_view.pdf') as pdf:
    n_neurons = len(all_cluster_inf['Neuron'].unique())
    base_palette = generate_base_colors(n_neurons)
    
    for idx, neuron in enumerate(all_cluster_inf['Neuron'].unique()):
        temp = all_cluster_inf[all_cluster_inf['Neuron'] == neuron]
        
        current_base_color = base_palette[idx]
        line_palette = get_gradient_palette(current_base_color, n_levels=25, reverse=False)
        
        fig, ax = plt.subplots(figsize=(1.5, 1.5))
        for i in range(11):
            sns.lineplot(
                x=range(32),
                y=temp.iloc[i, 41:73],
                color=line_palette[i],  
                ax=ax
            )
        
        ax.set_ylabel("Amplitude")
        ax.set_xticks([])
        ax.set_title(neuron)
        pdf.savefig(fig)
        plt.close(fig)

In [50]:
all_cluster_inf

Unnamed: 0_level_0,cluster,num_spikes,firing_rate,presence_ratio,snr,isi_violations_ratio,isi_violations_count,rp_contamination,rp_violations,sliding_rp_violation,...,51,52,53,54,55,56,57,58,59,60
cluster_date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
031722_13,13,41026,7.225309,1.0,8.119507,1.186358,1055,1.000000,601,,...,0.278167,0.232467,0.189995,0.151885,0.109964,0.071485,0.046131,0.026353,0.001617,-0.025760
031722_15,15,89477,15.758275,1.0,7.196227,0.111111,470,0.056940,156,0.025,...,0.163522,0.152965,0.132256,0.104174,0.071485,0.041073,0.013200,-0.014378,-0.038284,-0.047708
031722_16,16,92800,16.343506,1.0,10.314345,0.158241,720,0.123693,352,0.060,...,0.309491,0.288109,0.260605,0.221891,0.168960,0.115705,0.069822,0.023557,-0.024651,-0.058007
031722_30,30,36643,6.453395,1.0,5.714277,0.277693,197,0.126395,56,0.030,...,0.240426,0.224323,0.206875,0.173951,0.122084,0.066111,0.020023,-0.020407,-0.057979,-0.072519
031722_32,32,145392,25.605766,1.0,10.241227,0.234138,2615,0.169970,1158,0.040,...,0.283139,0.265236,0.234398,0.193061,0.143500,0.097541,0.055691,0.011361,-0.032671,-0.060431
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
092222_10,10,39694,16.363745,1.0,7.059677,0.150362,293,0.107375,132,0.065,...,0.162862,0.143825,0.112307,0.080696,0.057129,0.032310,0.014297,0.000486,-0.006694,-0.021599
092222_12,12,48010,19.791993,1.0,7.762495,0.112606,321,0.076602,140,0.045,...,0.257173,0.219855,0.163727,0.105242,0.079799,0.056887,0.045369,0.036280,0.031077,-0.002066
092222_20,20,15887,6.549373,1.0,5.538572,0.198623,62,0.074885,15,0.015,...,0.252958,0.205498,0.136522,0.071691,0.043233,0.019863,0.009704,0.008632,0.011757,-0.014989
092222_24,24,69462,28.635522,1.0,11.103218,0.319578,1907,0.307426,1035,0.150,...,0.276599,0.249447,0.207006,0.162607,0.122829,0.082905,0.053420,0.025603,0.000826,-0.038073


In [51]:
all_cluster_inf.to_csv('/media/ubuntu/sda/data/filter_neuron/mouse_2/natural_image/cluster_inf.tsv', sep = '\t')
all_spike_inf.to_csv("/media/ubuntu/sda/data/filter_neuron/mouse_2/natural_image/spike_inf.tsv", sep='\t')