In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
import spikeinterface as si
import matplotlib.pyplot as plt
import os
from matplotlib.backends.backend_pdf import PdfPages

from tqdm import tqdm
import pickle
from scipy.stats import  pearsonr
import neo
from quantities import ms
from elephant.statistics import instantaneous_rate
from elephant.kernels import GaussianKernel

import glob
import neo

In [2]:
def merge_tsv_files(directory_path):
    tsv_files = glob.glob(os.path.join(directory_path, "*.tsv"))
    
    for file in tsv_files:
        print(f"  - {os.path.basename(file)}")
    
    # 读取所有tsv文件
    dataframes = []
    
    for file_path in tsv_files:
        try:
            # 读取tsv文件
            df = pd.read_csv(file_path, sep='\t')

            if len(df) == 555:       
                dataframes.append(df)
            
        except Exception as e:
            print(f"读取文件 {file_path} 时出错: {e}")
            continue
    

    
    # 检查是否有共同的列（cluster_id）
    common_columns = set(dataframes[0].columns)
    for df in dataframes[1:]:
        common_columns = common_columns.intersection(set(df.columns))
    
    
    # 如果有cluster_id列，使用它作为合并键
    if 'cluster_id' in common_columns:
        
        # 从第一个文件开始
        merged_df = dataframes[0].copy()
        
        # 逐个合并其他文件
        for i, df in enumerate(dataframes[1:], 1):
            # 使用outer join保留所有cluster_id
            merged_df = pd.merge(merged_df, df, on='cluster_id', how='outer', suffixes=('', f'_file{i+1}'))
    
        
        return merged_df
    
    else:
        # 如果没有共同列，直接拼接
        merged_df = pd.concat(dataframes, ignore_index=True)
        
        return merged_df

In [3]:
directory_path = "/media/ubuntu/sda/duan"

cluster_inf = merge_tsv_files(directory_path)
cluster_inf = cluster_inf[cluster_inf['KSLabel'] == 'good']

  - cluster_KSLabel.tsv
  - cluster_info.tsv
  - cluster_ContamPct.tsv
  - cluster_group.tsv
  - cluster_Amplitude.tsv


In [4]:
spike_clusters = np.load("/media/ubuntu/sda/duan/spike_clusters.npy")
spike_times = np.load("/media/ubuntu/sda/duan/spike_times.npy")

spike_inf = pd.DataFrame([spike_clusters, spike_times])
spike_inf = spike_inf.T
spike_inf.columns = ['cluster', 'time']
spike_inf = spike_inf[spike_inf['cluster'].isin(cluster_inf['cluster_id'].values)]

In [5]:
from six import class_types


valid_cluster = spike_inf['cluster'].value_counts()
valid_cluster = valid_cluster[valid_cluster > 15000].index

spike_inf = spike_inf[spike_inf['cluster'].isin(valid_cluster)]
cluster_inf = cluster_inf[cluster_inf['cluster_id'].isin(valid_cluster)]

In [6]:
trigger_time = pd.read_csv("/media/ubuntu/sda/duan/rec_params.csv")
trigger_time = trigger_time[trigger_time['bhv_codes'] == 10]
trigger_time['trial_condition'] = trigger_time['trial_condition'].astype(int)
trigger_time['trial_target'] = trigger_time['trial_target'].astype(int)


In [7]:
def create_spike_train_dict(spike_inf, trigger_time, t_start=0, t_stop=500):
    gk = GaussianKernel(50 * ms)

    spike_train_dict = {}

    # 按trial_condition和trial_target分组
    for (trial_condition, trial_target), group in trigger_time.groupby(['trial_condition', 'trial_target']):
        key = f"{trial_condition}_{trial_target}"
        spike_train_dict[key] = []
        
        print(f"Processing {key}: {len(group)} trials")
        
        # 对每个trial
        for _, trial_row in group.iterrows():
            rec_codes_point = trial_row['rec_codes_points']
            
            # 定义时间窗口
            start_time = rec_codes_point
            end_time = rec_codes_point + t_stop * 30  
            
            trial_spikes = spike_inf[
                (spike_inf['time'] >= start_time) & 
                (spike_inf['time'] < end_time)
            ].copy()
                        
            trial_spikes['relative_time'] = trial_spikes['time'] - rec_codes_point
            
            # 为每个cluster创建spike train
            trial_spike_trains = []
            
            for cluster_id in spike_inf['cluster'].unique():
                # 获取该cluster的spikes
                cluster_spikes = trial_spikes[trial_spikes['cluster'] == cluster_id]['relative_time'].values
                
                cluster_spikes_ms = cluster_spikes / 30.0  
                
                valid_spikes = cluster_spikes_ms[(cluster_spikes_ms >= t_start) & (cluster_spikes_ms <= t_stop)]
                
                spike_train = neo.SpikeTrain(
                    valid_spikes.astype(int) * ms, 
                    t_stop=t_stop * ms, 
                    t_start=t_start * ms
                )

                inst_rate = instantaneous_rate(spike_train, kernel=gk, sampling_period=25*ms).magnitude

                trial_spike_trains.append(inst_rate)
            
            spike_train_dict[key].append(trial_spike_trains)
    
    return spike_train_dict


In [8]:
print("Creating spike train dictionary...")
spike_train_dict = create_spike_train_dict(spike_inf, trigger_time, t_start=0, t_stop=500)

print(f"\nSpike train dictionary created!")
print(f"Number of condition-target combinations: {len(spike_train_dict)}")


Creating spike train dictionary...
Processing 1_1: 13 trials
Processing 1_2: 23 trials
Processing 1_3: 20 trials
Processing 2_1: 11 trials
Processing 2_2: 17 trials
Processing 2_3: 19 trials
Processing 3_1: 14 trials
Processing 3_2: 13 trials
Processing 3_3: 19 trials
Processing 4_1: 10 trials
Processing 4_2: 19 trials
Processing 4_3: 15 trials
Processing 5_1: 20 trials
Processing 5_2: 13 trials
Processing 5_3: 21 trials
Processing 6_1: 18 trials
Processing 6_2: 20 trials
Processing 6_3: 25 trials
Processing 7_1: 10 trials
Processing 7_2: 10 trials
Processing 7_3: 32 trials
Processing 8_1: 19 trials
Processing 8_2: 15 trials
Processing 8_3: 21 trials
Processing 9_1: 15 trials
Processing 9_2: 17 trials
Processing 9_3: 16 trials
Processing 10_1: 16 trials
Processing 10_2: 19 trials
Processing 10_3: 21 trials
Processing 11_1: 17 trials
Processing 11_2: 18 trials
Processing 11_3: 16 trials
Processing 12_1: 19 trials
Processing 12_2: 23 trials
Processing 12_3: 19 trials
Processing 13_1: 15 

In [9]:
import pickle
with open('trail_activity_500.pkl', 'wb') as f:
    pickle.dump(spike_train_dict, f)

In [1]:
import pickle
with open('trail_activity_500.pkl', 'rb') as f:
    spike_train_dict = pickle.load(f)

In [17]:
spike_train_mean = {}
for key in spike_train_dict.keys():
    spike_train_mean[key] = []
    for i in spike_train_dict[key]:
        temp = []
        for neuron in i:
            temp.append(neuron.mean())
        temp = np.array(temp)
        spike_train_mean[key].append(temp)

In [21]:
import pickle
with open('trail_activity_mean_500.pkl', 'wb') as f:
    pickle.dump(spike_train_mean, f)