### Basic Idea
- Foundation data: Month 1 & 2
    - cluster_inf
    - spike_inf
    - neuron_inf
- Month 3 participate:
    - for each neuron in neuron_inf:
        - Match position
        - Match waveform
- Update data: Month 1 & 2 & 3
- Bin data & train(month 1 & 2) & test(month 3)

#### Function Setting

In [1]:
from pathlib import Path
from kilosort.io import load_ops
import sys
import spikeinterface as si
import matplotlib.pyplot as plt
import json
import numpy as np
import pandas as pd
import seaborn as sns
import warnings
from kilosort import io
import os
warnings.filterwarnings('ignore')
import torch
import torch.nn.functional as F
from scipy.stats import pearsonr


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
channel_indices = {
        "1": [1, 3, 5, 6, 9, 11],
        "2": [13, 15, 17, 19, 21, 23],
        "3": [24, 25, 26, 27, 28, 29],
        "4": [12, 14, 16, 18, 20, 22],
        "5": [0, 2, 4, 6, 8, 10]
        }
channel_position = {
    0: [650, 0],
    2: [650, 50],
    4: [650, 100],
    6: [600, 100],
    8: [600, 50],
    10: [600, 0],
    1: [0, 0],
    3: [0, 50],
    5: [0, 100],
    7: [50, 100],
    9: [50, 50],
    11: [50, 0],
    13: [150, 200], 
    15: [150, 250],
    17: [150, 300],
    19: [200, 300],
    21: [200, 250],
    23: [200, 200],
    12: [500, 200],
    14: [500, 250],
    16: [500, 300],
    18: [450, 300],
    20: [450, 250],
    22: [450, 200],
    24: [350, 400],
    26: [350, 450],
    28: [350, 500],
    25: [300, 400],
    27: [300, 450],
    29: [300, 500]}

In [3]:
def get_spike_inf(file_path, date):
    cluster_inf = pd.read_csv(file_path + "/analyzer_kilosort4_binary/extensions/quality_metrics/metrics.csv")
    cluster_inf.columns = ['cluster', 'num_spikes', 'firing_rate', 'presence_ratio', 'snr',
                           'isi_violations_ratio', 'isi_violations_count', 'rp_contamination',
                           'rp_violations', 'sliding_rp_violation', 'amplitude_cutoff',
                           'amplitude_median', 'amplitude_cv_median', 'amplitude_cv_range',
                           'sync_spike_2', 'sync_spike_4', 'sync_spike_8', 'firing_range',
                           'drift_ptp', 'drift_std', 'drift_mad', 'sd_ratio']
    
    cluster_inf['cluster'] = cluster_inf['cluster'].astype(str)
    cluster_inf['position_1'] = None
    cluster_inf['position_2'] = None

    spike_clusters = pd.DataFrame(np.load(file_path + "/kilosort4/sorter_output/spike_clusters.npy").astype(str))
    spike_positions = pd.DataFrame(np.load(file_path + "/kilosort4/sorter_output/spike_positions.npy").astype(float))
    spike_templates = pd.DataFrame(np.load(file_path + "/kilosort4/sorter_output/spike_templates.npy"))
    spike_times = pd.DataFrame(np.load(file_path + "/kilosort4/sorter_output/spike_times.npy").astype(int))
    tf = pd.DataFrame(np.load(file_path + "/kilosort4/sorter_output/tF.npy")[:, 0, :])

    spike_inf = pd.concat((spike_clusters, spike_positions, spike_templates, spike_times, tf), axis=1)
    spike_inf.columns = ['cluster', 'position_1', 'position_2', 'templates', 'time', 'PC_1', 'PC_2', 'PC_3', 'PC_4', 'PC_5', 'PC_6']

    for i in spike_inf['cluster'].value_counts().index:
        temp = spike_inf[spike_inf['cluster'] == i]
        cluster_inf.loc[cluster_inf['cluster'] == i, 'position_1'] = np.mean(temp['position_1'])
        cluster_inf.loc[cluster_inf['cluster'] == i, 'position_2'] = np.mean(temp['position_2'])

    cluster_inf['probe_group'] = "1"

    for i in spike_inf['cluster'].value_counts().index:
        cluster_rows = cluster_inf[cluster_inf['cluster'] == i]
        if (cluster_rows['position_1'] > 100).any() and (cluster_rows['position_1'] < 250).any():
            cluster_inf.loc[cluster_inf['cluster'] == i, 'probe_group'] = "2"
        elif (cluster_rows['position_1'] > 250).any() and (cluster_rows['position_1'] < 400).any():
            cluster_inf.loc[cluster_inf['cluster'] == i, 'probe_group'] = "3"
        elif (cluster_rows['position_1'] > 400).any() and (cluster_rows['position_1'] < 550).any():
            cluster_inf.loc[cluster_inf['cluster'] == i, 'probe_group'] = "4"
        elif (cluster_rows['position_1'] > 550).any():
            cluster_inf.loc[cluster_inf['cluster'] == i, 'probe_group'] = "5"

    waveform = np.load(file_path + "/kilosort4/sorter_output/templates.npy")
    cluster_inf['waveform'] = [waveform[i] for i in range(waveform.shape[0])]

    cluster_inf = cluster_inf[((cluster_inf['snr'] > 3) & (cluster_inf['num_spikes'] > int(5000))) | ((cluster_inf['snr'] < 3) & (cluster_inf['num_spikes'] > 8000))]
    spike_inf = spike_inf[spike_inf['cluster'].isin(list(cluster_inf['cluster']))]
    spike_inf = spike_inf[spike_inf['time'] > 200]
    cluster_inf['date'] = date
    spike_inf['date'] = date
    
    channel_indices = {
        "1": [1, 3, 5, 6, 9, 11],
        "2": [13, 15, 17, 19, 21, 23],
        "3": [24, 25, 26, 27, 28, 29],
        "4": [12, 14, 16, 18, 20, 22],
        "5": [0, 2, 4, 6, 8, 10]
        }

    for index, row in cluster_inf.iterrows():
        probe_group = row['probe_group']
        if probe_group in channel_indices:
            selected_channels = channel_indices[probe_group]
            cluster_inf.at[index, 'waveform'] = row['waveform'][:, selected_channels]

    return cluster_inf, spike_inf

