In [1]:
import os
import sys
import random
import mne
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from sklearn.preprocessing import LabelEncoder, OneHotEncoder, LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.decomposition import PCA
from imblearn.over_sampling import SMOTE

import pandas as pd
import logging
import time
import json

from collections import Counter

In [2]:
def compute_overlap(interval1, interval2):
    # Determine if there is an intersection
    if interval1[1] < interval2[0] or interval2[1] < interval1[0]:
        return 0
    
    # compute the start and the end of the overlap
    start = max(interval1[0], interval2[0])
    end = min(interval1[1], interval2[1])
    
    # compute the size of overlap
    overlap = end - start + 1
    
    return overlap

def check_overlap(interval1_np_list, interval2):
    # Determin if there is an intersection of length 0.1 s
    flag = True
    for i in range( len(interval1_np_list[0]) ):
        if compute_overlap([interval1_np_list[0][i], interval1_np_list[1][i]], interval2) >= NEW_SAMP_RATE/10:
            return True

    return False

# def calculate_accuracy_np(y_true, y_pred):
#     # labels_true = torch.argmax(y_true, dim=1)  # 将真实标签转换为标签
#     # labels_pred = torch.argmax(y_pred, dim=1)  # 将预测结果转换为标签
#     accuracy = np.sum(y_true == y_pred) / len(y_true)  # 计算准确率
#     return accuracy

# def calculate_sensitivity_np(y_true, y_pred, class_index):
#     # labels_true = torch.argmax(y_true, dim=1)  # 将真实标签转换为标签
#     # labels_pred = torch.argmax(y_pred, dim=1)  # 将预测结果转换为标签
#     positive_labels = (y_true == class_index) # 将指定类别的标签设为1，其他类别设为0
#     true_positive = np.sum((y_pred == y_true) * positive_labels)  # 统计真阳性的数量
#     false_negative = np.sum((y_pred != y_true) * positive_labels)  # 统计假阴性的数量
#     if (true_positive + false_negative) == 0:
#         return -1
#     sensitivity = true_positive / (true_positive + false_negative)  # 计算灵敏度
#     return sensitivity

def count_label_np(one_hot_data, label):
    labels = np.argmax(one_hot_data, axis=1)  # 将one-hot编码转换为标签
    count = np.sum(labels == label)  # 统计标签出现的次数
    return count  

def calculate_accuracy_np(y_true, y_pred):
    labels_true = y_true#np.argmax(y_true, axis=1)  # 将真实标签转换为标签
    labels_pred = y_pred#np.argmax(y_pred, axis=1)  # 将预测结果转换为标签
    accuracy = np.sum(labels_true == labels_pred) / len(labels_true)  # 计算准确率
    return accuracy

def calculate_sensitivity_np(y_true, y_pred, class_index):
    labels_true = y_true#np.argmax(y_true, axis=1)  # 将真实标签转换为标签
    labels_pred = y_pred#np.argmax(y_pred, axis=1)  # 将预测结果转换为标签
    positive_labels = (labels_true == class_index)  # 将指定类别的标签设为1，其他类别设为0
    true_positive = np.sum((labels_pred == labels_true) * positive_labels)  # 统计真阳性的数量
    false_negative = np.sum((labels_pred != labels_true) * positive_labels)  # 统计假阴性的数量
    if (true_positive + false_negative) == 0:
        return -1
    sensitivity = true_positive / (true_positive + false_negative)  # 计算灵敏度
    return sensitivity


def get_names( ch_num, l ):
    name_list = []
    for i in range(ch_num):
        name_list.append( l[i]['name'][0] )
        # print(l[i]['name'][0])
    return name_list


def top_n_elements(lists, n=6):
    # 创建一个 Counter 对象以统计所有子列表中的元素
    overall_counter = Counter()
    
    # 遍历每个子列表并更新计数
    for sublist in lists:
        overall_counter.update(get_names(n, sublist))

    # 获取前 n 名元素
    top_elements_with_counts = overall_counter.most_common(n)

    top_elements = [element for element, count in top_elements_with_counts]
    counts = [count for element, count in top_elements_with_counts]
    
    return top_elements, counts

# top_n_elements( total_acc_list )

