In [71]:
%reset -f

In [1]:
import spikeinterface.full as si
import spikeinterface as s
import spikeinterface.qualitymetrics as sqm
from spikeinterface.postprocessing import compute_correlograms
import matplotlib.pyplot as plt
from pathlib import Path
from tools import *
import pandas as pd
import os
import sys
import numpy as np
import scipy.stats
import spikeinterface.core as sc
import spikeinterface.extractors as se
from spikeinterface.postprocessing import compute_principal_components
from spikeinterface.qualitymetrics import (
    compute_snrs,
    compute_firing_rates,
    compute_isi_violations,
    calculate_pc_metrics,
    compute_quality_metrics,
)

  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'tools'

## Load recording and sorting data

In [73]:
base_folder = r'H:\Neuropixels_computer_data\Neuropixels_data_Third_batch\Ninth_batch_NP_1\20240925-cage1-4\cage1-4-R-DCN-1\cage1-4-R-DCN-1_g1\cage1-4-R-DCN-1_g1_imec1'

# 初始化 kilosort4_folder 为 None
kilosort4_folder = None

# 遍历 base_folder 下的所有子文件夹
for root, dirs, files in os.walk(base_folder):
    for dir_name in dirs:
        if dir_name == 'kilosort4':
            # 找到 kilosort4 文件夹，构建完整路径
            kilosort4_folder = os.path.join(root, dir_name)
            break
    if kilosort4_folder:
        break

# 检查是否找到了 kilosort4 文件夹
if kilosort4_folder:
    print(f"找到 kilosort4 文件夹，路径为: {kilosort4_folder}")
else:
    print("未在 base_folder 下找到 kilosort4 文件夹。")


找到 kilosort4 文件夹，路径为: H:\Neuropixels_computer_data\Neuropixels_data_Third_batch\Ninth_batch_NP_1\20240925-cage1-4\cage1-4-R-DCN-1\cage1-4-R-DCN-1_g1\cage1-4-R-DCN-1_g1_imec1\kilosort4


In [74]:
recording = si.read_spikeglx(base_folder, stream_name='imec0.ap', load_sync_channel=False)

In [75]:
sorting_info = se.KiloSortSortingExtractor(kilosort4_folder, keep_good_only=False, remove_empty_units=False)
printb('Sorting information',sorting_info)

sampling_rate = recording.get_sampling_frequency()

unit_ids = sorting_info.get_unit_ids()
printg('Unit ids',unit_ids)

