In [1]:
from pathlib import Path
from kilosort.io import load_ops
import sys
import spikeinterface as si
import matplotlib.pyplot as plt
import json
import numpy as np
import pandas as pd
import seaborn as sns
import warnings
from kilosort import io
import os
warnings.filterwarnings('ignore')
import torch
import torch.nn.functional as F
from scipy.stats import pearsonr
from sklearn.cluster import DBSCAN
from sklearn.decomposition import PCA
target_folder = os.path.abspath('/media/ubuntu/sda/data/paper_architecture/utils')

if target_folder not in sys.path:
    sys.path.append(target_folder)
import spiking_sorting

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
channel_indices = {
        "1": [1, 3, 5, 6, 9, 11],
        "2": [13, 15, 17, 19, 21, 23],
        "3": [24, 25, 26, 27, 28, 29],
        "4": [12, 14, 16, 18, 20, 22],
        "5": [0, 2, 4, 6, 8, 10]
        }
channel_position = {
    0: [650, 0],
    2: [650, 50],
    4: [650, 100],
    6: [600, 100],
    8: [600, 50],
    10: [600, 0],
    1: [0, 0],
    3: [0, 50],
    5: [0, 100],
    7: [50, 100],
    9: [50, 50],
    11: [50, 0],
    13: [150, 200], 
    15: [150, 250],
    17: [150, 300],
    19: [200, 300],
    21: [200, 250],
    23: [200, 200],
    12: [500, 200],
    14: [500, 250],
    16: [500, 300],
    18: [450, 300],
    20: [450, 250],
    22: [450, 200],
    24: [350, 400],
    26: [350, 450],
    28: [350, 500],
    25: [300, 400],
    27: [300, 450],
    29: [300, 500]}


mouse_month = {
    'mouse6': ['021322', '022522', '031722', '042422', 
              '052422', '062422', '072322', '082322', 
              '092422', '102122', '112022', '122022', 
              '012123', '022223', '032123', '042323'],
    'mouse5': ['030222', '042422', '052322', '062322', '082422', 
              '092222', '102522', '112822', '122322', 
              '012123', '022423', '032323', '042323', '052423', '062323', '072123'],
    'mouse2': ['031722', '042322', '052322', '062422', '072322', '082322', 
              '092222', '102522', '112822', '122022', '012123', '022223'],
    'mouse11': ['021722', '030122', '032322', '042322', '052322', '052422', 
              '062422', '072422', '082422', '092222', '102522', '112822', '122322', '012123', 
              '022423', '032323', '042323', '052423', '062323', '072123']
}