In [3]:
for patient_index in range(1, 25):
    # 设置数据文件夹路径
    data_folder = 'E:\\EEG\\chb-mit-scalp-eeg-database-1.0.0'
    # patients_folder = []
    # for i in range(1, 25):
    #     if i < 10:
    #         patients_folder.append("chb0"+str(i))
    #     else:
    #         patients_folder.append("chb"+str(i))
    # os.path.join(data_folder, patients_folder[0])
    
    # NUM_OF_PATIENTS_USED = 1000 # Use all data
    # PATIENT_TO_USE = 'chb13' #06' # 12'
    PATIENT_TO_USE = 'chb' + str(patient_index).zfill(2)
    print( PATIENT_TO_USE )
    
    TOTAL_ITERATION = 30
    
    SAMPLE_RATE = 256
    DOWN_SAMP_RATE = 2
    NEW_SAMP_RATE = int(SAMPLE_RATE / DOWN_SAMP_RATE)
    
    # split the raw data to 5 sec segments
    SEG_TIME = 5
    SEG_LEN = NEW_SAMP_RATE * SEG_TIME
    OVERLAP_TIME = 0 # 4.5
    STEP_FOR_OVERLAP = int(NEW_SAMP_RATE * (SEG_TIME - OVERLAP_TIME))
    STEP_NO_OVERLAP  = SEG_LEN
    
    PIL = 15 * 60 * NEW_SAMP_RATE# Preictal interval length (PIL)
    
    # Different patients have different channels, so pick the commont channels
    COMMON_CH_TOTAL = ['P4-O2', 'T7-P7', 'C4-P4', 'FZ-CZ', 'F3-C3', 'P7-O1', 'CZ-PZ', 'P8-O2', 'F4-C4', 'P3-O1', 'FP2-F8', 'F8-T8', 'FP1-F3', 'F7-T7', 'C3-P3', 'FP1-F7', 'FP2-F4']
    
    start_time = time.time()
    SEIZURE_FOR_TEST = 1
    total_acc_list = []
    for iteration in range(TOTAL_ITERATION):
        iter_start_time = time.time()
        
        acc_list = []
        
        for ch_i in range( len(COMMON_CH_TOTAL) ):
            COMMON_CH = [COMMON_CH_TOTAL[ch_i]]
                
            # These files have no common channels with other files
            FILE_EXCLUDED = ['chb12_27.edf', 'chb12_28.edf', 'chb12_29.edf'] 
                
            logger = logging.getLogger('mne')
            # logger.setLevel(logging.WARNING)
            logger.setLevel(logging.ERROR) # 每读一个edf文件都会有一个通道名重复的警告
                
            # 导入数据
            raws = []
            file_names = []
            labels = []
            sec_seizure = {}
                
            # cnt = 0
            # for subfolder in os.listdir(data_folder):
            patient_path_folder = os.path.join(data_folder, PATIENT_TO_USE)
            if os.path.isdir(patient_path_folder):
                # only read 16 people's data because of the limited memory
                # cnt += 1
                # if cnt > NUM_OF_PATIENTS_USED:
                #     break
                        
                # extract seizure information from summary
                summary_path = os.path.join(patient_path_folder, PATIENT_TO_USE+"-summary.txt")
                with open(summary_path, 'r') as file:
                    content = file.read()
                sections = content.split('\n\n')
                for section in sections:
                    if ("Seizure Start Time" in section) or ("Seizure 1 Start Time" in section):
                        lines = section.split('\n')
                        temp_name = ''
                        for i, line in enumerate(lines):
                            if "File Name:" in line:
                                temp_name = line.replace('File Name: ', '')
                                sec_seizure[temp_name] = {}
                                sec_seizure[temp_name]['start'] = []
                                sec_seizure[temp_name]['end'] = []
                            elif ("Seizure" in line) and ("Start Time" in line):
                                parts = line.split(":")
                                sec_seizure[temp_name]['start'].append( int(''.join(filter(str.isdigit, parts[1]))) )
                                parts_end = lines[i+1].split(":")
                                sec_seizure[temp_name]['end'].append( int(''.join(filter(str.isdigit, parts_end[1]))) )
                                # print(sec_seizure[temp_name]['start'], sec_seizure[temp_name]['end'])
                            # elif ("Seizure Start Time" in line):
                            #     sec_seizure[temp_name]['start'] = int(''.join(filter(str.isdigit, line)))
                            #     sec_seizure[temp_name]['end'] = int(''.join(filter(str.isdigit, lines[i+1])))
                            # elif ("Seizure 1 Start Time" in line):
                            #     sec_seizure[temp_name]['start'] = int(''.join(filter(str.isdigit, line.replace('Seizure 1', ''))))
                            #     sec_seizure[temp_name]['end'] = int(''.join(filter(str.isdigit, lines[i+1].replace('Seizure 1', ''))))
                # read edf data
                for file in os.listdir(patient_path_folder):
                    if file.endswith('.edf') and file not in FILE_EXCLUDED:
                        file_path = os.path.join(patient_path_folder, file)
                        raw = mne.io.read_raw_edf(file_path)
                        raws.append(raw.pick(COMMON_CH))
                        file_names.append(file)
                        if file in sec_seizure:
                            labels.append({**sec_seizure[file], 'name':file})
                        else:
                            labels.append('')
            
            # split the data into 5 sec segments
                
            segments = []
            seg_labels = []
                
            one_patient_segs = []
            one_patient_labels = []
            for index, (raw, label, file_name) in enumerate(zip(raws, labels, file_names)):
                data = raw.get_data()[:, ::DOWN_SAMP_RATE]
            #     for ch_idx in range(data.shape[0]):
            #         ch_data = data[ch_idx]
            #         min_val = np.min(ch_data)
            #         max_val = np.max(ch_data)
            #         normalized_ch_data = (ch_data - min_val) / (max_val - min_val)
            #         data[ch_idx] = normalized_ch_data
            # #     data = np.array(data)
                if label == '':
                    # split data without overlapping
                    for i in range(0, data.shape[1], STEP_NO_OVERLAP):
                        if data.shape[1] - i < SEG_LEN:
                            break
                        one_patient_segs.append( torch.tensor(data[:, i:(i+SEG_LEN)]) )  # plus Gaussian noise
                        one_patient_labels.append('inter_ictal')
                else:
                    # split data with 4 sec overlapping
                    seizure_start = np.array( label['start'] ) * NEW_SAMP_RATE
                    seizure_end   = np.array( label['end'] ) * NEW_SAMP_RATE
                    i = 0
                    while i < data.shape[1]:
                        if data.shape[1] - i < SEG_LEN:
                            break
                        one_patient_segs.append( torch.tensor(data[:, i:(i+SEG_LEN)]) ) # plus Gaussian noise
                        if check_overlap([seizure_start, seizure_end], [i, (i+SEG_LEN-1)]):
                            one_patient_labels.append('ictal')
                            i += STEP_FOR_OVERLAP
                        elif check_overlap([seizure_start-PIL, seizure_start], [i, (i+SEG_LEN-1)]):
                            one_patient_labels.append('pre_ictal')
                            i += STEP_FOR_OVERLAP
                        else:
                            one_patient_labels.append('inter_ictal')
                            i += STEP_NO_OVERLAP
                    
                if (index != 0) and (file_name[0:5] != file_names[index-1][0:5]):
                    # segments.append( one_patient_segs )
                    # seg_labels.append( one_patient_labels )
                    segments.extend( one_patient_segs )
                    seg_labels.extend( one_patient_labels )
                    one_patient_segs = []
                    one_patient_labels = []
                    # print(index, file_name, file_names[index-1][0:5])
                if index == len(raws)-1:
                    # segments.append( one_patient_segs )
                    # seg_labels.append( one_patient_labels )
                    segments.extend( one_patient_segs )
                    seg_labels.extend( one_patient_labels )
                    one_patient_segs = []
                    one_patient_labels = []
                    # print(index, file_name, file_names[index-1][0:5])
            
            # for index, label in enumerate(seg_labels):
            #     if label == 'inter_ictal':
            #         seg_labels[index] = 1
            #     elif label == 'ictal':
            #         seg_labels[index] = 1
            #     else:
            #         seg_labels[index] = 2
    
            pre_cnt = 0
            # Find_target_pre_start = False
            # Find_target_pre_end = False
            target_pre_start = -1
            target_pre_end = -1
            for index in range(0, len(seg_labels)):
                if seg_labels[index] == 'inter_ictal':
                    seg_labels[index] = 1
                elif seg_labels[index] == 'ictal':
                    seg_labels[index] = 1
                else:
                    seg_labels[index] = 2
                    # Enter a new pre-ictal segment
                    if seg_labels[index-1] != 2:
                        pre_cnt += 1
                        if pre_cnt == SEIZURE_FOR_TEST:
                            target_pre_start = index
                    # Leave a pre-ictal segment
                    if seg_labels[index+1] != 'pre_ictal':
                        if pre_cnt == SEIZURE_FOR_TEST:
                            target_pre_end = index
    
            pre_len_for_test = target_pre_end - target_pre_start + 1
            step_on_both_sides = int( pre_len_for_test / 2 )
            For_test = np.zeros(len(seg_labels), dtype=bool)
            For_test[ target_pre_start-step_on_both_sides : target_pre_end+step_on_both_sides ] = True
            
            one_hot_enc = LabelBinarizer()
            seg_labels = one_hot_enc.fit_transform(seg_labels)
        
            segments = np.array(segments) * 1e4
            seg_labels = np.array(seg_labels)
        
            segments_for_train,  segments_for_test, seg_labels_for_train, seg_labels_for_test = train_test_split(segments, seg_labels, test_size=0.3)#, random_state=42)
    
            # segments_for_train = segments[0: int( len(seg_labels)*0.7 )]
            # seg_labels_for_train = seg_labels[0: int( len(seg_labels)*0.7 )]
            # segments_for_test = segments[int( len(seg_labels)*0.7 ) : ]
            # seg_labels_for_test = seg_labels[int( len(seg_labels)*0.7 ) : ]
            
            seg_train_pca_list = []
            seg_test_pca_list = []
            # print("segments_for_train.shape =", segments_for_train.shape, "segments_for_train.squeeze().shape =", segments_for_train.squeeze().shape)
        
            # start_time = time.perf_counter()
            ##############
            pca = PCA(n_components=60)
            # seg_train_pca = pca.fit_transform(segments_for_train.squeeze())
            # seg_test_pca = pca.transform(segments_for_test.squeeze())
            # seg_train_pca_list.append(seg_train_pca)
            # seg_test_pca_list.append(seg_test_pca)
                
            # segments_for_train_pca = np.hstack(seg_train_pca_list)
            # segments_for_test_pca = np.hstack(seg_test_pca_list)
            segments_for_train_pca = pca.fit_transform(segments_for_train.squeeze())
            segments_for_test_pca = pca.transform(segments_for_test.squeeze())
            ##############
            # end_time = time.perf_counter()
            # execution_time = end_time - start_time
            # print(f"PCA execution time：{execution_time} s")
        
            # start_time = time.perf_counter()
            ##############
            smote = SMOTE()#(random_state=42)
            segments_for_train_pca, seg_labels_for_train = smote.fit_resample(segments_for_train_pca, seg_labels_for_train)
            ##############
            # end_time = time.perf_counter()
            # execution_time = end_time - start_time
            # print(f"SMOTE execution time：{execution_time} s")
        
            # start_time = time.perf_counter()
            ##############
            clf = DecisionTreeClassifier()#random_state=0)
            clf.fit(segments_for_train_pca, seg_labels_for_train)
            ##############
            # end_time = time.perf_counter()
            # execution_time = end_time - start_time
            # print(f"Training execution time：{execution_time} s")
        
            # start_time = time.perf_counter()
            ##############
            y_pred = clf.predict(segments_for_test_pca)
            ##############
            # end_time = time.perf_counter()
            # execution_time = end_time - start_time
            # print(f"Testing execution time：{execution_time} s")
            
            # seg_labels_for_test = np.array(seg_labels_for_test)
            # true_positive = np.sum(( seg_labels_for_test == 2 ) * ( y_pred == 2 ))
            # false_negtive = 
            
            # sens_p = calculate_sensitivity_np(seg_labels_for_test, y_pred, 2)
            sens_i = calculate_sensitivity_np(seg_labels_for_test.squeeze(), y_pred, 1)
            # sens_inter = calculate_sensitivity_np(seg_labels_for_test, y_pred, 0)
            accuracy = calculate_accuracy_np(seg_labels_for_test.squeeze(), y_pred)
            # print(COMMON_CH)
            # print("***********\nsensitivity", sens_i)
            # print("Acc:", accuracy)
            # print("***********\n")
            
            temp = {}
            temp['sens'] = sens_i
            temp['acc'] = accuracy
            temp['name'] = COMMON_CH
            acc_list.append(temp)
    
        iter_end_time = time.time()
        print(iteration, iter_end_time-iter_start_time, 's')
        sorted_acc_list = sorted( acc_list, key=lambda x: x["acc"], reverse=True)
        # sorted_acc_list = sorted( acc_list, key=lambda x: x["sens"], reverse=True)
        total_acc_list.append(sorted_acc_list)
    
    
    end_time = time.time()
    
    print( '*****************' )
    print( PATIENT_TO_USE )
    print("Total time cost:", end_time - start_time, "s")
    print( top_n_elements( total_acc_list, 10 ) )
    print( '*****************' )
    
    with open(PATIENT_TO_USE+'_sel_ch_30iter_with_SMOTE.json', 'w') as json_file:
        json.dump(total_acc_list, json_file)

