In [1]:
import numpy as np
import pywt

import seaborn as sns #绘制confusion matrix heatmap

import sklearn
import os

import warnings

warnings.simplefilter('ignore') #忽略警告

In [2]:
import scipy
import scipy.io as sio

import scipy.signal

from scipy import linalg

import pandas as pd
from sklearn.decomposition import PCA
#分类器
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import AdaBoostClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.neural_network import MLPClassifier

from sklearn.model_selection import train_test_split 

import xgboost
import lightgbm

#模型集成
from sklearn.ensemble import VotingClassifier
from sklearn.ensemble import BaggingClassifier
from mlxtend.classifier import StackingClassifier

#模型调节
from sklearn.model_selection import GridSearchCV #参数搜索
from mlxtend.feature_selection import SequentialFeatureSelector #特征选择函数 选择合适的feature

#结果可视化
from sklearn.metrics import classification_report , confusion_matrix #混淆矩阵

#相关指标
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score

from sklearn.metrics import cohen_kappa_score

from sklearn.metrics import roc_auc_score

#二分类其多分类化
#from sklearn.multiclass import OneVsOneClassifier
#from sklearn.multiclass import OneVsRestClassifier

#from sklearn.preprocessing import StandardScaler
#from sklearn.cluster import KMeans

#距离函数 度量向量距离
from sklearn.metrics.pairwise import manhattan_distances
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.metrics.pairwise import cosine_distances
from sklearn.metrics.pairwise import cosine_similarity #余弦相似度

#one-hot使用
#from keras.utils import to_categorical

from sklearn.preprocessing import label_binarize

#绘图
import matplotlib.pyplot as plt

import scipy.linalg as la

import gc

%matplotlib inline

In [3]:
sample_rate = 256 #hz
origin_channel = 16 #5 channel eeg

#采集的通道
#共16 channel
#未使用的channel使用none代替
#reference:a study on performance increasing in ssvep based bci application
SAMPLE_CHANNEL = ['Pz' , 'PO3' , 'PO4' , 'O1' , 'O2' , 'Oz' , 'O9' , 'FP2' ,
                  'C4' , 'C6' , 'CP3' , 'CP1' ,
                  'CPZ' , 'CP2' , 'CP4' , 'PO8']


# 减去前多少秒数据 second
# 减去后多少秒数据 second
CLIP_FORWARD = 2
CLIP_BACKWARD = 1

# 单个小段的实验时长
trial_time = 3 #second

trial_offset = 0 #second
start_trial_time = 0 #真正的实验开始时刻
end_trial_time = 2 #真正的实验结束时刻(<trial_time)

#是否进行归一化
#reference:a study on performance increasing in ssvep based bci application
#IS_NORMALIZE = True

#是否进行滤波
#IS_FILTER = False
#EEG频率范围
#reference:a study on performance increasing in ssvep based bci application
LO_FREQ = 0.5
HI_FREQ = 40

#是否陷波
#IS_NOTCH = False
NOTCH_FREQ = 50 #陷波 工频



In [4]:
def butter_worth(data , lowcut , highcut , order=6):
    nyq = 0.5 * sample_rate
    
    lo = lowcut / nyq
    hi = highcut / nyq
    
    b,a = scipy.signal.butter(order , [lo , hi] , btype='bandpass')

    return np.array([scipy.signal.filtfilt(b , a , data[: , i]) for i in range(data.shape[1])]).reshape((-1 , origin_channel))

def load_data(filename):    
    data = sio.loadmat(file_name=filename)['data_received'] #length*16 matrix
    
    #截掉前后数据
    data = data[CLIP_FORWARD * sample_rate : - CLIP_BACKWARD * sample_rate]
    
    #滤波
    #return butter_worth(data , 0.5 , 30)

    return data

