In [1]:
%reset -f

In [2]:
import spikeinterface.full as si
import spikeinterface as s
import spikeinterface.qualitymetrics as sqm
from spikeinterface.postprocessing import compute_correlograms
import matplotlib.pyplot as plt
from pathlib import Path
from tools import *
import pandas as pd
import os
import sys
import numpy as np
import scipy.stats
import spikeinterface.core as sc
import spikeinterface.extractors as se
from spikeinterface.postprocessing import compute_principal_components
from spikeinterface.qualitymetrics import (
    compute_snrs,
    compute_firing_rates,
    compute_isi_violations,
    calculate_pc_metrics,
    compute_quality_metrics,
)

## Load recording and sorting data

In [4]:
base_folder = r'/data1/zhangyuhao/xinchao_data/NP1/20230523_Syt2_conditional_tremor_mice1/'

# 初始化 kilosort4_folder 为 None
kilosort4_folder = None

# 遍历 base_folder 下的所有子文件夹
for root, dirs, files in os.walk(base_folder):
    for dir_name in dirs:
        if dir_name == 'Xinchao_sort':
            # 找到 kilosort4 文件夹，构建完整路径
            kilosort4_folder = os.path.join(root, dir_name)
            break
    if kilosort4_folder:
        break

# 检查是否找到了 kilosort4 文件夹
if kilosort4_folder:
    print(f"找到 kilosort4 文件夹，路径为: {kilosort4_folder}")
else:
    print("未在 base_folder 下找到 kilosort4 文件夹。")


找到 kilosort4 文件夹，路径为: /data1/zhangyuhao/xinchao_data/NP1/20230523_Syt2_conditional_tremor_mice1/Sorted/Xinchao_sort


In [5]:
recording = si.read_spikeglx(base_folder + 'Rawdata', stream_name='imec0.ap', load_sync_channel=False)

In [6]:
sorting_info = se.KiloSortSortingExtractor(kilosort4_folder, keep_good_only=False, remove_empty_units=False)
print('Sorting information',sorting_info)

sampling_rate = recording.get_sampling_frequency()

unit_ids = sorting_info.get_unit_ids()
print('Unit ids',unit_ids)