def calculate_position(row):
    probe_group = str(row['probe_group'])
    channels = channel_indices[probe_group]
    waveform = row['waveform'] 
    
    a_squared = [np.sum(waveform[:, j]**2) for j in range(len(channels))]
    
    sum_x_a = 0
    sum_y_a = 0
    sum_a = 0
    
    for j, channel in enumerate(channels):
        x_i, y_i = channel_position.get(channel, [0, 0])  
        a_i_sq = a_squared[j]
        
        sum_x_a += x_i * a_i_sq
        sum_y_a += y_i * a_i_sq
        sum_a += a_i_sq
    
    if sum_a == 0:
        return pd.Series({'position_1': 0, 'position_2': 0})
    
    x_hat = sum_x_a / sum_a
    y_hat = sum_y_a / sum_a
    return pd.Series({'position_1': x_hat, 'position_2': y_hat})

def calculate_position_waveform(row, channel_position, channel_indices, power=2):
    x_target = row['position_1']
    y_target = row['position_2']
    probe_group = str(row['probe_group'])
    channels = channel_indices[probe_group]  
    waveforms = row['waveform']  
    
    distances = []
    for channel in channels:
        x_channel, y_channel = channel_position.get(channel, [np.nan, np.nan])
        if np.isnan(x_channel):  
            continue
        distance = np.sqrt((x_target - x_channel)**2 + (y_target - y_channel)**2)
        distances.append(distance)
    
    if not distances:  
        return np.zeros(61)
    
    #IDW
    weights = 1 / (np.array(distances) ** power)
    if np.any(distances == 0):
        zero_idx = np.argwhere(distances == 0).flatten()
        return waveforms[:, zero_idx[0]]
    
    weights /= np.sum(weights)
    
    synthesized_waveform = np.zeros(61)
    for t in range(61): 
        weighted_sum = np.dot(waveforms[t, :], weights)
        synthesized_waveform[t] = weighted_sum
    
    return synthesized_waveform

#### Foundation data

In [4]:
date_order = ['030222', '042422', '052322', '062322', '082422', 
              '092222', '102522', '112822', '122322', 
              '012123', '022423', '032323', '042323', '052423', '062323', '072123']

In [5]:
all_cluster_inf = pd.DataFrame()
all_spike_inf = pd.DataFrame()

for date in date_order[:2]:
    cluster_inf, spike_inf = get_spike_inf(file_path=f"/media/ubuntu/sda/data/sort_output/mouse5/natural_image/{date}", date = date)
    all_cluster_inf = pd.concat([all_cluster_inf, cluster_inf], ignore_index=True)
    all_spike_inf = pd.concat([all_spike_inf, spike_inf], ignore_index=True)

In [6]:
all_cluster_inf[['position_1', 'position_2']] = all_cluster_inf.apply(calculate_position, axis=1)
all_cluster_inf['Neuron'] = None
current_max_neuron = 1  

