In [107]:
from pathlib import Path
from kilosort.io import load_ops
import sys
import spikeinterface as si
import matplotlib.pyplot as plt

import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
import spikeinterface.widgets as sw
import spikeinterface.qualitymetrics as sqm
import json
import numpy as np
import pandas as pd
import seaborn as sns
import warnings
from kilosort import io
import os
warnings.filterwarnings('ignore')

global_job_kwargs = dict(n_jobs = 4)
si.set_global_job_kwargs(**global_job_kwargs)

In [108]:
def get_spike_inf(file_path, date):
    cluster_inf = pd.read_csv(file_path + "/analyzer_kilosort4_binary/extensions/quality_metrics/metrics.csv")
    cluster_inf.columns = ['cluster', 'num_spikes', 'firing_rate', 'presence_ratio', 'snr',
                           'isi_violations_ratio', 'isi_violations_count', 'rp_contamination',
                           'rp_violations', 'sliding_rp_violation', 'amplitude_cutoff',
                           'amplitude_median', 'amplitude_cv_median', 'amplitude_cv_range',
                           'sync_spike_2', 'sync_spike_4', 'sync_spike_8', 'firing_range',
                           'drift_ptp', 'drift_std', 'drift_mad', 'sd_ratio']
    
    cluster_inf['cluster'] = cluster_inf['cluster'].astype(str)
    cluster_inf['position_1'] = None
    cluster_inf['position_2'] = None

    spike_clusters = pd.DataFrame(np.load(file_path + "/kilosort4/sorter_output/spike_clusters.npy").astype(str))
    spike_positions = pd.DataFrame(np.load(file_path + "/kilosort4/sorter_output/spike_positions.npy").astype(float))
    spike_templates = pd.DataFrame(np.load(file_path + "/kilosort4/sorter_output/spike_templates.npy"))
    spike_times = pd.DataFrame(np.load(file_path + "/kilosort4/sorter_output/spike_times.npy").astype(int))
    tf = pd.DataFrame(np.load(file_path + "/kilosort4/sorter_output/tF.npy")[:, 0, :])

    spike_inf = pd.concat((spike_clusters, spike_positions, spike_templates, spike_times, tf), axis=1)
    spike_inf.columns = ['cluster', 'position_1', 'position_2', 'templates', 'time', 'PC_1', 'PC_2', 'PC_3', 'PC_4', 'PC_5', 'PC_6']

    for i in spike_inf['cluster'].value_counts().index:
        temp = spike_inf[spike_inf['cluster'] == i]
        cluster_inf.loc[cluster_inf['cluster'] == i, 'position_1'] = np.mean(temp['position_1'])
        cluster_inf.loc[cluster_inf['cluster'] == i, 'position_2'] = np.mean(temp['position_2'])

    cluster_inf['probe_group'] = "1"

    for i in spike_inf['cluster'].value_counts().index:
        cluster_rows = cluster_inf[cluster_inf['cluster'] == i]
        if (cluster_rows['position_1'] > 100).any() and (cluster_rows['position_1'] < 250).any():
            cluster_inf.loc[cluster_inf['cluster'] == i, 'probe_group'] = "2"
        elif (cluster_rows['position_1'] > 250).any() and (cluster_rows['position_1'] < 400).any():
            cluster_inf.loc[cluster_inf['cluster'] == i, 'probe_group'] = "3"
        elif (cluster_rows['position_1'] > 400).any() and (cluster_rows['position_1'] < 550).any():
            cluster_inf.loc[cluster_inf['cluster'] == i, 'probe_group'] = "4"
        elif (cluster_rows['position_1'] > 550).any():
            cluster_inf.loc[cluster_inf['cluster'] == i, 'probe_group'] = "5"

    waveform = np.load(file_path + "/kilosort4/sorter_output/templates.npy")
    cluster_inf['waveform'] = [waveform[i] for i in range(waveform.shape[0])]

    cluster_inf = cluster_inf[((cluster_inf['snr'] > 3) & (cluster_inf['num_spikes'] > int(5000))) | ((cluster_inf['snr'] < 3) & (cluster_inf['num_spikes'] > 8000))]
    spike_inf = spike_inf[spike_inf['cluster'].isin(list(cluster_inf['cluster']))]
    spike_inf = spike_inf[spike_inf['time'] > 200]
    cluster_inf['date'] = date
    spike_inf['date'] = date
    
    channel_indices = {
        "1": [1, 3, 5, 6, 9, 11],
        "2": [13, 15, 17, 19, 21, 23],
        "3": [24, 25, 26, 27, 28, 29],
        "4": [12, 14, 16, 18, 20, 22],
        "5": [0, 2, 4, 6, 8, 10, 12]
        }

    for index, row in cluster_inf.iterrows():
        probe_group = row['probe_group']
        if probe_group in channel_indices:
            selected_channels = channel_indices[probe_group]
            cluster_inf.at[index, 'waveform'] = row['waveform'][:, selected_channels]

    return cluster_inf, spike_inf