chb01
0 143.6372594833374 s
1 146.79923963546753 s
2 143.1516628265381 s
3 144.44317889213562 s
4 144.0181233882904 s
5 144.72223043441772 s
6 143.40331768989563 s
7 144.46002411842346 s
8 144.73128461837769 s
9 142.94955921173096 s
10 142.90636229515076 s
11 144.52351069450378 s
12 142.22881507873535 s
13 143.47526741027832 s
14 142.48201823234558 s
15 144.57673716545105 s
16 143.25933599472046 s
17 140.14404392242432 s
18 143.2715187072754 s
19 141.21577978134155 s
20 144.69064044952393 s
21 142.7572042942047 s
22 142.02234768867493 s
23 141.94794869422913 s
24 143.66711831092834 s
25 141.84386777877808 s
26 144.05030179023743 s
27 143.822270154953 s
28 142.88473272323608 s
29 139.2898073196411 s
*****************
chb01
Total time cost: 4297.37650847435 s
(['P4-O2', 'FZ-CZ', 'FP2-F4', 'C3-P3', 'P8-O2', 'P3-O1', 'C4-P4', 'CZ-PZ', 'FP2-F8', 'T7-P7'], [30, 30, 30, 30, 30, 29, 26, 26, 20, 14])
*****************
chb02
0 127.15800404548645 s
1 126.02490258216858 s
2 126.87970519065857 s
3 

