In [1]:
import numpy as np
import pywt

import seaborn as sns #绘制confusion matrix heatmap

import os
import scipy.io as sio

from statsmodels.tsa.ar_model import AR

import tqdm
import  time

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
import warnings
warnings.simplefilter('ignore') #忽略警告

In [3]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.neural_network import MLPClassifier

from sklearn.decomposition import PCA
from sklearn.cluster import KMeans

import xgboost

In [4]:
sample_rate = 256
origin_channel = 16


SAMPLE_CHANNEL = ['Pz' , 'PO3' , 'PO4' , 'O1' , 'O2' , 'Oz' , 'O9' , 'FP2' ,
                  'C4' , 'C6' , 'CP3' , 'CP1' ,
                  'CPZ' , 'CP2' , 'CP4' , 'PO8']

LABEL2STR = {0:'sen' , 1:'hong' , 2:'zhao',
             3:'fen' , 4:'xiao' , 5:'yu' , 
             6:'bin' , 7:'wang' , 8:'wei' , 
             9:'fei'}

CLIP_FORWARD = 1 #首部裁掉时间
CLIP_BACKWARD = 1 #尾部裁掉时间

trial_time = 3 #segment second


#是否进行归一化
#reference:a study on performance increasing in ssvep based bci application
#IS_NORMALIZE = True

#是否进行滤波
#IS_FILTER = False
#EEG频率范围
#reference:a study on performance increasing in ssvep based bci application
LO_FREQ = 0.5
HI_FREQ = 40

#是否陷波
#IS_NOTCH = False
NOTCH_FREQ = 50 #陷波 工频


In [5]:
from keras.utils import to_categorical

Using TensorFlow backend.


In [6]:
def load_data(filename):

    data = sio.loadmat(file_name=filename)['data_received'] #length*16 matrix

    data = data[CLIP_FORWARD * sample_rate : - CLIP_BACKWARD * sample_rate] #首部 尾部 进行裁剪
   
    return data 

In [7]:
def separate(data , label , overlap_length):
    train_data = []
    train_labels = []

    size = sample_rate * trial_time #一小段 256*3 个数据点
    data_length = data.shape[0]

    idx = 0
    
    while idx<=data_length-size:
        train_data.append(data[idx : idx+size , :])
        train_labels.append(label)

        idx = idx + (size - overlap_length)
        
    return np.array(train_data) , np.array(train_labels)

In [8]:
def shuffle_t_v(filenames):
    # np.random.shuffle(filenames)
    
    return np.random.choice(filenames , size=10) #20次的计算准确率中 每次随机选择10个样本进行训练测试