def calculate_position(row):
    probe_group = str(row['probe_group'])
    channels = channel_indices[probe_group]
    waveform = row['waveform'] 
    
    a_squared = [np.sum(waveform[:, j]**2) for j in range(len(channels))]
    
    sum_x_a = 0
    sum_y_a = 0
    sum_a = 0
    
    for j, channel in enumerate(channels):
        x_i, y_i = channel_position.get(channel, [0, 0])  
        a_i_sq = a_squared[j]
        
        sum_x_a += x_i * a_i_sq
        sum_y_a += y_i * a_i_sq
        sum_a += a_i_sq
    
    if sum_a == 0:
        return pd.Series({'position_1': 0, 'position_2': 0})
    
    x_hat = sum_x_a / sum_a
    y_hat = sum_y_a / sum_a
    return pd.Series({'position_1': x_hat, 'position_2': y_hat})

def calculate_position_waveform(row, channel_position, channel_indices, power=2):
    x_target = row['position_1']
    y_target = row['position_2']
    probe_group = str(row['probe_group'])
    channels = channel_indices[probe_group]  
    waveforms = row['waveform']  
    
    distances = []
    for channel in channels:
        x_channel, y_channel = channel_position.get(channel, [np.nan, np.nan])
        if np.isnan(x_channel):  
            continue
        distance = np.sqrt((x_target - x_channel)**2 + (y_target - y_channel)**2)
        distances.append(distance)
    
    if not distances:  
        return np.zeros(61)
    
    #IDW
    weights = 1 / (np.array(distances) ** power)
    if np.any(distances == 0):
        zero_idx = np.argwhere(distances == 0).flatten()
        return waveforms[:, zero_idx[0]]
    
    weights /= np.sum(weights)
    
    synthesized_waveform = np.zeros(61)
    for t in range(61): 
        weighted_sum = np.dot(waveforms[t, :], weights)
        synthesized_waveform[t] = weighted_sum
    
    return synthesized_waveform

In [109]:
channel_indices = {
        "1": [1, 3, 5, 6, 9, 11],
        "2": [13, 15, 17, 19, 21, 23],
        "3": [24, 25, 26, 27, 28, 29],
        "4": [12, 14, 16, 18, 20, 22],
        "5": [0, 2, 4, 6, 8, 10, 12]
        }
channel_position = {
    0: [650, 0],
    2: [650, 50],
    4: [650, 100],
    6: [600, 100],
    8: [600, 50],
    10: [600, 0],
    1: [0, 0],
    3: [0, 50],
    5: [0, 100],
    7: [50, 100],
    9: [50, 50],
    11: [50, 0],
    13: [150, 200], 
    15: [150, 250],
    17: [150, 300],
    19: [200, 300],
    21: [200, 250],
    23: [200, 200],
    12: [500, 200],
    14: [500, 250],
    16: [500, 300],
    18: [450, 300],
    20: [450, 250],
    22: [450, 200],
    24: [350, 400],
    26: [350, 450],
    28: [350, 500],
    25: [300, 400],
    27: [300, 450],
    29: [300, 500]
}

