In [18]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import io
import os
from collections import Counter
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from scipy.signal import butter, filtfilt

def seed_assign(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


In [19]:
data_path = '/data2/spike_sorting/neuropixels_choi'
filename = ["/set1/20141202_all_es", "/set2/20150924_1_e", "/set3/20150601_all_s", "/set4/20150924_1_GT"]

In [20]:
index = 2
ground_truth_path = data_path+filename[index]+'_gtTimes.mat'

In [21]:
# 근데 prob파일에 129번째 채널이 없네?


## Merge

In [22]:
width = 1_500_000
max_iter = 5
n_channel = 129
n_unit = 7 if index != 2 else 8 # set2는 예외적으로 unit 8개다.
frequency = 25_000
my_seed = 42
seed_assign(my_seed)

In [23]:
# 3,125,000 샘플만 쓸래. # 현실시간 125s = 3,125,000 / 25000Hz

In [24]:
sample_num_limit = 3_125_000

In [25]:
all_waveforms = []
for i in range(max_iter):
    start_index = i * width + 1  # 현재 반복에서 시작 인덱스 계산
    end_index = (i + 1) * width  # 현재 반복에서 끝 인덱스 계산
    
    current_dataset = f"{data_path}{filename[index]}_start{start_index}_end{end_index}.mat"
    mat1 = io.loadmat(current_dataset)
    all_waveforms.append(mat1['raw_data'])
    print(mat1.keys())
merged_waveform = np.concatenate(all_waveforms, axis=1)[:, :sample_num_limit]
print(merged_waveform.shape)
print(mat1['raw_data'].dtype)


dict_keys(['__header__', '__version__', '__globals__', 'raw_data'])
dict_keys(['__header__', '__version__', '__globals__', 'raw_data'])
dict_keys(['__header__', '__version__', '__globals__', 'raw_data'])
dict_keys(['__header__', '__version__', '__globals__', 'raw_data'])
dict_keys(['__header__', '__version__', '__globals__', 'raw_data'])
(129, 3125000)
int16


In [26]:
ground_truth_path = data_path+filename[index]+'_gtTimes.mat'

mat1 = io.loadmat(ground_truth_path)
spike_times_int_all_units = []
for i in range(n_unit):
    temp = mat1['gtTimes'][0][i][:,0]
    temp = temp[temp < sample_num_limit]
    spike_times_int_all_units.append(temp)
    print('unit', i, '스파이크 튄 gttime', temp.shape)

spike_chans_int_all_units = []
for i in range(n_unit):
    spike_chans_int_all_units.append(mat1['gtChans'][0][i][:,0])
    print('unit', i,'연관 ch', mat1['gtChans'][0][i][:,0].shape)
    
print('유닛 개수', mat1['gtTimes'][0].shape)

unit 0 스파이크 튄 gttime (2,)
unit 1 스파이크 튄 gttime (9,)
unit 2 스파이크 튄 gttime (34,)
unit 3 스파이크 튄 gttime (0,)
unit 4 스파이크 튄 gttime (2,)
unit 5 스파이크 튄 gttime (31,)
unit 6 스파이크 튄 gttime (37,)
unit 7 스파이크 튄 gttime (279,)
unit 0 연관 ch (12,)
unit 1 연관 ch (9,)
unit 2 연관 ch (11,)
unit 3 연관 ch (12,)
unit 4 연관 ch (13,)
unit 5 연관 ch (12,)
unit 6 연관 ch (14,)
unit 7 연관 ch (11,)
유닛 개수 (8,)


In [27]:
print(len(merged_waveform))
print(merged_waveform[0].shape)

129
(3125000,)


In [28]:
print(len(spike_times_int_all_units))
print(spike_times_int_all_units[0].shape) # 0번째 unit의 spike time

8
(2,)


In [29]:
total_spike = 0
for i in range(n_unit):
    total_spike += spike_times_int_all_units[i].shape[0]
print('모든 채널의 모든 스파이크 개수 합', total_spike)

모든 채널의 모든 스파이크 개수 합 394


In [30]:
print('총 유닛',len(spike_chans_int_all_units))
for i in spike_chans_int_all_units:
    print(i)
# 유닛별 잘 나오는 channel

총 유닛 8
[28 27 26 24 23 22 60 55 56 57 58 59]
[ 77  68 100  79  78  67  66  99 111]
[29 30 17 18 19 62 63 64 50 51 52]
[ 94  95  96  83  84 125 126 127 128 114 115 116]
[ 75  76  77  69  70  74 106 107 108 109 101 102 103]
[ 75  76  77  68  69  70 107 108 109 100 101 102]
[ 76  77  68  69  70 108 109 100 101 102  78  67  99 111]
[ 92  86  87  88  90  91 124 119 120 122 123]


In [31]:
merged_waveform = merged_waveform/256

In [32]:
# Bandpass filter 함수
def bandpass_filter(data, lowcut, highcut, fs, order=4):
    nyquist = 0.5 * fs  # Nyquist frequency
    low = lowcut / nyquist
    high = highcut / nyquist
    
    # Butterworth 필터 설계
    b, a = butter(order, [low, high], btype='band')
    
    # 필터링 적용
    filtered_data = filtfilt(b, a, data)
    return filtered_data


In [33]:
# bandpass filter 적용
lowcut = 300.0  # 필터의 낮은 주파수 (Hz)
highcut = 6000.0  # 필터의 높은 주파수 (Hz)

for i in range(merged_waveform.shape[0]):
    merged_waveform[i] = bandpass_filter(merged_waveform[i], lowcut, highcut, frequency)
print(merged_waveform.shape)


(129, 3125000)


In [34]:
np.save(f"{data_path}{filename[index]}_merged_{sample_num_limit}_limit.npy", merged_waveform)