In [1]:
%reset -f

In [2]:
import spikeinterface.full as si
import spikeinterface as s
import spikeinterface.qualitymetrics as sqm
from spikeinterface.postprocessing import compute_correlograms
import matplotlib.pyplot as plt
from pathlib import Path
from tools import *
import pandas as pd
import os
import sys
import numpy as np
import scipy.stats
import spikeinterface.core as sc
import spikeinterface.extractors as se
from spikeinterface.postprocessing import compute_principal_components
from spikeinterface.qualitymetrics import (
    compute_snrs,
    compute_firing_rates,
    compute_isi_violations,
    calculate_pc_metrics,
    compute_quality_metrics,
)

## Load recording and sorting data

In [3]:
base_folder = r'/data1/zhangyuhao/xinchao_data/NP1/20230112_PVsyt2_tremor/'

# 初始化 kilosort4_folder 为 None
kilosort4_folder = None

# 遍历 base_folder 下的所有子文件夹
for root, dirs, files in os.walk(base_folder):
    for dir_name in dirs:
        if dir_name == 'Xinchao_sort':
            # 找到 kilosort4 文件夹，构建完整路径
            kilosort4_folder = os.path.join(root, dir_name)
            break
    if kilosort4_folder:
        break

# 检查是否找到了 kilosort4 文件夹
if kilosort4_folder:
    print(f"找到 kilosort4 文件夹，路径为: {kilosort4_folder}")
else:
    print("未在 base_folder 下找到 kilosort4 文件夹。")


找到 kilosort4 文件夹，路径为: /data1/zhangyuhao/xinchao_data/NP1/20230112_PVsyt2_tremor/Sorted/Xinchao_sort


In [5]:
recording = si.read_spikeglx(base_folder + 'Rawdata', stream_name='imec0.ap', load_sync_channel=False)

In [6]:
sorting_info = se.KiloSortSortingExtractor(kilosort4_folder, keep_good_only=False, remove_empty_units=False)
print('Sorting information',sorting_info)

sampling_rate = recording.get_sampling_frequency()

unit_ids = sorting_info.get_unit_ids()
print('Unit ids',unit_ids)