for i in range(1, len(all_cluster_inf)):
    current_pos1 = all_cluster_inf.at[i, 'position_1']
    current_pos2 = all_cluster_inf.at[i, 'position_2']
    
    mask = (
        (all_cluster_inf.loc[:i-1, 'position_1'] - current_pos1).abs().lt(10) & 
        (all_cluster_inf.loc[:i-1, 'position_2'] - current_pos2).abs().lt(10)
    )
    
    matched = all_cluster_inf.loc[:i-1][mask]
    
    if not matched.empty:
        all_cluster_inf.at[i, 'Neuron'] = matched['Neuron'].iloc[-1]
    else:
        current_max_neuron += 1
        all_cluster_inf.at[i, 'Neuron'] = f'Neuron_{current_max_neuron}'

neuron_date = pd.crosstab(all_cluster_inf['Neuron'], all_cluster_inf['date'])   
neuron_date[neuron_date > 1] = 1
neuron_date = neuron_date.sum(axis=1)
neuron_date = neuron_date[neuron_date == 2]
neuron_date = neuron_date.index

all_cluster_inf = all_cluster_inf[all_cluster_inf['Neuron'].isin(neuron_date)]
all_cluster_inf['cluster_date'] = all_cluster_inf['date']  + "_" +  all_cluster_inf['cluster']

all_cluster_inf['position_waveform'] = None
for idx, row in all_cluster_inf.iterrows():
    all_cluster_inf.at[idx, 'position_waveform'] = calculate_position_waveform(row, channel_position, channel_indices, 2)

In [7]:
waveform_dict = {}
for neuron in all_cluster_inf['Neuron']:
    temp = all_cluster_inf[all_cluster_inf['Neuron'] == neuron]
    temp.index = temp['cluster_date']
    waveform_dict[neuron] = temp['position_waveform'].apply(pd.Series)


In [8]:
from sklearn.cluster import DBSCAN
from sklearn.decomposition import PCA

num = 0
results = {}

for _, df in waveform_dict.items():
    
    pca = PCA(n_components=2)
    principal_components = pca.fit_transform(df)

    eps = 3
    min_samples = 1

    dbscan = DBSCAN(eps=eps, min_samples=min_samples)
    dbscan.fit(principal_components)

    label = pd.DataFrame(dbscan.labels_, columns=['labels'])
    label['cluster_date'] = df.index
    label['date'] = label['cluster_date'].apply(lambda x: x.split('_')[0])

    remain_label = label['labels'].value_counts()
    remain_label = remain_label[remain_label >= 2]
    for i in remain_label.index:
        temp = label[label['labels'] == i]
        if temp['date'].nunique() != 2:
            remain_label = remain_label.drop(i)
    label = label[label['labels'].isin(remain_label.index)]
    for i in label['labels'].unique():
        results[num] = label.loc[label['labels'] ==i, 'cluster_date'].values
        num += 1

In [9]:
all_cluster_inf['Neuron'] = None
for key,item in results.items():
    all_cluster_inf.loc[all_cluster_inf['cluster_date'].isin(item), 'Neuron'] = f'Neuron_{key+1}'

all_cluster_inf = all_cluster_inf.dropna(subset=['Neuron'])
all_cluster_inf['neuron_date'] = all_cluster_inf['date'] + "_" + all_cluster_inf['Neuron']

In [10]:
waveform_mean = pd.DataFrame()
for _, df in waveform_dict.items():
    waveform_mean = pd.concat((waveform_mean, df), axis=0)
waveform_mean = waveform_mean.loc[list(all_cluster_inf['cluster_date'])]

all_cluster_inf = all_cluster_inf.set_index('cluster_date')
all_cluster_inf = all_cluster_inf.join(waveform_mean, how="right")

In [11]:
all_cluster_inf['cluster_date'] = all_cluster_inf.index

In [12]:
all_spike_inf['cluster_date'] = all_spike_inf['date']  + "_" +  all_spike_inf['cluster']
all_spike_inf = all_spike_inf[all_spike_inf['cluster_date'].isin(all_cluster_inf['cluster_date'].values)]

all_spike_inf['Neuron'] = None
for i in range(len(all_cluster_inf)):
    all_spike_inf.loc[all_spike_inf['cluster_date'] == all_cluster_inf.iloc[i, -1], "Neuron"] = all_cluster_inf.iloc[i, 27]

In [14]:
all_spike_inf.to_csv("foundation_data_spike.tsv", sep = '\t', index=False)
all_cluster_inf.to_csv("foundation_data_cluster.tsv", sep = '\t', index= False)

In [15]:
all_neuron_inf = pd.DataFrame()
for neuron in all_cluster_inf['Neuron'].unique():
    temp = all_cluster_inf[all_cluster_inf['Neuron'] == neuron]
    neuron_position_1 = temp['position_1'].mean()
    neuron_position_2 = temp['position_2'].mean()
    neuron_waveform = temp.iloc[:, 30:-1].mean(axis = 0)

    df_temp = pd.DataFrame([neuron, neuron_position_1, neuron_position_2])
    df_temp = pd.concat((df_temp, neuron_waveform.T), axis=0)
    all_neuron_inf = pd.concat((all_neuron_inf, df_temp.T), axis=0)

