In [None]:
# BMINT DEAL WITH CHB-MIT DATASET
# Author： YanLi@Fudan university

In [None]:
import numpy as np
import mne
import os
import re
# %matplotlib notebook
import matplotlib.pyplot as plt
import warnings
# 忽略RuntimeWarning警告
warnings.filterwarnings("ignore", category=RuntimeWarning, message="Channel names are not unique")
np.random.seed(77)

In [None]:
current = 0
ch_list = [20, 3, 19, 10, 15, 13, 4, 16, 6, 1]
patient_list = ['chb01','chb02','chb03','chb04','chb05','chb06','chb07','chb08','chb09','chb10']
dataset_dir = 'C:/Users/gaosiy/Desktop/chb_data/'
ch = ch_list[current]

In [None]:
# 检查并确认通道是否正确
raw = mne.io.read_raw_edf(dataset_dir+patient_list[current]+'/'+patient_list[current]+'_02.edf', verbose=False)
# 获取通道名字列表
ch_names = raw.ch_names
# 打印通道名字列表 
print('Channel is',ch_names[ch])

In [None]:
def load_info(info_path):

    with open(info_path, 'r') as infile:
        info_string = infile.read()
    
    s_name = []
    n_name = []
    tmlist = []
    seizure_info = []
    non_seizure_info = []
    start_sec = []
    stop_sec = []
    info_array = info_string.split('\n\n')

    for block in info_array:
        tmlist = block.split('\n')
        if 'File Name:' in block:
            if  'Seizure ' in block:
                s_name = tmlist[0][tmlist[0].index('chb'):]
                for i in range(3, len(tmlist)):
                    
                    if 'Start' in tmlist[i]:
                        start_sec.append(int(re.search(r"Start Time: ([0-9]*) seconds", tmlist[i]).group(1)))
                    if 'End' in tmlist[i]:
                        stop_sec.append(int(re.search(r"End Time: ([0-9]*) seconds", tmlist[i]).group(1)))
                
                seizure_info.append([s_name, start_sec, stop_sec])
                start_sec = []
                stop_sec = []
            else:
                non_seizure_info.append(tmlist[0][tmlist[0].index('chb'):])
                
    return seizure_info, non_seizure_info

In [None]:
seizure_info, non_seizure_info = load_info(dataset_dir+patient_list[current]+'-summary.txt')
print('patients',patient_list[current])
print('seizure files\n',seizure_info)
print('nonseizure files\n',non_seizure_info)

In [None]:
# 获得不包含癫痫发作的数据
def get_noseizure_file_data(name):    
    edf_filename = dataset_dir + patient_list[current]+ "/" + name
    edf = mne.io.read_raw_edf(edf_filename,stim_channel=None, verbose=False)
    data = edf.get_data()[ch].astype(np.float32)* 1e6 #  to uV
    data = np.array(data)
    data_1s = data.reshape(-1, 256)
    return data_1s

ns_data = np.empty([0, 256], dtype=np.float32)
for name in non_seizure_info:
    tem = get_noseizure_file_data(name)
    ns_data = np.concatenate([ns_data, tem], axis=0) 

In [None]:
# 癫痫发作数据获取
def get_seizure_file_data(name):  
    print(name)
    info = name[0]
    start = name[1][0]
    stop = name[2][0]
    edf_filename = dataset_dir + patient_list[current]+ "/" + info
    edf = mne.io.read_raw_edf(edf_filename,stim_channel=None, verbose=False)
    data = edf.get_data()[ch].astype(np.float32)* 1e6 #  to uV
    data = np.array(data)
    data_1s = data.reshape(-1, 256)
    # 选取癫痫部分数据
    seizure_len = stop - start
    head_tail_len = round(seizure_len/2)
    s_data = data_1s[start-1:stop-1]
    # 拼接前后的非发作部分
    ns_head_data = data_1s[:start-1]
    ns_tail_data = data_1s[stop-1:]

    ns_data = np.concatenate([ns_head_data,ns_tail_data])
    return s_data, ns_data

s_data = np.empty([0, 256], dtype=np.float32)
ns_data_add = np.empty([0, 256], dtype=np.float32)
for info in seizure_info:
    s, ns = get_seizure_file_data(info)
    s_data = np.concatenate([s_data, s], axis=0) 
    ns_data_add = np.concatenate([ns_data_add, ns], axis=0) 
print('seizure ', s_data.shape[0], '\nnoseizure ', ns_data_add.shape[0])

In [None]:
# 两部分非发作的拼接起来，生成最终的非发作集
ns = np.concatenate([ns_data, ns_data_add], axis=0)
s = s_data
print('ns：',ns.shape[0], 's:',s.shape[0])

# 去掉重复数据
ns_de = np.unique(ns, axis=0)
s_de = np.unique(s, axis=0)

print('ns：',ns_de.shape[0], 's:',s_de.shape[0])

# 去掉异常数据，1s中所有数据相同
def check_duplicate_arrays(arr):
    index_list = []
    for i, sub_arr in enumerate(arr):
        if all(x == sub_arr[0] for x in sub_arr):
            index_list.append(i)
    if len(index_list)>0:
        print('Warning!delte bad data ',index_list)
        arr = np.delete(arr, index_list, axis=0)
    else:
        print('pass!')
    return arr
ns_de = check_duplicate_arrays(ns_de)
s_de = check_duplicate_arrays(s_de)


# 非发作训练集测试集划分
N = ns_de.shape[0]   
np.random.shuffle(ns_de)
cut_point = int(N*0.7)
train_ns = ns_de[:cut_point]
test_ns = ns_de[cut_point:]

# 发作训练集测试集划分
N = s_de.shape[0]   
np.random.shuffle(s_de)
cut_point = int(N*0.7)
train_s = s_de[:cut_point]
test_s = s_de[cut_point:]

print('训练集正样本',train_s.shape[0])
print('训练集负样本',train_ns.shape[0])
print('测试集正样本',test_s.shape[0])
print('测试集负样本',test_ns.shape[0])

In [None]:
# 训练集测试集保存为文件，供不同算法测试使用
# train_s
# train_ns
# test_s
# test_ns

print("train_s：", train_s.nbytes/1024/1024, "Mb")
print("train_ns：", train_ns.nbytes/1024/1024, "Mb")
print("test_s：", test_s.nbytes/1024/1024, "Mb")
print("test_ns：", test_ns.nbytes/1024/1024, "Mb")

np.savez('dataset.npz', 
         train_s=train_s, 
         train_ns=train_ns, 
         test_s=test_s,
         test_ns=test_ns)