In [4]:
with open(PATIENT_TO_USE+'_sel_ch_30iter_with_SMOTE.json', 'r') as json_file:
    loaded_data = json.load(json_file)
print(loaded_data)

[[{'sens': 0.3408723747980614, 'acc': 0.6771739130434783, 'name': ['T7-P7']}, {'sens': 0.30806451612903224, 'acc': 0.6754347826086956, 'name': ['CZ-PZ']}, {'sens': 0.3359013867488444, 'acc': 0.6752173913043479, 'name': ['F7-T7']}, {'sens': 0.3568147013782542, 'acc': 0.6706521739130434, 'name': ['FZ-CZ']}, {'sens': 0.28892455858747995, 'acc': 0.6689130434782609, 'name': ['C4-P4']}, {'sens': 0.3030769230769231, 'acc': 0.6615217391304348, 'name': ['F4-C4']}, {'sens': 0.2850678733031674, 'acc': 0.6615217391304348, 'name': ['P3-O1']}, {'sens': 0.3220338983050847, 'acc': 0.6608695652173913, 'name': ['F3-C3']}, {'sens': 0.3275316455696203, 'acc': 0.6597826086956522, 'name': ['F8-T8']}, {'sens': 0.3105590062111801, 'acc': 0.6591304347826087, 'name': ['P8-O2']}, {'sens': 0.29606299212598425, 'acc': 0.6563043478260869, 'name': ['C3-P3']}, {'sens': 0.30806451612903224, 'acc': 0.6554347826086957, 'name': ['P4-O2']}, {'sens': 0.28278041074249605, 'acc': 0.6539130434782608, 'name': ['FP1-F3']}, {'se

In [5]:
set(COMMON_CH_TOTAL) == set( sorted(COMMON_CH_TOTAL) )

True

In [6]:
sorted_acc_list = sorted( acc_list, key=lambda x: x["acc"], reverse=True)
for i, dic in enumerate(sorted_acc_list):
    print(i+1)
    print( dic['name'] )
    print( dic['acc'] )
    print()

1
['F3-C3']
0.6797826086956522

2
['CZ-PZ']
0.6758695652173913

3
['FZ-CZ']
0.675

4
['F7-T7']
0.6741304347826087

5
['P7-O1']
0.6693478260869565

6
['C3-P3']
0.6684782608695652

7
['C4-P4']
0.6682608695652174

8
['T7-P7']
0.6667391304347826

9
['P4-O2']
0.665

10
['F8-T8']
0.6641304347826087

11
['P8-O2']
0.6613043478260869

12
['P3-O1']
0.6606521739130434

13
['FP2-F4']
0.6573913043478261

14
['F4-C4']
0.6556521739130434

15
['FP1-F7']
0.6545652173913044

16
['FP2-F8']
0.6471739130434783

17
['FP1-F3']
0.6469565217391304



In [7]:
with open(PATIENT_TO_USE+'_sel_ch_30iter_with_SMOTE.json', 'r') as json_file:
    loaded_data = json.load(json_file)
top_n_elements( loaded_data )

(['FZ-CZ', 'F3-C3', 'C3-P3', 'CZ-PZ', 'C4-P4', 'F8-T8'],
 [28, 23, 20, 19, 19, 16])