def combine(freq):    
    overlap_length = 2*256 #重叠2秒数据
    
    #保证随机性 进行置乱
    person_0_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/0/' % freq) )
    person_1_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/1/' % freq) )
    person_2_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/2/' % freq) )
    person_3_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/3/' % freq) )
    person_4_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/4/' % freq) )
    person_5_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/5/' % freq) )
    person_6_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/6/' % freq) )
    person_7_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/7/' % freq) )
    person_8_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/8/' % freq) )

    #打开信号文件 并 合并
    person_0 = np.concatenate([load_data('../incremental/data/base_rf/%s/0/' % freq + filename) for filename in person_0_filenames] , axis = 0)
    person_1 = np.concatenate([load_data('../incremental/data/base_rf/%s/1/' % freq + filename) for filename in person_1_filenames] , axis = 0)
    person_2 = np.concatenate([load_data('../incremental/data/base_rf/%s/2/' % freq + filename) for filename in person_2_filenames] , axis = 0)
    person_3 = np.concatenate([load_data('../incremental/data/base_rf/%s/3/' % freq + filename) for filename in person_3_filenames] , axis = 0)
    person_4 = np.concatenate([load_data('../incremental/data/base_rf/%s/4/' % freq + filename) for filename in person_4_filenames] , axis = 0)
    person_5 = np.concatenate([load_data('../incremental/data/base_rf/%s/5/' % freq + filename) for filename in person_5_filenames] , axis = 0)
    person_6 = np.concatenate([load_data('../incremental/data/base_rf/%s/6/' % freq + filename) for filename in person_6_filenames] , axis = 0)
    person_7 = np.concatenate([load_data('../incremental/data/base_rf/%s/7/' % freq + filename) for filename in person_7_filenames] , axis = 0)
    person_8 = np.concatenate([load_data('../incremental/data/base_rf/%s/8/' % freq + filename) for filename in person_8_filenames] , axis = 0)
    
    #============
    #训练数据分段
    train_person_data_0 , train_person_labels_0 = separate(person_0 , label = 0 , overlap_length=overlap_length)
    train_person_data_1 , train_person_labels_1 = separate(person_1 , label = 1 , overlap_length=overlap_length)
    train_person_data_2 , train_person_labels_2 = separate(person_2 , label = 2 , overlap_length=overlap_length)
    train_person_data_3 , train_person_labels_3 = separate(person_3 , label = 3 , overlap_length=overlap_length)
    train_person_data_4 , train_person_labels_4 = separate(person_4 , label = 4 , overlap_length=overlap_length)
    train_person_data_5 , train_person_labels_5 = separate(person_5 , label = 5 , overlap_length=overlap_length)
    train_person_data_6 , train_person_labels_6 = separate(person_6 , label = 6 , overlap_length=overlap_length)
    train_person_data_7 , train_person_labels_7 = separate(person_7 , label = 7 , overlap_length=overlap_length)
    train_person_data_8 , train_person_labels_8 = separate(person_8 , label = 8 , overlap_length=overlap_length)

    #合并数据
    train_data = np.concatenate((train_person_data_0 , train_person_data_1 , train_person_data_2 ,
                                 train_person_data_3 , train_person_data_4 , train_person_data_5 ,
                                 train_person_data_6 , train_person_data_7 , train_person_data_8 ,
                                 ))
    
    train_labels = np.concatenate((train_person_labels_0 , train_person_labels_1 , train_person_labels_2 ,
                                   train_person_labels_3 , train_person_labels_4 , train_person_labels_5 ,
                                   train_person_labels_6 , train_person_labels_7 , train_person_labels_8 ,
                                    ))
    
    #产生索引并置乱
    idx_train_data = list(range(train_data.shape[0]))
    np.random.shuffle(idx_train_data)

    #将训练数据置乱
    train_data = train_data[idx_train_data]
    train_labels = train_labels[idx_train_data]
        
    return train_data , train_labels


In [10]:
def session_data_labels(session_id , freq , is_training):
    if is_training:
        overlap_length = 256*2
    else:
        overlap_length = 0
    
    str_freq = str(freq)
    
    subjcets = os.listdir('../incremental/data/incremental/%s/s%d/' % (str_freq , session_id)) #受试者ID
    
    data = []
    labels = []
    
    for subjcet in subjcets:
        filenames = os.listdir('../incremental/data/incremental/%s/s%d/%s/' % (str_freq , session_id , subjcet))
        
        person = np.concatenate([load_data('../incremental/data/incremental/%s/s%d/%s/%s' % (str_freq , session_id , subjcet , filename)) for filename in filenames] , axis = 0)
        
        person_data , person_label = separate( person , label = int(subjcet) , overlap_length = overlap_length)
        
        data.append(person_data)
        labels.append(person_label)
    
    #合并数据
    data = np.concatenate(data)
    labels = np.concatenate(labels)
    
    #shuffle
    idx_data = list(range(data.shape[0]))
    np.random.shuffle(idx_data)

    data = data[idx_data]
    labels = labels[idx_data]
    
    return data , labels


In [14]:
def feature_extraction_RMS(data):
    def rms(datum):
        '''
        :datum: 一段信号 shape : 768 * 16
        '''
        return [ np.sqrt(np.mean(np.square( d ))) for d in datum.T ]
    
    feature_rms = []
    
    for datum in data: 
        feature_rms.append(rms(datum))
    
    return np.array(feature_rms)

In [12]:
def concat_and_shuffle(orig_X , orig_y , session_id , freq):
    session_id_data , session_id_labels = session_data_labels(session_id , freq , is_training=True)
    session_id_data = feature_extraction_ar(session_id_data)
    # session_id_labels = to_categorical(session_id_labels , num_classes=9)
    
    orig_X = np.concatenate((orig_X , session_id_data) , axis=0)
    orig_y = np.concatenate((orig_y , session_id_labels) , axis=0)
    
    idx = list(range(orig_X.shape[0]))
    np.random.shuffle(idx)
    
    orig_X = orig_X[idx]
    orig_y = orig_y[idx]
    
    return orig_X , orig_y

# PCA降维