In [197]:
def cal_lossed_neurons(mouse, date_order):
    neuron_num = []
    for date_len in range(2, len(date_order) + 1):
        date_order_temp = date_order[:date_len]
        all_cluster_inf = pd.DataFrame()
        all_spike_inf = pd.DataFrame()

        for date in date_order_temp:
            cluster_inf, spike_inf = get_spike_inf(file_path=f"/media/ubuntu/sda/data/sort_output/mouse{mouse}/natural_image/{date}", date = date)
            all_cluster_inf = pd.concat([all_cluster_inf, cluster_inf], ignore_index=True)
            all_spike_inf = pd.concat([all_spike_inf, spike_inf], ignore_index=True)

        all_cluster_inf[['position_1', 'position_2']] = all_cluster_inf.apply(calculate_position, axis=1)
        all_cluster_inf['Neuron'] = None
        current_max_neuron = 1  

        for i in range(1, len(all_cluster_inf)):
            current_pos1 = all_cluster_inf.at[i, 'position_1']
            current_pos2 = all_cluster_inf.at[i, 'position_2']
            
            mask = (
                (all_cluster_inf.loc[:i-1, 'position_1'] - current_pos1).abs().lt(10) & 
                (all_cluster_inf.loc[:i-1, 'position_2'] - current_pos2).abs().lt(10)
            )
            
            matched = all_cluster_inf.loc[:i-1][mask]
            
            if not matched.empty:
                all_cluster_inf.at[i, 'Neuron'] = matched['Neuron'].iloc[-1]
            else:
                current_max_neuron += 1
                all_cluster_inf.at[i, 'Neuron'] = f'Neuron_{current_max_neuron}'
        
        neuron_date = pd.crosstab(all_cluster_inf['Neuron'], all_cluster_inf['date'])   
        neuron_date[neuron_date > 1] = 1
        neuron_date = neuron_date.sum(axis=1)
        neuron_date = neuron_date[neuron_date == date_len]
        neuron_date = neuron_date.index

        all_cluster_inf = all_cluster_inf[all_cluster_inf['Neuron'].isin(neuron_date)]
        all_cluster_inf['cluster_date'] = all_cluster_inf['date']  + "_" +  all_cluster_inf['cluster']

        all_cluster_inf['position_waveform'] = all_cluster_inf.apply(
            calculate_position_waveform, 
            axis=1, 
            args=(channel_position, channel_indices, 2))
        
        os.makedirs(f"/media/ubuntu/sda/data/filter_neuron/neuron_loss/waveform/mouse{mouse}/{date_len}", exist_ok=True)
        for neuron in all_cluster_inf['Neuron']:
            temp = all_cluster_inf[all_cluster_inf['Neuron'] == neuron]
            temp.index = temp['cluster_date']
            df_expanded = temp['position_waveform'].apply(pd.Series)
            df_expanded.to_csv(f"/media/ubuntu/sda/data/filter_neuron/neuron_loss/waveform/mouse{mouse}/{date_len}/waveform_mean_{neuron}.csv") 

        num = 0
        results = {}
        folder_path = f'/media/ubuntu/sda/data/filter_neuron/neuron_loss/waveform/mouse{mouse}/{date_len}'

        csv_files = [f for f in os.listdir(folder_path) if f.startswith('waveform_mean_Neuron_') and f.endswith('.csv')]
        if date_len < 13:
            eps_threshold = 3
        else:
            eps_threshold = 2.5

        for csv_file in csv_files:
            df = pd.read_csv(os.path.join(folder_path, csv_file), index_col=0)
            
            from sklearn.cluster import DBSCAN
            from sklearn.decomposition import PCA
            
            pca = PCA(n_components=2)
            principal_components = pca.fit_transform(df)

            eps = eps_threshold
            min_samples = 1

            dbscan = DBSCAN(eps=eps, min_samples=min_samples)
            dbscan.fit(principal_components)

            label = pd.DataFrame(dbscan.labels_, columns=['labels'])
            label['cluster_date'] = df.index
            label['date'] = label['cluster_date'].apply(lambda x: x.split('_')[0])

            remain_label = label['labels'].value_counts()
            remain_label = remain_label[remain_label >= date_len]
            for i in remain_label.index:
                temp = label[label['labels'] == i]
                if temp['date'].nunique() != date_len:
                    remain_label = remain_label.drop(i)
            label = label[label['labels'].isin(remain_label.index)]
            for i in label['labels'].unique():
                results[num] = label.loc[label['labels'] ==i, 'cluster_date'].values
                num += 1

        all_cluster_inf['Neuron'] = None
        for key,item in results.items():
            all_cluster_inf.loc[all_cluster_inf['cluster_date'].isin(item), 'Neuron'] = f'Neuron_{key+1}'

        all_cluster_inf = all_cluster_inf.dropna(subset=['Neuron'])
        all_cluster_inf['neuron_date'] = all_cluster_inf['date'] + "_" + all_cluster_inf['Neuron']

        waveform_mean = pd.DataFrame()
        csv_files = [f for f in os.listdir(f'/media/ubuntu/sda/data/filter_neuron/neuron_loss/waveform/mouse{mouse}/{date_len}') if f.startswith('waveform_mean_Neuron_') and f.endswith('.csv')]
        for csv_file in csv_files:
            df = pd.read_csv(os.path.join(f'/media/ubuntu/sda/data/filter_neuron/neuron_loss/waveform/mouse{mouse}/{date_len}', csv_file), index_col=0)
            waveform_mean = pd.concat((waveform_mean, df), axis=0)

        waveform_mean = waveform_mean.loc[list(all_cluster_inf['cluster_date'])]

        all_cluster_inf = all_cluster_inf.set_index('cluster_date')
        all_cluster_inf = all_cluster_inf.join(waveform_mean, how="right")

        from scipy.stats import pearsonr
        from mpl_toolkits.axes_grid1.inset_locator import inset_axes
        import colorsys
        from matplotlib.backends.backend_pdf import PdfPages

        n_neurons = len(all_cluster_inf['Neuron'].unique())
        print(f'{date_len}: {n_neurons}')

        with PdfPages(f'/media/ubuntu/sda/data/filter_neuron/neuron_loss/figure/mouse{mouse}/cluster_view_{date_len}.pdf') as pdf:
            
            for idx, neuron in enumerate(all_cluster_inf['Neuron'].unique()):
                temp = all_cluster_inf[all_cluster_inf['Neuron'] == neuron]
                                
                fig, ax = plt.subplots(figsize=(1.5, 1.5))
                for i in range(date_len):
                    sns.lineplot(
                        x=range(32),
                        y=temp.iloc[i, 41:73],
                        color='orange',  
                        ax=ax
                    )
                
                ax.set_ylabel("Amplitude")
                ax.set_xticks([])
                ax.set_title(neuron)
                pdf.savefig(fig)
                plt.close(fig)

        all_cluster_inf.to_csv(f'/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse{mouse}/cluster_inf_{date_len}.tsv', sep = '\t')
        neuron_num.append(n_neurons)

        with PdfPages(f'/media/ubuntu/sda/data/filter_neuron/neuron_loss/figure/mouse{mouse}/cluster_sep_detail_{date_len}.pdf') as pdf:            
            all_cluster_inf['cluster_date'] = all_cluster_inf.index
            for idx, neuron in enumerate(all_cluster_inf['Neuron'].unique()):
                b = all_cluster_inf[all_cluster_inf['Neuron'] == neuron]

                fig, axes = plt.subplots(6,6, figsize = (30, 30))
                axes = axes.flatten()

                for i in range(date_len):
                    ax = axes[i]
                    sns.lineplot(x = range(32),
                                y = b.iloc[i, 41:73], ax=ax)
                    ax.set_title(b['cluster_date'].values[i])

                fig.suptitle(neuron, fontsize = 20)
                pdf.savefig(fig)
                plt.close()
    return neuron_num