In [38]:
for mouse in mouse_month.keys():
    os.makedirs(f'/media/ubuntu/sda/data/paper_architecture/01_closed_loop/{mouse}', exist_ok=True)
    os.makedirs(f'/media/ubuntu/sda/data/paper_architecture/01_closed_loop/{mouse}/pkl', exist_ok=True)
    os.makedirs(f'/media/ubuntu/sda/data/paper_architecture/01_closed_loop/{mouse}/tsv', exist_ok=True)

    date_order = mouse_month[mouse]
    all_cluster_inf = pd.DataFrame()
    all_spike_inf = pd.DataFrame()

    for date in date_order[:2]:
        cluster_inf, spike_inf = spiking_sorting.get_spike_inf(file_path=f"/media/ubuntu/sda/data/paper_architecture/00_sorted_output/{mouse}/{date}", date = date)
        all_cluster_inf = pd.concat([all_cluster_inf, cluster_inf], ignore_index=True)
        all_spike_inf = pd.concat([all_spike_inf, spike_inf], ignore_index=True)

    all_cluster_inf[['position_1', 'position_2']] = all_cluster_inf.apply(spiking_sorting.calculate_position, axis=1)
    all_cluster_inf['Neuron'] = None
    current_max_neuron = 1  

    for i in range(1, len(all_cluster_inf)):
        current_pos1 = all_cluster_inf.at[i, 'position_1']
        current_pos2 = all_cluster_inf.at[i, 'position_2']
        
        mask = (
            (all_cluster_inf.loc[:i-1, 'position_1'] - current_pos1).abs().lt(10) & 
            (all_cluster_inf.loc[:i-1, 'position_2'] - current_pos2).abs().lt(10)
        )
        
        matched = all_cluster_inf.loc[:i-1][mask]
        
        if not matched.empty:
            all_cluster_inf.at[i, 'Neuron'] = matched['Neuron'].iloc[-1]
        else:
            current_max_neuron += 1
            all_cluster_inf.at[i, 'Neuron'] = f'Neuron_{current_max_neuron}'

    neuron_date = pd.crosstab(all_cluster_inf['Neuron'], all_cluster_inf['date'])   
    neuron_date[neuron_date > 1] = 1
    neuron_date = neuron_date.sum(axis=1)
    neuron_date = neuron_date[neuron_date == 2]
    neuron_date = neuron_date.index
    all_cluster_inf = all_cluster_inf[all_cluster_inf['Neuron'].isin(neuron_date)]
    all_cluster_inf['cluster_date'] = all_cluster_inf['date']  + "_" +  all_cluster_inf['cluster']
    all_cluster_inf['position_waveform'] = None
    for idx, row in all_cluster_inf.iterrows():
        all_cluster_inf.at[idx, 'position_waveform'] = spiking_sorting.calculate_position_waveform(row, channel_position, channel_indices, 2)

    waveform_dict = {}
    for neuron in all_cluster_inf['Neuron']:
        temp = all_cluster_inf[all_cluster_inf['Neuron'] == neuron]
        temp.index = temp['cluster_date']
        waveform_dict[neuron] = temp['position_waveform'].apply(pd.Series)
    
    num = 0
    results = {}

    for _, df in waveform_dict.items():
        
        pca = PCA(n_components=2)
        principal_components = pca.fit_transform(df)

        eps = 3
        min_samples = 1

        dbscan = DBSCAN(eps=eps, min_samples=min_samples)
        dbscan.fit(principal_components)

        label = pd.DataFrame(dbscan.labels_, columns=['labels'])
        label['cluster_date'] = df.index
        label['date'] = label['cluster_date'].apply(lambda x: x.split('_')[0])

        remain_label = label['labels'].value_counts()
        remain_label = remain_label[remain_label >= 2]
        for i in remain_label.index:
            temp = label[label['labels'] == i]
            if temp['date'].nunique() != 2:
                remain_label = remain_label.drop(i)
        label = label[label['labels'].isin(remain_label.index)]
        for i in label['labels'].unique():
            results[num] = label.loc[label['labels'] ==i, 'cluster_date'].values
            num += 1

    all_cluster_inf['Neuron'] = None
    for key,item in results.items():
        all_cluster_inf.loc[all_cluster_inf['cluster_date'].isin(item), 'Neuron'] = f'Neuron_{key+1}'

    all_cluster_inf = all_cluster_inf.dropna(subset=['Neuron'])
    all_cluster_inf['neuron_date'] = all_cluster_inf['date'] + "_" + all_cluster_inf['Neuron']

    waveform_mean = pd.DataFrame()
    for _, df in waveform_dict.items():
        waveform_mean = pd.concat((waveform_mean, df), axis=0)
    waveform_mean = waveform_mean.loc[list(all_cluster_inf['cluster_date'])]

    all_cluster_inf = all_cluster_inf.set_index('cluster_date')
    all_cluster_inf = all_cluster_inf.join(waveform_mean, how="right")

    all_cluster_inf['cluster_date'] = all_cluster_inf.index

    all_spike_inf['cluster_date'] = all_spike_inf['date']  + "_" +  all_spike_inf['cluster']
    all_spike_inf = all_spike_inf[all_spike_inf['cluster_date'].isin(all_cluster_inf['cluster_date'].values)]

    all_spike_inf['Neuron'] = None
    for i in range(len(all_cluster_inf)):
        all_spike_inf.loc[all_spike_inf['cluster_date'] == all_cluster_inf.iloc[i, -1], "Neuron"] = all_cluster_inf.iloc[i, 27]

    all_spike_inf.to_csv(f"/media/ubuntu/sda/data/paper_architecture/01_closed_loop/{mouse}/pkl/foundation_data_spike.tsv", sep = '\t', index=False)
    all_cluster_inf.to_csv(f"/media/ubuntu/sda/data/paper_architecture/01_closed_loop/{mouse}/pkl/foundation_data_cluster.tsv", sep = '\t', index= False)

    all_spike_inf.reset_index(drop=True).to_pickle(f"/media/ubuntu/sda/data/paper_architecture/01_closed_loop/{mouse}/pkl/foundation_data_spike.pkl")
    all_cluster_inf.reset_index(drop=True).to_pickle(f"/media/ubuntu/sda/data/paper_architecture/01_closed_loop/{mouse}/pkl/foundation_data_cluster.pkl")

    all_neuron_inf = pd.DataFrame()
    for neuron in all_cluster_inf['Neuron'].unique():
        temp = all_cluster_inf[all_cluster_inf['Neuron'] == neuron]
        neuron_position_1 = temp['position_1'].mean()
        neuron_position_2 = temp['position_2'].mean()
        neuron_waveform = temp.iloc[:, 30:-1].mean(axis = 0)

        df_temp = pd.DataFrame([neuron, neuron_position_1, neuron_position_2])
        df_temp = pd.concat((df_temp, neuron_waveform.T), axis=0)
        all_neuron_inf = pd.concat((all_neuron_inf, df_temp.T), axis=0)

    all_neuron_inf.columns = ['Neuron', 'position_1', 'position_2'] + list(range(1, 62))
    all_neuron_inf.to_csv(f"/media/ubuntu/sda/data/paper_architecture/01_closed_loop/{mouse}/tsv/foundation_data_neuron.tsv", sep = '\t', index= False)
    all_neuron_inf.reset_index(drop=True).to_pickle(f'/media/ubuntu/sda/data/paper_architecture/01_closed_loop/{mouse}/pkl/foundation_data_neuron.pkl')
    
    for date_len in range(2, len(date_order)):
        date = date_order[date_len]
        cluster_inf, spike_inf = spiking_sorting.get_spike_inf(file_path=f"/media/ubuntu/sda/data/paper_architecture/00_sorted_output/{mouse}/{date}", date = date)

        cluster_inf[['position_1', 'position_2']] = cluster_inf.apply(spiking_sorting.calculate_position, axis=1)
        cluster_inf['Neuron'] = None

        cluster_inf['cluster_date'] = cluster_inf['date']  + "_" +  cluster_inf['cluster']

        cluster_inf['position_waveform'] = cluster_inf.apply(
            spiking_sorting.calculate_position_waveform, 
            axis=1, 
            args=(channel_position, channel_indices, 2))

        cluster_inf = cluster_inf.reset_index()

        neuron_match_position_dict = {}
        for neuron in all_neuron_inf['Neuron'].unique():
            neuron_match_position_dict[neuron] = []

        for i in range(1, len(cluster_inf)):
            current_pos1 = cluster_inf.at[i, 'position_1']
            current_pos2 = cluster_inf.at[i, 'position_2']

            for j in range(len(all_neuron_inf)):
                if (abs(current_pos1 - all_neuron_inf.iloc[j, 1]) <= 10) & (abs(current_pos2 - all_neuron_inf.iloc[j, 2]) <= 10):
                    neuron_match_position_dict[all_neuron_inf.iloc[j, 0]].append(cluster_inf.at[i, 'cluster_date'])
                    break

        neuron_match_position_dict = {}
        for neuron in all_neuron_inf['Neuron'].unique():
            neuron_match_position_dict[neuron] = []

        for i in range(1, len(cluster_inf)):
            current_pos1 = cluster_inf.at[i, 'position_1']
            current_pos2 = cluster_inf.at[i, 'position_2']

            for j in range(len(all_neuron_inf)):
                if (abs(current_pos1 - all_neuron_inf.iloc[j, 1]) <= 10) & (abs(current_pos2 - all_neuron_inf.iloc[j, 2]) <= 10):
                    neuron_match_position_dict[all_neuron_inf.iloc[j, 0]].append(cluster_inf.at[i, 'cluster_date'])
                    break

        for neuron in neuron_match_position_dict.keys():
            temp = cluster_inf[cluster_inf['cluster_date'].isin(neuron_match_position_dict[neuron])]
            for i in range(len(temp)):
                waveform = temp['position_waveform'].values[i]
                corr, _ = pearsonr(waveform, all_neuron_inf[all_neuron_inf['Neuron'] == neuron].iloc[:, 3:].values[0].astype(float))
                if corr <= 0.9:
                    neuron_match_position_dict[neuron].remove(temp['cluster_date'].values[i])

        neuron_match_position_dict = {k: v for k, v in neuron_match_position_dict.items() if v}

        for index, row in cluster_inf.iterrows():
            cluster_date = row['cluster_date']
            for neuron, dates in neuron_match_position_dict.items():
                if cluster_date in dates:
                    cluster_inf.at[index, 'Neuron'] = neuron
                    break
        
        cluster_inf = cluster_inf.dropna(subset=['Neuron'])
        spike_inf['cluster_date'] = spike_inf['date']  + "_" +  spike_inf['cluster']
        spike_inf = spike_inf[spike_inf['cluster_date'].isin(cluster_inf['cluster_date'].values)]

        spike_inf['Neuron'] = None
        for i in range(len(cluster_inf)):
            spike_inf.loc[spike_inf['cluster_date'] == cluster_inf.iloc[i, -2], "Neuron"] = cluster_inf.iloc[i, -3]

        remained_neruon = cluster_inf['Neuron'].unique()
        all_cluster_inf = pd.concat((all_cluster_inf, cluster_inf), axis= 0)
        all_spike_inf = pd.concat((all_spike_inf, spike_inf), axis=0)
        all_cluster_inf = all_cluster_inf[all_cluster_inf['Neuron'].isin(remained_neruon)]
        all_spike_inf = all_spike_inf[all_spike_inf['Neuron'].isin(remained_neruon)]
        all_neuron_inf = all_neuron_inf[all_neuron_inf['Neuron'].isin(remained_neruon)]

        all_cluster_inf.to_csv(f"/media/ubuntu/sda/data/paper_architecture/01_closed_loop/{mouse}/tsv/cluster_{date}.tsv", sep = '\t', index=False)
        all_spike_inf.to_csv(f'/media/ubuntu/sda/data/paper_architecture/01_closed_loop/{mouse}/tsv/spike_{date}.tsv', sep = '\t', index=False)
        all_neuron_inf.to_csv(f'/media/ubuntu/sda/data/paper_architecture/01_closed_loop/{mouse}/tsv/neuron_{date}.tsv', sep = '\t', index=False)

        all_cluster_inf.reset_index(drop=True).to_pickle(f"/media/ubuntu/sda/data/paper_architecture/01_closed_loop/{mouse}/pkl/cluster_{date}.pkl")
        all_spike_inf.reset_index(drop=True).to_pickle(f'/media/ubuntu/sda/data/paper_architecture/01_closed_loop/{mouse}/pkl/spike_{date}.pkl')
        all_neuron_inf.reset_index(drop=True).to_pickle(f'/media/ubuntu/sda/data/paper_architecture/01_closed_loop/{mouse}/pkl/neuron_{date}.pkl')