In [1]:
import os
import re
import glob
import numpy as np
import pandas as pd
import numpy as np
import scipy.io as sio
import spikeinterface.extractors as se
import spikeinterface as si
import spikeinterface.sorters as ss
import spikeinterface.postprocessing as spost
import spikeinterface.qualitymetrics as sqm
from pathlib import Path
import matplotlib.pyplot as plt
import json

from typing import List, Tuple

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ROOT_SORT_DIR = "/media/ubuntu/sda/Monkey/sorted_result_combined/20240112/Hub1-instance1_V1/phy_folder_for_kilosort/"
DATE_STR = "20240112"

# 需要的 cluster 指标文件（来自 phy_folder_for_kilosort）
CLUSTER_INFO_FILENAME = "cluster_info.tsv"

# spike 层面
SPIKE_CLUSTERS_FILENAME = "spike_clusters.npy"
SPIKE_TIMES_FILENAME = "spike_times.npy"

In [3]:
def load_cluster_info(phy_dir: str) -> pd.DataFrame:
    """读取 phy_folder_for_kilosort/cluster_info.tsv 为 DataFrame。
    该表包含所有需要的 cluster 级指标。
    如果cluster_info.tsv不存在，则尝试从其他文件构建基本信息。
    """
    path = os.path.join(phy_dir, CLUSTER_INFO_FILENAME)
    
    if os.path.exists(path):
        df = pd.read_csv(path, sep='\t')
        # 标准化主键列名
        if 'cluster_id' not in df.columns:
            raise ValueError(f"{path} 中缺少 cluster_id 列")
        return df
    else:
        print(f"警告: {path} 不存在，尝试从其他文件构建cluster信息")
        
        # 尝试从cluster_group.tsv构建基本信息
        cluster_group_path = os.path.join(phy_dir, "cluster_group.tsv")
        if os.path.exists(cluster_group_path):
            df = pd.read_csv(cluster_group_path, sep='\t')
            if 'cluster_id' not in df.columns:
                raise ValueError(f"{cluster_group_path} 中缺少 cluster_id 列")
            return df
        else:
            # 如果都没有，从spike_clusters.npy中提取唯一的cluster_id
            spike_clusters_path = os.path.join(phy_dir, SPIKE_CLUSTERS_FILENAME)
            if os.path.exists(spike_clusters_path):
                spike_clusters = np.load(spike_clusters_path)
                unique_clusters = np.unique(spike_clusters)
                df = pd.DataFrame({
                    'cluster_id': unique_clusters,
                    'group': 'unsorted'  # 默认分组
                })
                return df
            else:
                raise ValueError(f"无法找到任何cluster信息文件: {phy_dir}")


def load_spike_level(phy_dir: str) -> pd.DataFrame:
    """读取 spike 层面的 numpy 文件并返回 DataFrame: [cluster, time]
    time 使用原始采样点，不做单位转换。
    """
    spike_clusters = np.load(os.path.join(phy_dir, SPIKE_CLUSTERS_FILENAME))
    spike_times = np.load(os.path.join(phy_dir, SPIKE_TIMES_FILENAME))
    # 展平为一维
    spike_clusters = np.asarray(spike_clusters).reshape(-1)
    spike_times = np.asarray(spike_times).reshape(-1)
    if spike_clusters.shape[0] != spike_times.shape[0]:
        raise ValueError(f"spike_clusters 与 spike_times 行数不一致: {phy_dir}")
    df = pd.DataFrame({
        'cluster_id': spike_clusters.astype(int),
        'time': spike_times.astype(int),
    })
    return df

In [6]:
ROOT_SORT_DIR = "/media/ubuntu/sda/Monkey/sorted_result_combined/20240112/Hub1-instance1_V1/phy_folder_for_kilosort/"
cluster_inf = load_cluster_info(phy_dir=ROOT_SORT_DIR)
spike_inf = load_spike_level(phy_dir=ROOT_SORT_DIR)

cluster_inf = cluster_inf[cluster_inf['group'] == 'good']
spike_inf = spike_inf[spike_inf['cluster_id'].isin(cluster_inf['cluster_id'].unique())]