In [5]:
def separate(data , label , overlap_length):
    '''
    最长重叠长度为size长 256*3 个数据点
    '''
    train_data = []
    train_labels = []

    size = sample_rate * trial_time #一小段 256*3 个数据点
    data_length = data.shape[0]

    idx = 0

    while idx<data_length-size:
        train_data.append(data[idx : idx+size , :])
        train_labels.append(label)

        idx = idx + (size - overlap_length)

    return np.array(train_data) , np.array(train_labels)

In [6]:
def shuffle_t_v(filenames):
    np.random.shuffle(filenames)
    
    return filenames

def combine(freq):
    overlap_length = 2*256 #重叠2秒数据
    
    #保证随机性 进行置乱
    person_0_filenames = shuffle_t_v( os.listdir('data/0/%s/' % freq) )
    person_1_filenames = shuffle_t_v( os.listdir('data/1/%s/' % freq) )
    person_2_filenames = shuffle_t_v( os.listdir('data/2/%s/' % freq) )
    person_3_filenames = shuffle_t_v( os.listdir('data/3/%s/' % freq) )
    person_4_filenames = shuffle_t_v( os.listdir('data/4/%s/' % freq) )
    person_5_filenames = shuffle_t_v( os.listdir('data/5/%s/' % freq) )
    person_6_filenames = shuffle_t_v( os.listdir('data/6/%s/' % freq) )
    person_7_filenames = shuffle_t_v( os.listdir('data/7/%s/' % freq) )
    person_8_filenames = shuffle_t_v( os.listdir('data/8/%s/' % freq) )

    #打开信号文件 并 合并
    person_0 = np.concatenate([load_data('data/0/%s/' % freq + filename) for filename in person_0_filenames] , axis = 0)
    person_1 = np.concatenate([load_data('data/1/%s/' % freq + filename) for filename in person_1_filenames] , axis = 0)
    person_2 = np.concatenate([load_data('data/2/%s/' % freq + filename) for filename in person_2_filenames] , axis = 0)
    person_3 = np.concatenate([load_data('data/3/%s/' % freq + filename) for filename in person_3_filenames] , axis = 0)
    person_4 = np.concatenate([load_data('data/4/%s/' % freq + filename) for filename in person_4_filenames] , axis = 0)
    person_5 = np.concatenate([load_data('data/5/%s/' % freq + filename) for filename in person_5_filenames] , axis = 0)
    person_6 = np.concatenate([load_data('data/6/%s/' % freq + filename) for filename in person_6_filenames] , axis = 0)
    person_7 = np.concatenate([load_data('data/7/%s/' % freq + filename) for filename in person_7_filenames] , axis = 0)
    person_8 = np.concatenate([load_data('data/8/%s/' % freq + filename) for filename in person_8_filenames] , axis = 0)
        
    #============
    #训练数据分段
    person_data_0 , person_labels_0 = separate(person_0 , label = 0 , overlap_length=overlap_length)
    person_data_1 , person_labels_1 = separate(person_1 , label = 1 , overlap_length=overlap_length)
    person_data_2 , person_labels_2 = separate(person_2 , label = 2 , overlap_length=overlap_length)
    person_data_3 , person_labels_3 = separate(person_3 , label = 3 , overlap_length=overlap_length)
    person_data_4 , person_labels_4 = separate(person_4 , label = 4 , overlap_length=overlap_length)
    person_data_5 , person_labels_5 = separate(person_5 , label = 5 , overlap_length=overlap_length)
    person_data_6 , person_labels_6 = separate(person_6 , label = 6 , overlap_length=overlap_length)
    person_data_7 , person_labels_7 = separate(person_7 , label = 7 , overlap_length=overlap_length)
    person_data_8 , person_labels_8 = separate(person_8 , label = 8 , overlap_length=overlap_length)

    #合并数据
    data = np.concatenate((person_data_0 , person_data_1 , person_data_2 ,
                           person_data_3 , person_data_4 , person_data_5 ,
                           person_data_6 , person_data_7 , person_data_8 
                        ))
    
    labels = np.concatenate((person_labels_0 , person_labels_1 , person_labels_2 ,
                             person_labels_3 , person_labels_4 , person_labels_5 ,
                             person_labels_6 , person_labels_7 , person_labels_8 
                        ))
    
    return data , labels