In [13]:
def get_center(data , label):
    centers = []
    
    for label_id in range(9): #一共9个受试者
        equal_idx = np.where(label == label_id)
    
        center = np.mean(data[equal_idx] , axis = 0)
        centers.append(center)
        
    return np.array(centers)

def get_center_simple(data):
    '''
    计算单个用户的脑电的中心
    '''
    return np.mean(data , axis=0)

In [15]:
all_centers = []
all_time = []

for freq in [6 , 7.5 , 8.5 , 10]:
    sub_centers = []
    sub_time = []
    
    print('freq : ' , freq)
    
    for session_id in [1 , 3 , 5 , 6 , 7 , 8 , 9 , 11 , 12 , 13]:
        
        X , y = session_data_labels(session_id , freq , is_training=True)
        X_sbp = feature_extraction_RMS(X)
        
        pca = PCA(n_components=2)
        start = time.clock()
        X_pca = pca.fit_transform(X_sbp)
        time1 = time.clock() - start
        
        kmeans = KMeans(n_clusters=9)
        start = time.clock()
        _ = kmeans.fit_transform(X_pca)
        time2 = time.clock() - start
        
        sub_centers.append(kmeans.cluster_centers_)
        sub_time.append(time1 + time2)
        
        print('session : ' , session_id , kmeans.cluster_centers_ , time1 + time2)
    
    all_centers.append(sub_centers)
    all_time.append(sub_time)
    


freq :  6
session :  1 [[ 10010.419   -4230.933 ]
 [ 48636.555    3198.599 ]
 [-16749.674    3082.0154]
 [ -7521.4233  -2716.4714]
 [ -5313.1445  18285.004 ]
 [ -5103.003  -15641.252 ]
 [  4645.6934 -15693.303 ]
 [-19978.266   -1479.0504]
 [ -3248.2925   -224.5728]] 0.10130159999999999
session :  3 [[-1.1086482e+05 -1.2722072e+04]
 [ 8.7625075e+05  3.9666547e+02]
 [-1.2930768e+05  8.8162275e+03]
 [-6.9935992e+04 -7.5937616e+02]
 [-8.4653727e+04 -1.1317825e+04]
 [-1.0002624e+05  2.0964303e+04]
 [-1.1739029e+05 -5.7423096e+03]
 [ 8.6270756e+05  7.7370148e+02]
 [-1.1413672e+05  1.8936217e+04]] 0.05342350000000007
session :  5 [[-37208.484  -23244.922 ]
 [197092.86    -7218.011 ]
 [-11680.661   28292.348 ]
 [-23386.14     4150.3994]
 [-31626.629  -12890.508 ]
 [ -5761.245   37323.125 ]
 [-14019.395   21071.229 ]
 [-30075.668   -5626.795 ]
 [-31527.512  -17591.283 ]] 0.05576400000000015
session :  6 [[-16869.49   -8411.644]
 [158905.12   17630.44 ]
 [-42015.74  109952.76 ]
 [-22492.465 -311

session :  13 [[ -7789.5283 -13885.631 ]
 [ 82256.97    -2301.4983]
 [ -9219.157   42118.92  ]
 [-14140.697   -6390.0186]
 [ -9487.165   16576.344 ]
 [-10550.521  -23179.678 ]
 [-18159.055     167.5012]
 [ 97083.35     1045.0869]
 [ -6621.1943  -5080.2617]] 0.07316969999999756
freq :  10
session :  1 [[ -9302.281    9358.878 ]
 [ 67134.48    -9615.407 ]
 [-18624.572  -16730.826 ]
 [  1268.5735  45104.74  ]
 [  1733.7273  -2390.3757]
 [-12894.699   -3169.2358]
 [-10074.35   -18801.066 ]
 [ -3157.382    6269.955 ]
 [-15773.902  -12611.174 ]] 0.06489899999999693
session :  3 [[-28127.28     4861.327 ]
 [ 29153.604  -11050.339 ]
 [ 77723.02    -4007.5789]
 [ 18351.63    26126.31  ]
 [-16483.838   -4077.8218]
 [-39738.31     1646.167 ]
 [-12386.646    4684.975 ]
 [-20766.662  -15163.515 ]
 [ -5692.335   -2773.9329]] 0.05827729999999676
session :  5 [[-32770.043  -18794.709 ]
 [189497.69    -6726.6606]
 [-17450.646   26168.758 ]
 [-28735.492   -7717.9097]
 [ -6093.433   34902.586 ]
 [-16188.