Sorting information KiloSortSortingExtractor: 656 units - 1 segments - 30.0kHz
Unit ids [  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  17  18  19
  20  21  22  23  24  25  27  28  30  31  32  33  34  35  36  37  38  39
  40  41  42  44  45  46  47  48  49  50  51  52  53  54  55  56  58  59
  60  61  62  63  64  65  66  67  68  69  70  72  73  74  75  76  77  78
  80  82  83  87  88  89  90  91  92  93  94  95  96  97  98  99 100 101
 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
 120 121 122 123 124 125 126 127 128 129 130 131 132 134 136 137 138 140
 141 142 144 146 147 148 149 151 152 153 154 155 156 158 159 160 161 162
 163 164 165 166 167 168 169 170 171 172 173 174 176 177 178 179 180 181
 182 183 184 185 187 188 189 190 191 192 193 194 195 196 197 198 199 200
 201 202 203 204 205 206 207 208 209 211 212 213 215 216 217 218 221 222
 223 224 225 226 227 228 229 230 231 232 233 234 236 237 238 239 240 242
 243 244 245 246 248 249 250 251 252

In [7]:
analyzer = si.create_sorting_analyzer(sorting=sorting_info, recording=recording, format="memory")
print(analyzer)

estimate_sparsity (no parallelization):   0%|          | 0/601 [00:00<?, ?it/s]



SortingAnalyzer: 384 channels - 656 units - 1 segments - memory - sparse - has recording
Loaded 0 extensions


In [8]:

if '__spec__' not in sys.modules['__main__'].__dict__:
    sys.modules['__main__'].__spec__ = None

    analyzer.compute("random_spikes", method="uniform", max_spikes_per_unit=600, seed=2205)
    analyzer.compute("waveforms", ms_before=1.3, ms_after=2.6, n_jobs=2)
    analyzer.compute("templates", operators=["average", "median", "std"])
    analyzer.compute("noise_levels")

print(analyzer)

compute_waveforms (workers: 2 processes):   0%|          | 0/601 [00:00<?, ?it/s]

noise_level (no parallelization):   0%|          | 0/20 [00:00<?, ?it/s]

SortingAnalyzer: 384 channels - 656 units - 1 segments - memory - sparse - has recording
Loaded 4 extensions: random_spikes, waveforms, templates, noise_levels


In [9]:
firing_rates = compute_firing_rates(analyzer)
print(firing_rates)
isi_violation_ratio, isi_violations_count = compute_isi_violations(analyzer)
print(isi_violation_ratio)
snrs = compute_snrs(analyzer)
print(snrs)


{np.int64(0): 5.06166667701487, np.int64(1): 1.5033333364067871, np.int64(2): 30.28000006190522, np.int64(3): 4.275000008739921, np.int64(4): 1.8233333370610034, np.int64(5): 0.18000000036799668, np.int64(6): 0.18000000036799668, np.int64(7): 0.31333333397392016, np.int64(8): 0.21000000042932948, np.int64(9): 0.19833333373881118, np.int64(10): 0.19333333372858905, np.int64(11): 0.3383333340250308, np.int64(12): 26.521666720888252, np.int64(13): 0.23833333382058822, np.int64(14): 1.461666669654936, np.int64(17): 1.9850000040581857, np.int64(18): 0.21000000042932948, np.int64(19): 1.1983333357832373, np.int64(20): 0.4850000009915466, np.int64(21): 60.46166679027607, np.int64(22): 18.830000038496543, np.int64(23): 7.5316666820646025, np.int64(24): 1.031666668775833, np.int64(25): 1.3300000027190866, np.int64(27): 0.1733333336877005, np.int64(28): 0.3766666674367338, np.int64(30): 0.270000000551995, np.int64(31): 0.2866666672527355, np.int64(32): 0.23666666715051415, np.int64(33): 0.530000

## Save quality metrics

In [10]:
# qc_metric_names=['firing_rate', 'presence_ratio', 'snr', 'isi_violation', 'rp_violation', 'amplitude_cutoff']
qc_metric_names=['firing_rate', 'presence_ratio', 'snr', 'isi_violation', 'amplitude_cutoff']
metrics = compute_quality_metrics(analyzer, metric_names=qc_metric_names)
metrics_df = pd.DataFrame(metrics)

# 添加 cluster_id 列
metrics_df['cluster_id'] = unit_ids
cols = ['cluster_id'] + [col for col in metrics_df.columns if col != 'cluster_id']
metrics_df = metrics_df[cols]
metrics_df.to_csv(os.path.join(base_folder,'quality_metrics.csv'), index=False)

print('Metrics saved to' , base_folder)
print(metrics)

Metrics saved to /data1/zhangyuhao/xinchao_data/NP1/20230112_PVsyt2_tremor/
     firing_rate  presence_ratio        snr  isi_violations_ratio  \
0       5.061667             1.0   5.841489              7.784574   
1       1.503333             0.8   7.857068              0.000000   
2      30.280000             1.0   3.345547              0.029084   
3       4.275000             1.0   3.590757              1.337544   
4       1.823333             1.0   4.566010              4.846111   
..           ...             ...        ...                   ...   
713     0.170000             0.8  66.888934             19.223376   
714     1.285000             0.8  82.068555              0.000000   
715     0.911667             0.8  18.411307             72.858771   
716     0.518333             0.7  67.208527              0.000000   
718     0.728333             0.8  10.839482             96.350716   

     isi_violations_count  amplitude_cutoff  
0                   359.0          0.006858  
1  



## Extractor good units

In [11]:
isi_th = 0.1
snr_th = 2

In [12]:
qc_file_path = os.path.join(base_folder, 'quality_metrics.csv')

try:
    # 读取 CSV 文件
    metrics_df = pd.read_csv(qc_file_path)

    # 打印读取的数据
    print("读取到的数据：")
    print(metrics_df)

    # 如果你需要将数据转换为字典格式，类似于之前打印的 metrics 变量
    metrics = metrics_df.to_dict(orient='list')
    print("转换为字典格式的数据：")
    print(metrics)
except FileNotFoundError:
    print(f"未找到文件：{qc_file_path}，请检查文件路径是否正确。")

读取到的数据：
     cluster_id  firing_rate  presence_ratio        snr  isi_violations_ratio  \
0             0     5.061667             1.0   5.841489              7.784574   
1             1     1.503333             0.8   7.857068              0.000000   
2             2    30.280000             1.0   3.345547              0.029084   
3             3     4.275000             1.0   3.590757              1.337544   
4             4     1.823333             1.0   4.566010              4.846111   
..          ...          ...             ...        ...                   ...   
651         713     0.170000             0.8  66.888934             19.223376   
652         714     1.285000             0.8  82.068555              0.000000   
653         715     0.911667             0.8  18.411307             72.858771   
654         716     0.518333             0.7  67.208527              0.000000   
655         718     0.728333             0.8  10.839482             96.350716   

     isi_violations

In [13]:
# 筛选满足条件的神经元
filtered_neurons = metrics_df[
    (metrics_df['isi_violations_ratio'] < isi_th) & (metrics_df['snr'] > snr_th)
]

# 提取筛选后的神经元 ID
filtered_neuro_ids = filtered_neurons['cluster_id'].tolist()

# 打印结果
print("满足条件的神经元 ID：")
print(filtered_neuro_ids)
print("满足条件的神经元数量：", len(filtered_neuro_ids))

# 如果需要保存筛选后的结果到新的 CSV 文件
filtered_neurons.to_csv(os.path.join(base_folder,'filtered_quality_metrics.csv'),index=False)
print("筛选后的数据已保存到 'filtered_quality_metrics.csv'")

满足条件的神经元 ID：
[1, 2, 5, 6, 7, 13, 14, 22, 28, 30, 32, 33, 34, 35, 36, 37, 38, 39, 42, 45, 47, 48, 52, 53, 55, 56, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 75, 76, 77, 78, 80, 82, 83, 91, 93, 95, 96, 97, 99, 103, 105, 107, 108, 111, 112, 113, 115, 117, 118, 119, 120, 122, 123, 130, 131, 132, 134, 136, 137, 138, 140, 141, 142, 144, 146, 148, 149, 152, 154, 156, 158, 159, 160, 165, 167, 170, 176, 183, 185, 187, 188, 191, 198, 199, 200, 203, 204, 206, 207, 208, 209, 213, 216, 217, 218, 222, 223, 225, 233, 234, 242, 249, 250, 253, 255, 257, 259, 260, 262, 263, 265, 267, 268, 272, 273, 275, 277, 278, 279, 281, 282, 285, 291, 292, 293, 295, 296, 297, 300, 302, 304, 305, 312, 314, 315, 316, 317, 318, 320, 322, 323, 328, 330, 332, 333, 335, 337, 338, 339, 341, 342, 345, 349, 351, 356, 357, 361, 365, 368, 379, 380, 391, 393, 397, 426, 428, 433, 438, 439, 446, 449, 451, 453, 457, 458, 465, 466, 467, 468, 470, 475, 478, 479, 480, 484, 485, 486, 489, 494, 507, 510, 511, 518, 523, 527, 528

## Manual calculate sis_violations
ref  [UMS] https://github.com/danamics/UMS2K/blob/master/quality_measures/rpv_contamination.m.

In this method the number of spikes whose refractory period are violated, denoted \(n_v\), is used.
Here, the refactory period \(t_r\) is adjusted to take account of the data recording system’s minimum possible refactory period. E.g. if a system has a sampling rate of \(f \text{ Hz}\), the closest that two spikes from the same unit can possibly be is \(1/f \, \text{s}\). Hence the refactory period \(t_r\) is the expected biological threshold minus this minimum possible threshold.

The contamination rate is calculated to be

[C = \frac{ n_v T }{ 2 N^2 t_r }\]

In [None]:

m_quality_metrics = []
m_quality_metrics.append(['neuro_ids', 'isi_violation','snr'])


biological_t_r = 0.0015    #### 生物学的限制，IBL和spikeinterface使用1.5ms
min_possible_t_r = 0       #### 系统限制采样，IBL和spikeinterface使用0ms，其实应该是1/sampling_rate
t_r = biological_t_r - min_possible_t_r  # 调整后的不应期

for unit_id in unit_ids:
    spike_train = sorting_info.get_unit_spike_train(unit_id)

    spike_train_s = spike_train / sampling_rate
    n_v = np.sum(np.diff(spike_train_s) < t_r)

    T = spike_train_s[-1] - spike_train_s[0]
    N = len(spike_train_s)
    C = (n_v * T) / (2 * N**2 * t_r)

    m_quality_metrics.append([unit_id, C])
    print(f"Unit {unit_id}: ISI violations contamination rate = {C:.4f}")
# print(m_quality_metrics)

## Manual calculate Signal-to-noise ratio
ref https://allensdk.readthedocs.io/en/latest/_static/examples/nb/ecephys_quality_metrics.html#SNR
ref Quantitative assessment of extracellular multichannel recording quality using measures of cluster separation. Society of Neuroscience Abstract. 2005.
ref Methods for neuronal recording in conscious animals. IBRO Handbook Series. 1984.

Calculation¶
Aμs: maximum amplitude of the mean spike waverform (on the best channel).
бb: standard deviation of the background noise on the same channel (usually computed via the median absolute deviation).
SNR = Aμs/бb