In [7]:
def feature_extraction_dwt_meta(data , n):
    n_features = 48
    
    X = np.zeros((data.shape[0] , n_features))
    
    level = 5
    wavelet = 'db4'
    
    #n=3 or 4
    
    for i , datum in enumerate(data):
        coeffs_Pz  = pywt.wavedec(data = datum[:,0], wavelet=wavelet, level=level)
        coeffs_PO3 = pywt.wavedec(data = datum[:,1], wavelet=wavelet, level=level)
        coeffs_PO4 = pywt.wavedec(data = datum[:,2], wavelet=wavelet, level=level)
        coeffs_O1  = pywt.wavedec(data = datum[:,3], wavelet=wavelet, level=level)
        coeffs_O2  = pywt.wavedec(data = datum[:,4], wavelet=wavelet, level=level)
        coeffs_Oz  = pywt.wavedec(data = datum[:,5], wavelet=wavelet, level=level)
        coeffs_O9  = pywt.wavedec(data = datum[:,6], wavelet=wavelet, level=level)
        coeffs_FP2 = pywt.wavedec(data = datum[:,7], wavelet=wavelet, level=level)
        coeffs_C4  = pywt.wavedec(data = datum[:,8], wavelet=wavelet, level=level)
        coeffs_C6  = pywt.wavedec(data = datum[:,9], wavelet=wavelet, level=level)
        coeffs_CP3 = pywt.wavedec(data = datum[:,10], wavelet=wavelet, level=level)
        coeffs_CP1 = pywt.wavedec(data = datum[:,11], wavelet=wavelet, level=level)
        coeffs_CPZ = pywt.wavedec(data = datum[:,12], wavelet=wavelet, level=level)
        coeffs_CP2 = pywt.wavedec(data = datum[:,13], wavelet=wavelet, level=level)
        coeffs_CP4 = pywt.wavedec(data = datum[:,14], wavelet=wavelet, level=level)
        coeffs_PO8 = pywt.wavedec(data = datum[:,15], wavelet=wavelet, level=level)

        X[i , :] = np.array([
            np.std(coeffs_Pz [n]),   
            np.std(coeffs_PO3[n]),  
            np.std(coeffs_PO4[n]),   
            np.std(coeffs_O1 [n]),  
            np.std(coeffs_O2 [n]),  
            np.std(coeffs_Oz [n]),   
            np.std(coeffs_O9 [n]),  
            np.std(coeffs_FP2[n]),  
            np.std(coeffs_C4 [n]),
            np.std(coeffs_C6 [n]),
            np.std(coeffs_CP3[n]),
            np.std(coeffs_CP1[n]),
            np.std(coeffs_CPZ[n]),
            np.std(coeffs_CP2[n]),
            np.std(coeffs_CP4[n]),
            np.std(coeffs_PO8[n]),
            
            np.mean(coeffs_Pz [n]**2),
            np.mean(coeffs_PO3[n]**2),
            np.mean(coeffs_PO4[n]**2),
            np.mean(coeffs_O1 [n]**2),
            np.mean(coeffs_O2 [n]**2),
            np.mean(coeffs_Oz [n]**2),
            np.mean(coeffs_O9 [n]**2),
            np.mean(coeffs_FP2[n]**2),            
            np.mean(coeffs_C4 [n]**2),
            np.mean(coeffs_C6 [n]**2),
            np.mean(coeffs_CP3[n]**2),
            np.mean(coeffs_CP1[n]**2),
            np.mean(coeffs_CPZ[n]**2),
            np.mean(coeffs_CP2[n]**2),
            np.mean(coeffs_CP4[n]**2),
            np.mean(coeffs_PO8[n]**2),
            
            np.mean(coeffs_Pz [n]),
            np.mean(coeffs_PO3[n]), 
            np.mean(coeffs_PO4[n]),
            np.mean(coeffs_O1 [n]),
            np.mean(coeffs_O2 [n]),
            np.mean(coeffs_Oz [n]),
            np.mean(coeffs_O9 [n]),
            np.mean(coeffs_FP2[n]),        
            np.mean(coeffs_C4 [n]),
            np.mean(coeffs_C6 [n]),
            np.mean(coeffs_CP3[n]),
            np.mean(coeffs_CP1[n]),
            np.mean(coeffs_CPZ[n]),
            np.mean(coeffs_CP2[n]),
            np.mean(coeffs_CP4[n]),
            np.mean(coeffs_PO8[n])]).flatten()
        
    return X