from numpy import record
time_mapping = {}

for block in os.listdir("/media/ubuntu/sda/Monkey/TVSD/monkeyF/20240112"):
    record_temp = se.read_blackrock(f"/media/ubuntu/sda/Monkey/TVSD/monkeyF/20240112/{block}/Hub1-instance1_B00{block[-1]}.ns6")
    time_mapping[block] = int(record_temp.get_total_duration() * 30000)


spike_inf['block'] = None

cumulative_time = 0
time_boundaries = {}

for i in [1, 2, 3, 4]:
    block_key = f'Block_{i}'
    if block_key in time_mapping:
        time_boundaries[i] = (cumulative_time, cumulative_time + time_mapping[block_key])
        cumulative_time += time_mapping[block_key]

for block_id, (start_time, end_time) in time_boundaries.items():
    mask = (spike_inf['time'] >= start_time) & (spike_inf['time'] < end_time)
    spike_inf.loc[mask, 'block'] = block_id

for block_id in spike_inf['block'].dropna().unique():
    block_mask = spike_inf['block'] == block_id
    block_spikes = spike_inf[block_mask]
    
    if len(block_spikes) > 0:
        min_time = block_spikes['time'].min()
        spike_inf.loc[block_mask, 'time'] = spike_inf.loc[block_mask, 'time'] - min_time


spike_inf['array'] = 'Hub1_instance1'
cluster_inf['array'] = 'Hub1_instance1'

cluster_inf.to_csv("/media/ubuntu/sda/Monkey/sorted_result_combined/20240112/cluster_inf_Hub1-instance1_V1.csv", index = False)
spike_inf.to_csv("/media/ubuntu/sda/Monkey/sorted_result_combined/20240112/spike_inf_Hub1-instance1_V1.csv", index = False)

In [5]:
ROOT_SORT_DIR = "/media/ubuntu/sda/Monkey/sorted_result_combined/20240112/Hub2-instance1_V1/phy_folder_for_kilosort/"
cluster_inf = load_cluster_info(phy_dir=ROOT_SORT_DIR)
spike_inf = load_spike_level(phy_dir=ROOT_SORT_DIR)

cluster_inf = cluster_inf[cluster_inf['group'] == 'good']
spike_inf = spike_inf[spike_inf['cluster_id'].isin(cluster_inf['cluster_id'].unique())]


from numpy import record
time_mapping = {}

for block in os.listdir("/media/ubuntu/sda/Monkey/TVSD/monkeyF/20240112"):
    record_temp = se.read_blackrock(f"/media/ubuntu/sda/Monkey/TVSD/monkeyF/20240112/{block}/Hub2-instance1_B00{block[-1]}.ns6")
    time_mapping[block] = int(record_temp.get_total_duration() * 30000)


spike_inf['block'] = None

cumulative_time = 0
time_boundaries = {}

for i in [1, 2, 3, 4]:
    block_key = f'Block_{i}'
    if block_key in time_mapping:
        time_boundaries[i] = (cumulative_time, cumulative_time + time_mapping[block_key])
        cumulative_time += time_mapping[block_key]

for block_id, (start_time, end_time) in time_boundaries.items():
    mask = (spike_inf['time'] >= start_time) & (spike_inf['time'] < end_time)
    spike_inf.loc[mask, 'block'] = block_id

for block_id in spike_inf['block'].dropna().unique():
    block_mask = spike_inf['block'] == block_id
    block_spikes = spike_inf[block_mask]
    
    if len(block_spikes) > 0:
        min_time = block_spikes['time'].min()
        spike_inf.loc[block_mask, 'time'] = spike_inf.loc[block_mask, 'time'] - min_time


spike_inf['array'] = 'Hub2_instance1'
cluster_inf['array'] = 'Hub2_instance1'

cluster_inf.to_csv("/media/ubuntu/sda/Monkey/sorted_result_combined/20240112/cluster_inf_Hub2-instance1_V1.csv", index = False)
spike_inf.to_csv("/media/ubuntu/sda/Monkey/sorted_result_combined/20240112/spike_inf_Hub2-instance1_V1.csv", index = False)