Sorting information KiloSortSortingExtractor: 495 units - 1 segments - 30.0kHz
Unit ids [  0   2   3   4   5   7   9  10  11  12  13  14  15  16  17  18  19  20
  21  22  23  24  25  26  27  28  29  30  32  33  34  35  36  37  38  39
  41  42  43  44  45  46  47  48  49  50  51  52  53  54  55  57  58  59
  61  62  63  66  67  68  69  71  72  73  74  75  76  78  79  80  81  82
  83  84  85  87  88  89  91  92  93  94  95  96  97  98  99 100 101 102
 103 104 105 106 107 108 109 110 111 112 113 115 116 117 118 119 120 121
 122 123 124 125 126 127 128 129 130 131 134 135 138 139 141 142 144 145
 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
 183 185 186 187 188 189 190 191 192 193 194 195 196 200 201 202 203 204
 205 206 207 208 209 210 212 214 215 216 219 220 221 222 223 224 225 226
 227 228 230 231 232 233 234 235 236 237 238 239 240 241 243 244 245 246
 247 248 249 250 251 252 253 255 256

In [7]:
analyzer = si.create_sorting_analyzer(sorting=sorting_info, recording=recording, format="memory")
print(analyzer)

estimate_sparsity (no parallelization):   0%|          | 0/2184 [00:00<?, ?it/s]

SortingAnalyzer: 384 channels - 495 units - 1 segments - memory - sparse - has recording
Loaded 0 extensions




In [8]:

if '__spec__' not in sys.modules['__main__'].__dict__:
    sys.modules['__main__'].__spec__ = None

    analyzer.compute("random_spikes", method="uniform", max_spikes_per_unit=600, seed=2205)
    analyzer.compute("waveforms", ms_before=1.3, ms_after=2.6, n_jobs=2)
    analyzer.compute("templates", operators=["average", "median", "std"])
    analyzer.compute("noise_levels")

print(analyzer)

compute_waveforms (workers: 2 processes):   0%|          | 0/2184 [00:00<?, ?it/s]

noise_level (no parallelization):   0%|          | 0/20 [00:00<?, ?it/s]

SortingAnalyzer: 384 channels - 495 units - 1 segments - memory - sparse - has recording
Loaded 4 extensions: random_spikes, waveforms, templates, noise_levels


In [9]:
firing_rates = compute_firing_rates(analyzer)
print(firing_rates)
isi_violation_ratio, isi_violations_count = compute_isi_violations(analyzer)
print(isi_violation_ratio)
snrs = compute_snrs(analyzer)
print(snrs)


{np.int64(0): 85.66083124301146, np.int64(2): 10.992972852583828, np.int64(3): 0.08151762905182142, np.int64(4): 0.14975429606711013, np.int64(5): 1.6564565409013374, np.int64(7): 0.06823666701528872, np.int64(9): 0.344847048741694, np.int64(10): 1.5300584194501985, np.int64(11): 0.17448436330755032, np.int64(12): 0.5023867363474612, np.int64(13): 0.293097093220032, np.int64(14): 0.11494901624723132, np.int64(15): 0.0819755932599777, np.int64(16): 0.05266588393797451, np.int64(17): 0.053581812354287106, np.int64(18): 0.09067691321494742, np.int64(19): 0.12456626461851361, np.int64(20): 0.04762827764825521, np.int64(21): 0.23127192511893152, np.int64(22): 0.16486711493626802, np.int64(23): 1.1307136299379048, np.int64(24): 7.337044578872083, np.int64(25): 6.599264239532284, np.int64(26): 0.25508606394305916, np.int64(27): 0.5852782580237514, np.int64(28): 1.0661406765878665, np.int64(29): 0.07785391538657101, np.int64(30): 0.23081396091077525, np.int64(32): 3.180103461437348, np.int64(3

## Save quality metrics

In [10]:
# qc_metric_names=['firing_rate', 'presence_ratio', 'snr', 'isi_violation', 'rp_violation', 'amplitude_cutoff']
qc_metric_names=['firing_rate', 'presence_ratio', 'snr', 'isi_violation', 'amplitude_cutoff']
metrics = compute_quality_metrics(analyzer, metric_names=qc_metric_names)
metrics_df = pd.DataFrame(metrics)

# 添加 cluster_id 列
metrics_df['cluster_id'] = unit_ids
cols = ['cluster_id'] + [col for col in metrics_df.columns if col != 'cluster_id']
metrics_df = metrics_df[cols]
metrics_df.to_csv(os.path.join(base_folder,'quality_metrics.csv'), index=False)

print('Metrics saved to' , base_folder)
print(metrics)

Metrics saved to /data1/zhangyuhao/xinchao_data/NP1/20230523_Syt2_conditional_tremor_mice1/
     firing_rate  presence_ratio       snr  isi_violations_ratio  \
0      85.660831        1.000000  1.138398              0.004618   
2      10.992973        1.000000  1.081219              0.050529   
3       0.081518        1.000000  1.411541              0.000000   
4       0.149754        1.000000  1.318532              0.000000   
5       1.656457        1.000000  1.143209              0.500717   
..           ...             ...       ...                   ...   
546     0.453385        0.944444  1.855136             43.815607   
547     1.054234        0.888889  1.099030              0.137353   
548     0.585278        0.861111  2.132802             44.118496   
549     0.808765        0.916667  1.608668             17.270206   
550     0.070069        0.805556  1.666495              0.000000   

     isi_violations_count  amplitude_cutoff  
0                   222.0          0.007347  



## Extractor good units

In [11]:
isi_th = 0.1
snr_th = 2

In [12]:
qc_file_path = os.path.join(base_folder, 'quality_metrics.csv')

try:
    # 读取 CSV 文件
    metrics_df = pd.read_csv(qc_file_path)

    # 打印读取的数据
    print("读取到的数据：")
    print(metrics_df)

    # 如果你需要将数据转换为字典格式，类似于之前打印的 metrics 变量
    metrics = metrics_df.to_dict(orient='list')
    print("转换为字典格式的数据：")
    print(metrics)
except FileNotFoundError:
    print(f"未找到文件：{qc_file_path}，请检查文件路径是否正确。")

读取到的数据：
     cluster_id  firing_rate  presence_ratio       snr  isi_violations_ratio  \
0             0    85.660831        1.000000  1.138398              0.004618   
1             2    10.992973        1.000000  1.081219              0.050529   
2             3     0.081518        1.000000  1.411541              0.000000   
3             4     0.149754        1.000000  1.318532              0.000000   
4             5     1.656457        1.000000  1.143209              0.500717   
..          ...          ...             ...       ...                   ...   
490         546     0.453385        0.944444  1.855136             43.815607   
491         547     1.054234        0.888889  1.099030              0.137353   
492         548     0.585278        0.861111  2.132802             44.118496   
493         549     0.808765        0.916667  1.608668             17.270206   
494         550     0.070069        0.805556  1.666495              0.000000   

     isi_violations_count  ampl

In [13]:
# 筛选满足条件的神经元
filtered_neurons = metrics_df[
    (metrics_df['isi_violations_ratio'] < isi_th) & (metrics_df['snr'] > snr_th)
]

# 提取筛选后的神经元 ID
filtered_neuro_ids = filtered_neurons['cluster_id'].tolist()

# 打印结果
print("满足条件的神经元 ID：")
print(filtered_neuro_ids)
print("满足条件的神经元数量：", len(filtered_neuro_ids))

# 如果需要保存筛选后的结果到新的 CSV 文件
filtered_neurons.to_csv(os.path.join(base_folder,'filtered_quality_metrics.csv'),index=False)
print("筛选后的数据已保存到 'filtered_quality_metrics.csv'")

满足条件的神经元 ID：
[43, 44, 52, 55, 79, 92, 138, 182, 275, 292, 294, 298, 303, 304, 317, 348, 349, 375, 376, 377, 403, 409, 414, 416, 417, 426, 427, 428, 453, 479, 503, 525]
满足条件的神经元数量： 32
筛选后的数据已保存到 'filtered_quality_metrics.csv'


## Manual calculate sis_violations
ref  [UMS] https://github.com/danamics/UMS2K/blob/master/quality_measures/rpv_contamination.m.

In this method the number of spikes whose refractory period are violated, denoted \(n_v\), is used.
Here, the refactory period \(t_r\) is adjusted to take account of the data recording system’s minimum possible refactory period. E.g. if a system has a sampling rate of \(f \text{ Hz}\), the closest that two spikes from the same unit can possibly be is \(1/f \, \text{s}\). Hence the refactory period \(t_r\) is the expected biological threshold minus this minimum possible threshold.

The contamination rate is calculated to be

[C = \frac{ n_v T }{ 2 N^2 t_r }\]

In [None]:

m_quality_metrics = []
m_quality_metrics.append(['neuro_ids', 'isi_violation','snr'])


biological_t_r = 0.0015    #### 生物学的限制，IBL和spikeinterface使用1.5ms
min_possible_t_r = 0       #### 系统限制采样，IBL和spikeinterface使用0ms，其实应该是1/sampling_rate
t_r = biological_t_r - min_possible_t_r  # 调整后的不应期

for unit_id in unit_ids:
    spike_train = sorting_info.get_unit_spike_train(unit_id)

    spike_train_s = spike_train / sampling_rate
    n_v = np.sum(np.diff(spike_train_s) < t_r)

    T = spike_train_s[-1] - spike_train_s[0]
    N = len(spike_train_s)
    C = (n_v * T) / (2 * N**2 * t_r)

    m_quality_metrics.append([unit_id, C])
    print(f"Unit {unit_id}: ISI violations contamination rate = {C:.4f}")
# print(m_quality_metrics)

## Manual calculate Signal-to-noise ratio
ref https://allensdk.readthedocs.io/en/latest/_static/examples/nb/ecephys_quality_metrics.html#SNR
ref Quantitative assessment of extracellular multichannel recording quality using measures of cluster separation. Society of Neuroscience Abstract. 2005.
ref Methods for neuronal recording in conscious animals. IBRO Handbook Series. 1984.

Calculation¶
Aμs: maximum amplitude of the mean spike waverform (on the best channel).
бb: standard deviation of the background noise on the same channel (usually computed via the median absolute deviation).
SNR = Aμs/бb