def normalize(data , normalization_type = 'mean_std'):
    
    def _norm_mean_std(data):
        _mean = np.mean(data , axis=0)
        _std = np.std(data , axis=0)
        
        return (data - _mean) / _std
    
    def _norm_min_max(data):
        return (data - np.min(data)) / (np.max(data) - np.min(data))
    
    if normalization_type == 'mean_std':
        return _norm_mean_std(data)
    elif normalization_type == 'min_max':
        return _norm_min_max(data)
    else:
        raise Exception('wrong normalization type')
    
def feature_extraction_dwt(data , is_normalize = True):
    data_3 = feature_extraction_dwt_meta(data , 3) #4
    data_4 = feature_extraction_dwt_meta(data , 4) #5
    
    data_concat = np.concatenate((data_3 , data_4) , axis = -1)
    
    if is_normalize:
        return normalize(data_concat)
    else:
        return data_concat

In [8]:
acc_s = []

for freq in [6 , 7.5 , 8.5 , 10]:
    print('freq = ' , freq)
    acc_s = []
    
    data , labels = combine(freq) #读取数据
    data_feature = feature_extraction_dwt(data) #特征提取
        
    for t in range(20): #循环20次 取平均值 作为最终结果
        classifier = RandomForestClassifier(n_estimators=20)

        train_X , val_X , train_y , val_y = train_test_split(data_feature , labels , test_size = 0.1)

        classifier.fit(train_X , train_y)

        acc = classifier.score(val_X , val_y)
        acc_s.append(acc)
        print(acc)
        
    print('freq = ' , freq , 'average_acc = ' , np.average(acc_s) , 'var_acc = ' , np.var(acc_s))

freq =  6
1.0
0.9973474801061007
0.9960212201591512
1.0
0.9986737400530504
1.0
0.9973474801061007
1.0
0.9960212201591512
0.9973474801061007
0.9946949602122016
0.9973474801061007
0.9973474801061007
0.9973474801061007
0.9946949602122016
0.9946949602122016
0.9907161803713528
0.9920424403183024
0.9973474801061007
0.9946949602122016
freq =  6 average_acc =  0.9966843501326259 var_acc =  6.244327336433793e-06
freq =  7.5
1.0
1.0
0.993368700265252
1.0
0.9973474801061007
0.9920424403183024
0.9960212201591512
0.9986737400530504
1.0
1.0
0.9973474801061007
1.0
1.0
1.0
0.9946949602122016
1.0
1.0
0.9973474801061007
0.9973474801061007
0.9946949602122016
freq =  7.5 average_acc =  0.997944297082228 var_acc =  6.239929922816601e-06
freq =  8.5
0.9946949602122016
0.9986737400530504
0.9946949602122016
0.9960212201591512
0.9880636604774535
0.9946949602122016
0.9946949602122016
0.9986737400530504
1.0
0.9946949602122016
0.9973474801061007
0.9893899204244032
0.9946949602122016
0.9960212201591512
0.994694960

In [10]:
acc_s = []