In [133]:
date_order = ['021322', '022522', '031722', '042422', 
              '052422', '062422', '072322', '082322', 
              '092422', '102122', '112022', '122022', 
              '012123', 
              '022223', '032123', '042323']

cal_lossed_neurons(6, date_order)

2: 25
3: 24
4: 20
5: 20
6: 21
7: 21
8: 21
9: 21
10: 21
11: 21
12: 21
13: 19
14: 19
15: 19
16: 19


[25, 24, 20, 20, 21, 21, 21, 21, 21, 21, 21, 19, 19, 19, 19]

In [182]:
neuron_num_mouse6 = [25, 24, 20, 20, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19]

In [136]:
ablated_neuron = {
    6: ['Neuron_6', 'Neuron_14'],
    7: ['Neuron_6', 'Neuron_14'],
    8: ['Neuron_6', 'Neuron_14'],
    9: ['Neuron_6', 'Neuron_14'],
    10: ['Neuron_6', 'Neuron_14'],
    11: ['Neuron_6', 'Neuron_14'],
    12: ['Neuron_6', 'Neuron_14']
}

In [137]:
for i in [6,7,8,9,10,11,12]:
    temp = pd.read_csv(f'/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse6/cluster_inf_{i}.tsv', sep = '\t')
    temp = temp[~temp['Neuron'].isin(['Neuron_6', 'Neuron_14'])]
    temp.to_csv(f'/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse6/cluster_inf_{i}.tsv', sep = '\t')