In [16]:
all_neuron_inf.columns = ['Neuron', 'position_1', 'position_2'] + list(range(1, 62))

In [17]:
all_neuron_inf.to_csv("foundation_data_neuron.tsv", sep = '\t', index= False)

#### Closed_loop

In [18]:
for date_len in range(2, len(date_order)):
    date = date_order[date_len]
    cluster_inf, spike_inf = get_spike_inf(file_path=f"/media/ubuntu/sda/data/sort_output/mouse5/natural_image/{date}", date = date)

    cluster_inf[['position_1', 'position_2']] = cluster_inf.apply(calculate_position, axis=1)
    cluster_inf['Neuron'] = None

    cluster_inf['cluster_date'] = cluster_inf['date']  + "_" +  cluster_inf['cluster']

    cluster_inf['position_waveform'] = cluster_inf.apply(
        calculate_position_waveform, 
        axis=1, 
        args=(channel_position, channel_indices, 2))

    cluster_inf = cluster_inf.reset_index()

    neuron_match_position_dict = {}
    for neuron in all_neuron_inf['Neuron'].unique():
        neuron_match_position_dict[neuron] = []

    for i in range(1, len(cluster_inf)):
        current_pos1 = cluster_inf.at[i, 'position_1']
        current_pos2 = cluster_inf.at[i, 'position_2']

        for j in range(len(all_neuron_inf)):
            if (abs(current_pos1 - all_neuron_inf.iloc[j, 1]) <= 10) & (abs(current_pos2 - all_neuron_inf.iloc[j, 2]) <= 10):
                neuron_match_position_dict[all_neuron_inf.iloc[j, 0]].append(cluster_inf.at[i, 'cluster_date'])
                break

    neuron_match_position_dict = {}
    for neuron in all_neuron_inf['Neuron'].unique():
        neuron_match_position_dict[neuron] = []

    for i in range(1, len(cluster_inf)):
        current_pos1 = cluster_inf.at[i, 'position_1']
        current_pos2 = cluster_inf.at[i, 'position_2']

        for j in range(len(all_neuron_inf)):
            if (abs(current_pos1 - all_neuron_inf.iloc[j, 1]) <= 10) & (abs(current_pos2 - all_neuron_inf.iloc[j, 2]) <= 10):
                neuron_match_position_dict[all_neuron_inf.iloc[j, 0]].append(cluster_inf.at[i, 'cluster_date'])
                break

    for neuron in neuron_match_position_dict.keys():
        temp = cluster_inf[cluster_inf['cluster_date'].isin(neuron_match_position_dict[neuron])]
        for i in range(len(temp)):
            waveform = temp['position_waveform'].values[i]
            corr, _ = pearsonr(waveform, all_neuron_inf[all_neuron_inf['Neuron'] == neuron].iloc[:, 3:].values[0].astype(float))
            if corr <= 0.9:
                neuron_match_position_dict[neuron].remove(temp['cluster_date'].values[i])

    neuron_match_position_dict = {k: v for k, v in neuron_match_position_dict.items() if v}

    for index, row in cluster_inf.iterrows():
        cluster_date = row['cluster_date']
        for neuron, dates in neuron_match_position_dict.items():
            if cluster_date in dates:
                cluster_inf.at[index, 'Neuron'] = neuron
                break
    
    cluster_inf = cluster_inf.dropna(subset=['Neuron'])
    spike_inf['cluster_date'] = spike_inf['date']  + "_" +  spike_inf['cluster']
    spike_inf = spike_inf[spike_inf['cluster_date'].isin(cluster_inf['cluster_date'].values)]

    spike_inf['Neuron'] = None
    for i in range(len(cluster_inf)):
        spike_inf.loc[spike_inf['cluster_date'] == cluster_inf.iloc[i, -2], "Neuron"] = cluster_inf.iloc[i, -3]

    remained_neruon = cluster_inf['Neuron'].unique()
    all_cluster_inf = pd.concat((all_cluster_inf, cluster_inf), axis= 0)
    all_spike_inf = pd.concat((all_spike_inf, spike_inf), axis=0)
    all_cluster_inf = all_cluster_inf[all_cluster_inf['Neuron'].isin(remained_neruon)]
    all_spike_inf = all_spike_inf[all_spike_inf['Neuron'].isin(remained_neruon)]
    all_neuron_inf = all_neuron_inf[all_neuron_inf['Neuron'].isin(remained_neruon)]

    all_cluster_inf.to_csv(f"cluster_{date}.tsv", sep = '\t')
    all_spike_inf.to_csv(f'spike_{date}.tsv', sep = '\t')
    all_neuron_inf.to_csv(f'neuron_{date}.tsv', sep = '\t')
        