for freq in [6 , 7.5 , 8.5 , 10]:
    print('freq = ' , freq)
    acc_s = []
    
    data , labels = combine(freq) #读取数据
    data_feature = feature_extraction_dwt(data) #特征提取
        
    for t in range(20): #循环20次 取平均值 作为最终结果
        classifier = DecisionTreeClassifier()

        train_X , val_X , train_y , val_y = train_test_split(data_feature , labels , test_size = 0.1)

        classifier.fit(train_X , train_y)

        acc = classifier.score(val_X , val_y)
        acc_s.append(acc)
        print(acc)
        
    print('freq = ' , freq , 'average_acc = ' , np.average(acc_s) , 'var_acc = ' , np.var(acc_s))

freq =  6
0.9774535809018567
0.9734748010610079
0.9681697612732095
0.9708222811671088
0.9774535809018567
0.9734748010610079
0.9761273209549072
0.9748010610079576
0.9840848806366048
0.9734748010610079
0.9708222811671088
0.980106100795756
0.9708222811671088
0.9694960212201591
0.9840848806366048
0.9787798408488063
0.9734748010610079
0.9814323607427056
0.986737400530504
0.9814323607427056
freq =  6 average_acc =  0.9763262599469495 var_acc =  2.7136439431783888e-05
freq =  7.5
0.993368700265252
0.9854111405835544
0.9840848806366048
0.9787798408488063
0.9840848806366048
0.9840848806366048
0.9814323607427056
0.9854111405835544
0.993368700265252
0.9854111405835544
0.980106100795756
0.9721485411140584
0.9708222811671088
0.9814323607427056
0.9761273209549072
0.9761273209549072
0.9814323607427056
0.9721485411140584
0.9814323607427056
0.9827586206896551
freq =  7.5 average_acc =  0.9814986737400531 var_acc =  3.5262859796382287e-05
freq =  8.5
0.9827586206896551
0.9588859416445623
0.9668435013262

In [8]:
algorithm_name = ['svm' , 'GBDT' , 'xgb' , 'lightgbm']

for idx , algorithm in enumerate([SVC , GradientBoostingClassifier , xgboost.XGBClassifier , lightgbm.LGBMClassifier]):
    
    print(algorithm_name[idx])
    
    acc_s = []

    for freq in [6 , 7.5 , 8.5 , 10]:
        print('freq = ' , freq)
        acc_s = []

        data , labels = combine(freq) #读取数据
        data_feature = feature_extraction_dwt(data) #特征提取

        for t in range(20): #循环20次 取平均值 作为最终结果
            classifier = algorithm()

            X_train, X_test, y_train, y_test = train_test_split(data_feature , labels , test_size = 0.1)

            classifier.fit(X_train , y_train)

            acc = classifier.score(X_test , y_test)
            acc_s.append(acc)
            print(acc)

        print('freq = ' , freq , 'average_acc = ' , np.average(acc_s) , 'var_acc = ' , np.var(acc_s))

svm
freq =  6
0.5106100795755968
0.5079575596816976
0.5172413793103449
0.46286472148541113
0.4827586206896552
0.506631299734748
0.473474801061008
0.4854111405835544
0.493368700265252
0.4986737400530504
0.46949602122015915
0.48010610079575594
0.4854111405835544
0.5026525198938993
0.4708222811671088
0.5053050397877984
0.5053050397877984
0.5026525198938993
0.5278514588859416
0.5092838196286472
freq =  6 average_acc =  0.49489389920424404 var_acc =  0.0002978412217070407
freq =  7.5
0.4854111405835544
0.5039787798408488
0.5119363395225465
0.4880636604774536
0.493368700265252
0.5013262599469496
0.5490716180371353
0.5159151193633952
0.48010610079575594
0.4880636604774536
0.5039787798408488
0.47745358090185674
0.4960212201591512
0.4854111405835544
0.5026525198938993
0.5092838196286472
0.5159151193633952
0.53315649867374
0.47745358090185674
0.4880636604774536
freq =  7.5 average_acc =  0.5003315649867375 var_acc =  0.0003303117238564969
freq =  8.5
0.6724137931034483
0.6511936339522546
0.63925