In [None]:
guide = pd.read_csv("/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse6/cluster_inf_2.tsv", sep = '\t')
guide_dict = {}
for neuron in guide['Neuron'].unique():
    guide_dict[neuron] = guide.loc[guide['Neuron'] == neuron, 'cluster_date'].values.tolist()
reverse_dict = {
    v: k  
    for k, values in guide_dict.items() 
    for v in values
}

In [166]:
os.makedirs("/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse6_processed", exist_ok=True)
for file in os.listdir('/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse6'):
    if file != 'cluster_inf_2.tsv':
        temp = pd.read_csv(f"/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse6/{file}", sep = '\t')
        temp['Neuron_new'] = temp['cluster_date'].map(reverse_dict)
        neuron_to_new = temp.groupby('Neuron')['Neuron_new'].transform('first')
        temp['Neuron_new'] = neuron_to_new
        temp.to_csv(f"/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse6_processed/{file}", sep = '\t')

In [179]:
date_order = ['021722', '030122', '032322', '042322', '052322', '052422', 
              '062422', '072422', '082422', '092222', 
              '102522', '112822', '122322', '012123', 
              '022423', '032323', '042323', '052423',
              '062323', '072123']

cal_lossed_neurons(11, date_order)

2: 21
3: 20
4: 17
5: 17
6: 17
7: 17
8: 17
9: 17
10: 14
11: 14
12: 14
13: 12
14: 12
15: 12
16: 13
17: 13
18: 13
19: 13
20: 13


[21, 20, 17, 17, 17, 17, 17, 17, 14, 14, 14, 12, 12, 12, 13, 13, 13, 13, 13]

In [184]:
neuron_num_mouse11 = [21, 20, 16, 16, 16, 16, 16, 16, 13, 13, 13, 12, 12, 12, 12, 12,  12, 11, 11]
ablated_neuron = {
    4: ['Neuron_8'],
    5: ['Neuron_8'],
    6: ['Neuron_8'],
    7: ['Neuron_8'],
    8: ['Neuron_8'],
    9: ['Neuron_8'],
    10: ['Neuron_8'],
    11: ['Neuron_8'],
    12: ['Neuron_8'],
    13: ['Neuron_6'],
    14: ['Neuron_6'],
    15: ['Neuron_6'],
    16: ['Neuron_6'],
    17: ['Neuron_6'],
    18: ['Neuron_6'],
    19: ["Neuron_6", 'Neuron_7'],
    20: ['Neuron_6', 'Neuron_7']    
}

In [186]:
for i in [4, 5, 6, 7, 8, 9, 10, 11, 12]:
    temp = pd.read_csv(f'/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse11/cluster_inf_{i}.tsv', sep = '\t')
    temp = temp[~temp['Neuron'].isin(['Neuron_8'])]
    temp.to_csv(f'/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse11/cluster_inf_{i}.tsv', sep = '\t')

for i in [13, 14, 15, 16, 17, 18]:
    temp = pd.read_csv(f'/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse11/cluster_inf_{i}.tsv', sep = '\t')
    temp = temp[~temp['Neuron'].isin(['Neuron_6'])]
    temp.to_csv(f'/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse11/cluster_inf_{i}.tsv', sep = '\t')

for i in [19, 20]:
    temp = pd.read_csv(f'/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse11/cluster_inf_{i}.tsv', sep = '\t')
    temp = temp[~temp['Neuron'].isin(['Neuron_7', 'Neruon_6'])]
    temp.to_csv(f'/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse11/cluster_inf_{i}.tsv', sep = '\t')

In [187]:
guide = pd.read_csv("/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse11/cluster_inf_2.tsv", sep = '\t')
guide_dict = {}
for neuron in guide['Neuron'].unique():
    guide_dict[neuron] = guide.loc[guide['Neuron'] == neuron, 'cluster_date'].values.tolist()
reverse_dict = {
    v: k  
    for k, values in guide_dict.items() 
    for v in values
}

In [188]:
os.makedirs("/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse11_processed", exist_ok=True)
for file in os.listdir('/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse11'):
    if file != 'cluster_inf_2.tsv':
        temp = pd.read_csv(f"/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse11/{file}", sep = '\t')
        temp['Neuron_new'] = temp['cluster_date'].map(reverse_dict)
        neuron_to_new = temp.groupby('Neuron')['Neuron_new'].transform('first')
        temp['Neuron_new'] = neuron_to_new
        temp.to_csv(f"/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse11_processed/{file}", sep = '\t')