[1;34mSorting information[0m KiloSortSortingExtractor: 268 units - 1 segments - 30.0kHz
[1;32mUnit ids[0m [  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35
  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53
  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71
  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89
  90  91  92  93  94  95  96  97  98  99 100 101 102 103 104 105 106 107
 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
 216 217 218 2

In [None]:
analyzer = si.create_sorting_analyzer(sorting=sorting_info, recording=recording, format="memory")
print(analyzer)

estimate_sparsity (no parallelization):   0%|          | 0/3118 [00:00<?, ?it/s]

In [None]:

if '__spec__' not in sys.modules['__main__'].__dict__:
    sys.modules['__main__'].__spec__ = None

    analyzer.compute("random_spikes", method="uniform", max_spikes_per_unit=600, seed=2205)
    analyzer.compute("waveforms", ms_before=1.3, ms_after=2.6, n_jobs=2)
    analyzer.compute("templates", operators=["average", "median", "std"])
    analyzer.compute("noise_levels")

print(analyzer)

In [None]:
firing_rates = compute_firing_rates(analyzer)
print(firing_rates)
isi_violation_ratio, isi_violations_count = compute_isi_violations(analyzer)
print(isi_violation_ratio)
snrs = compute_snrs(analyzer)
print(snrs)


## Save quality metrics

In [None]:
# qc_metric_names=['firing_rate', 'presence_ratio', 'snr', 'isi_violation', 'rp_violation', 'amplitude_cutoff']
qc_metric_names=['firing_rate', 'presence_ratio', 'snr', 'isi_violation', 'amplitude_cutoff']
metrics = compute_quality_metrics(analyzer, metric_names=qc_metric_names)
metrics_df = pd.DataFrame(metrics)

# 添加 neuro_id 列
metrics_df['neuro_id'] = unit_ids
cols = ['neuro_id'] + [col for col in metrics_df.columns if col != 'neuro_id']
metrics_df = metrics_df[cols]
metrics_df.to_csv(os.path.join(base_folder,'quality_metrics_test.csv'), index=False)

printg('Metrics saved to' , base_folder)
print(metrics)

## Extractor good units

In [None]:
isi_th = 0.1
snr_th = 2

In [None]:
qc_file_path = os.path.join(base_folder, 'quality_metrics.csv')

try:
    # 读取 CSV 文件
    metrics_df = pd.read_csv(qc_file_path)

    # 打印读取的数据
    print("读取到的数据：")
    print(metrics_df)

    # 如果你需要将数据转换为字典格式，类似于之前打印的 metrics 变量
    metrics = metrics_df.to_dict(orient='list')
    print("转换为字典格式的数据：")
    print(metrics)
except FileNotFoundError:
    print(f"未找到文件：{qc_file_path}，请检查文件路径是否正确。")

In [None]:
# 筛选满足条件的神经元
filtered_neurons = metrics_df[
    (metrics_df['isi_violations_ratio'] < isi_th) & (metrics_df['snr'] > snr_th)
]

# 提取筛选后的神经元 ID
filtered_neuro_ids = filtered_neurons['neuro_id'].tolist()

# 打印结果
print("满足条件的神经元 ID：")
print(filtered_neuro_ids)
printg("满足条件的神经元数量：", len(filtered_neuro_ids))

# 如果需要保存筛选后的结果到新的 CSV 文件
filtered_neurons.to_csv(os.path.join(base_folder,'filtered_quality_metrics.csv'),index=False)
print("筛选后的数据已保存到 'filtered_quality_metrics.csv'")

## Manual calculate sis_violations
ref  [UMS] https://github.com/danamics/UMS2K/blob/master/quality_measures/rpv_contamination.m.

In this method the number of spikes whose refractory period are violated, denoted \(n_v\), is used.
Here, the refactory period \(t_r\) is adjusted to take account of the data recording system’s minimum possible refactory period. E.g. if a system has a sampling rate of \(f \text{ Hz}\), the closest that two spikes from the same unit can possibly be is \(1/f \, \text{s}\). Hence the refactory period \(t_r\) is the expected biological threshold minus this minimum possible threshold.

The contamination rate is calculated to be

[C = \frac{ n_v T }{ 2 N^2 t_r }\]

In [None]:

m_quality_metrics = []
m_quality_metrics.append(['neuro_ids', 'isi_violation','snr'])


biological_t_r = 0.0015    #### 生物学的限制，IBL和spikeinterface使用1.5ms
min_possible_t_r = 0       #### 系统限制采样，IBL和spikeinterface使用0ms，其实应该是1/sampling_rate
t_r = biological_t_r - min_possible_t_r  # 调整后的不应期

for unit_id in unit_ids:
    spike_train = sorting_info.get_unit_spike_train(unit_id)

    spike_train_s = spike_train / sampling_rate
    n_v = np.sum(np.diff(spike_train_s) < t_r)

    T = spike_train_s[-1] - spike_train_s[0]
    N = len(spike_train_s)
    C = (n_v * T) / (2 * N**2 * t_r)

    m_quality_metrics.append([unit_id, C])
    print(f"Unit {unit_id}: ISI violations contamination rate = {C:.4f}")
# print(m_quality_metrics)

## Manual calculate Signal-to-noise ratio
ref https://allensdk.readthedocs.io/en/latest/_static/examples/nb/ecephys_quality_metrics.html#SNR
ref Quantitative assessment of extracellular multichannel recording quality using measures of cluster separation. Society of Neuroscience Abstract. 2005.
ref Methods for neuronal recording in conscious animals. IBRO Handbook Series. 1984.

Calculation¶
Aμs: maximum amplitude of the mean spike waverform (on the best channel).
бb: standard deviation of the background noise on the same channel (usually computed via the median absolute deviation).
SNR = Aμs/бb