In [8]:
acc_s = []

for freq in [6 , 7.5 , 8.5 , 10]:
    print('freq = ' , freq)
    acc_s = []
    
    data , labels = combine(freq) #读取数据
    data_feature = feature_extraction_dwt(data) #特征提取
        
    for t in range(20): #循环20次 取平均值 作为最终结果
        classifier = GaussianNB()

        train_X , val_X , train_y , val_y = train_test_split(data_feature , labels , test_size = 0.1)

        classifier.fit(train_X , train_y)

        acc = classifier.score(val_X , val_y)
        acc_s.append(acc)
        print(acc)
        
    print('freq = ' , freq , 'average_acc = ' , np.average(acc_s) , 'var_acc = ' , np.var(acc_s))

freq =  6
0.26657824933687
0.3302387267904509
0.27984084880636606
0.29310344827586204
0.3103448275862069
0.2864721485411141
0.29045092838196285
0.27984084880636606
0.27320954907161804
0.32625994694960214
0.26790450928381965
0.2917771883289125
0.2838196286472148
0.30238726790450926
0.29973474801061006
0.3103448275862069
0.25862068965517243
0.25862068965517243
0.3275862068965517
0.2891246684350133
freq =  6 average_acc =  0.2913129973474801 var_acc =  0.0004473532846920752
freq =  7.5
0.3037135278514589
0.3129973474801061
0.3448275862068966
0.32625994694960214
0.32360742705570295
0.3395225464190981
0.33819628647214856
0.3196286472148541
0.3395225464190981
0.3169761273209549
0.34880636604774534
0.32625994694960214
0.3302387267904509
0.3328912466843501
0.32625994694960214
0.35013262599469497
0.33421750663129973
0.3275862068965517
0.3063660477453581
0.34084880636604775
freq =  7.5 average_acc =  0.32944297082228113 var_acc =  0.00016347824863328382
freq =  8.5
0.32360742705570295
0.32493368

In [8]:
acc_s = []

for freq in [6 , 7.5 , 8.5 , 10]:
    print('freq = ' , freq)
    acc_s = []
    
    data , labels = combine(freq) #读取数据
    data_feature = feature_extraction_dwt(data) #特征提取
        
    for t in range(20): #循环20次 取平均值 作为最终结果
        classifier = KNeighborsClassifier()

        train_X , val_X , train_y , val_y = train_test_split(data_feature , labels , test_size = 0.1)

        classifier.fit(train_X , train_y)

        acc = classifier.score(val_X , val_y)
        acc_s.append(acc)
        print(acc)
        
    print('freq = ' , freq , 'average_acc = ' , np.average(acc_s) , 'var_acc = ' , np.var(acc_s))

freq =  6
0.7652519893899205
0.7546419098143236
0.7811671087533156
0.7572944297082228
0.7931034482758621
0.7652519893899205
0.76657824933687
0.7838196286472149
0.773209549071618
0.7612732095490716
0.7811671087533156
0.7519893899204244
0.7771883289124668
0.7612732095490716
0.753315649867374
0.7811671087533156
0.7811671087533156
0.7824933687002652
0.7692307692307693
0.7718832891246684
freq =  6 average_acc =  0.7706233421750663 var_acc =  0.00013235775246431058
freq =  7.5
0.7970822281167109
0.7997347480106101
0.7877984084880637
0.8037135278514589
0.7745358090185677
0.7851458885941645
0.8143236074270557
0.7811671087533156
0.7785145888594165
0.7652519893899205
0.773209549071618
0.7785145888594165
0.7877984084880637
0.8037135278514589
0.7944297082228117
0.8116710875331565
0.7811671087533156
0.8050397877984085
0.7891246684350133
0.7851458885941645
freq =  7.5 average_acc =  0.7898541114058355 var_acc =  0.00017193447501917262
freq =  8.5
0.7970822281167109
0.8209549071618037
0.8050397877984