In [191]:
date_order = ['031722', '042322', '052322', '062422', '072322', '082322', 
              '092222', '102522', '112822', '122022', 
              '012123', '022223']

cal_lossed_neurons(2, date_order)

2: 14
3: 13
4: 12
5: 12
6: 12
7: 11
8: 9
9: 9
10: 8
11: 8
12: 8


[14, 13, 12, 12, 12, 11, 9, 9, 8, 8, 8]

In [192]:
neuron_num_mouse2 = [13, 12, 12, 12, 12, 11, 9, 9, 8, 8, 8]
ablated_neuron = {
    2: ['Neuron_3'],
    3: ['Neuron_3']
}

In [193]:
for i in [2, 3]:
    temp = pd.read_csv(f'/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse2/cluster_inf_{i}.tsv', sep = '\t')
    temp = temp[~temp['Neuron'].isin(['Neuron_3'])]
    temp.to_csv(f'/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse2/cluster_inf_{i}.tsv', sep = '\t')


In [194]:
guide = pd.read_csv("/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse2/cluster_inf_2.tsv", sep = '\t')
guide_dict = {}
for neuron in guide['Neuron'].unique():
    guide_dict[neuron] = guide.loc[guide['Neuron'] == neuron, 'cluster_date'].values.tolist()
reverse_dict = {
    v: k  
    for k, values in guide_dict.items() 
    for v in values
}

In [195]:
os.makedirs("/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse2_processed", exist_ok=True)
for file in os.listdir('/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse2'):
    if file != 'cluster_inf_2.tsv':
        temp = pd.read_csv(f"/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse2/{file}", sep = '\t')
        temp['Neuron_new'] = temp['cluster_date'].map(reverse_dict)
        neuron_to_new = temp.groupby('Neuron')['Neuron_new'].transform('first')
        temp['Neuron_new'] = neuron_to_new
        temp.to_csv(f"/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse2_processed/{file}", sep = '\t')

In [203]:
date_order = ['030222', '042422', '052322', '062322', '072322', '082422', 
              '092222', '102522', '112822', '122322', 
              '012123', '022423', '032323', '042323', '052423', '062323', '072123']

cal_lossed_neurons(5, date_order)

2: 11
3: 10
4: 7
5: 7
6: 8
7: 8
8: 8
9: 7
10: 6
11: 5
12: 5
13: 5
14: 5
15: 6
16: 6
17: 7


[11, 10, 7, 7, 8, 8, 8, 7, 6, 5, 5, 5, 5, 6, 6, 7]

In [204]:
neuron_num_mouse5 = [10, 8, 6, 6, 8, 8, 8, 7, 6, 5, 5, 5, 5, 5, 5, 6]
ablated_neuron = {
    2: ['Neuron_9'],
    3: ['Neuron_9', 'Neuron_1'],
    4: ['Neuron_7'],
    5: ['Neuron_7'],
    17: ['Neuron_7'],
    16: ['Neuron_6'],
    15: ['Neuron_6']
}

In [None]:
for i in [2, 3, 4, 5, 15, 16, 17]:
    temp = pd.read_csv(f'/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse5/cluster_inf_{i}.tsv', sep = '\t')
    temp = temp[~temp['Neuron'].isin(ablated_neuron[i])]
    temp.to_csv(f'/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse5/cluster_inf_{i}.tsv', sep = '\t')

In [207]:
guide = pd.read_csv("/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse5/cluster_inf_2.tsv", sep = '\t')
guide_dict = {}
for neuron in guide['Neuron'].unique():
    guide_dict[neuron] = guide.loc[guide['Neuron'] == neuron, 'cluster_date'].values.tolist()
reverse_dict = {
    v: k  
    for k, values in guide_dict.items() 
    for v in values
}

In [208]:
os.makedirs("/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse5_processed", exist_ok=True)
for file in os.listdir('/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse5'):
    if file != 'cluster_inf_2.tsv':
        temp = pd.read_csv(f"/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse5/{file}", sep = '\t')
        temp['Neuron_new'] = temp['cluster_date'].map(reverse_dict)
        neuron_to_new = temp.groupby('Neuron')['Neuron_new'].transform('first')
        temp['Neuron_new'] = neuron_to_new
        temp.to_csv(f"/media/ubuntu/sda/data/filter_neuron/neuron_loss/cluster_inf/mouse5_processed/{file}", sep = '\t')