In [1]:
import numpy as np
import pywt

import seaborn as sns #绘制confusion matrix heatmap

import os
import scipy.io as sio

from statsmodels.tsa.ar_model import AR

import tqdm
import time

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
import warnings
warnings.simplefilter('ignore') #忽略警告

In [3]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.neural_network import MLPClassifier

from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.mixture import GaussianMixture
from sklearn.cluster import DBSCAN

import xgboost

In [4]:
sample_rate = 256
origin_channel = 16


SAMPLE_CHANNEL = ['Pz' , 'PO3' , 'PO4' , 'O1' , 'O2' , 'Oz' , 'O9' , 'FP2' ,
                  'C4' , 'C6' , 'CP3' , 'CP1' ,
                  'CPZ' , 'CP2' , 'CP4' , 'PO8']

LABEL2STR = {0:'sen' , 1:'hong' , 2:'zhao',
             3:'fen' , 4:'xiao' , 5:'yu' , 
             6:'bin' , 7:'wang' , 8:'wei' , 
             9:'fei'}

CLIP_FORWARD = 1 #首部裁掉时间
CLIP_BACKWARD = 1 #尾部裁掉时间

trial_time = 3 #segment second


#是否进行归一化
#reference:a study on performance increasing in ssvep based bci application
#IS_NORMALIZE = True

#是否进行滤波
#IS_FILTER = False
#EEG频率范围
#reference:a study on performance increasing in ssvep based bci application
LO_FREQ = 0.5
HI_FREQ = 40

#是否陷波
#IS_NOTCH = False
NOTCH_FREQ = 50 #陷波 工频


In [5]:
from keras.utils import to_categorical

Using TensorFlow backend.


In [6]:
def load_data(filename):

    data = sio.loadmat(file_name=filename)['data_received'] #length*16 matrix

    data = data[CLIP_FORWARD * sample_rate : - CLIP_BACKWARD * sample_rate] #首部 尾部 进行裁剪
   
    return data 

In [7]:
def separate(data , label , overlap_length):
    train_data = []
    train_labels = []

    size = sample_rate * trial_time #一小段 256*3 个数据点
    data_length = data.shape[0]

    idx = 0
    
    while idx<=data_length-size:
        train_data.append(data[idx : idx+size , :])
        train_labels.append(label)

        idx = idx + (size - overlap_length)
        
    return np.array(train_data) , np.array(train_labels)

In [8]:
def shuffle_t_v(filenames):
    # np.random.shuffle(filenames)
    
    return np.random.choice(filenames , size=10) #20次的计算准确率中 每次随机选择10个样本进行训练测试

def combine(freq):    
    overlap_length = 2*256 #重叠2秒数据
    
    #保证随机性 进行置乱
    person_0_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/0/' % freq) )
    person_1_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/1/' % freq) )
    person_2_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/2/' % freq) )
    person_3_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/3/' % freq) )
    person_4_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/4/' % freq) )
    person_5_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/5/' % freq) )
    person_6_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/6/' % freq) )
    person_7_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/7/' % freq) )
    person_8_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/8/' % freq) )

    #打开信号文件 并 合并
    person_0 = np.concatenate([load_data('../incremental/data/base_rf/%s/0/' % freq + filename) for filename in person_0_filenames] , axis = 0)
    person_1 = np.concatenate([load_data('../incremental/data/base_rf/%s/1/' % freq + filename) for filename in person_1_filenames] , axis = 0)
    person_2 = np.concatenate([load_data('../incremental/data/base_rf/%s/2/' % freq + filename) for filename in person_2_filenames] , axis = 0)
    person_3 = np.concatenate([load_data('../incremental/data/base_rf/%s/3/' % freq + filename) for filename in person_3_filenames] , axis = 0)
    person_4 = np.concatenate([load_data('../incremental/data/base_rf/%s/4/' % freq + filename) for filename in person_4_filenames] , axis = 0)
    person_5 = np.concatenate([load_data('../incremental/data/base_rf/%s/5/' % freq + filename) for filename in person_5_filenames] , axis = 0)
    person_6 = np.concatenate([load_data('../incremental/data/base_rf/%s/6/' % freq + filename) for filename in person_6_filenames] , axis = 0)
    person_7 = np.concatenate([load_data('../incremental/data/base_rf/%s/7/' % freq + filename) for filename in person_7_filenames] , axis = 0)
    person_8 = np.concatenate([load_data('../incremental/data/base_rf/%s/8/' % freq + filename) for filename in person_8_filenames] , axis = 0)
    
    #============
    #训练数据分段
    train_person_data_0 , train_person_labels_0 = separate(person_0 , label = 0 , overlap_length=overlap_length)
    train_person_data_1 , train_person_labels_1 = separate(person_1 , label = 1 , overlap_length=overlap_length)
    train_person_data_2 , train_person_labels_2 = separate(person_2 , label = 2 , overlap_length=overlap_length)
    train_person_data_3 , train_person_labels_3 = separate(person_3 , label = 3 , overlap_length=overlap_length)
    train_person_data_4 , train_person_labels_4 = separate(person_4 , label = 4 , overlap_length=overlap_length)
    train_person_data_5 , train_person_labels_5 = separate(person_5 , label = 5 , overlap_length=overlap_length)
    train_person_data_6 , train_person_labels_6 = separate(person_6 , label = 6 , overlap_length=overlap_length)
    train_person_data_7 , train_person_labels_7 = separate(person_7 , label = 7 , overlap_length=overlap_length)
    train_person_data_8 , train_person_labels_8 = separate(person_8 , label = 8 , overlap_length=overlap_length)

    #合并数据
    train_data = np.concatenate((train_person_data_0 , train_person_data_1 , train_person_data_2 ,
                                 train_person_data_3 , train_person_data_4 , train_person_data_5 ,
                                 train_person_data_6 , train_person_data_7 , train_person_data_8 ,
                                 ))
    
    train_labels = np.concatenate((train_person_labels_0 , train_person_labels_1 , train_person_labels_2 ,
                                   train_person_labels_3 , train_person_labels_4 , train_person_labels_5 ,
                                   train_person_labels_6 , train_person_labels_7 , train_person_labels_8 ,
                                    ))
    
    #产生索引并置乱
    idx_train_data = list(range(train_data.shape[0]))
    np.random.shuffle(idx_train_data)

    #将训练数据置乱
    train_data = train_data[idx_train_data]
    train_labels = train_labels[idx_train_data]
        
    return train_data , train_labels


In [9]:
def session_data_labels(session_id , freq , is_training):
    if is_training:
        overlap_length = 256*2
    else:
        overlap_length = 0
    
    str_freq = str(freq)
    
    subjcets = os.listdir('../incremental/data/incremental/%s/s%d/' % (str_freq , session_id)) #受试者ID
    
    data = []
    labels = []
    
    for subjcet in subjcets:
        filenames = os.listdir('../incremental/data/incremental/%s/s%d/%s/' % (str_freq , session_id , subjcet))
        
        person = np.concatenate([load_data('../incremental/data/incremental/%s/s%d/%s/%s' % (str_freq , session_id , subjcet , filename)) for filename in filenames] , axis = 0)
        
        person_data , person_label = separate( person , label = int(subjcet) , overlap_length = overlap_length)
        
        data.append(person_data)
        labels.append(person_label)
    
    #合并数据
    data = np.concatenate(data)
    labels = np.concatenate(labels)
    
    #shuffle
    idx_data = list(range(data.shape[0]))
    np.random.shuffle(idx_data)

    data = data[idx_data]
    labels = labels[idx_data]
    
    return data , labels


In [10]:
def ar(data):
    X = []
    
    for i in range(16):
        X.append( AR(data[: , i]).fit().params.reshape([1 , -1]).squeeze() )
        
    return np.array(X).reshape([1 , -1]).squeeze()


def feature_extraction_ar(data):
    X = []
    
    for datum in data:
        X.append( ar(datum) )
    
    return np.array(X)

In [11]:
def concat_and_shuffle(orig_X , orig_y , session_id , freq):
    session_id_data , session_id_labels = session_data_labels(session_id , freq , is_training=True)
    session_id_data = feature_extraction_ar(session_id_data)
    # session_id_labels = to_categorical(session_id_labels , num_classes=9)
    
    orig_X = np.concatenate((orig_X , session_id_data) , axis=0)
    orig_y = np.concatenate((orig_y , session_id_labels) , axis=0)
    
    idx = list(range(orig_X.shape[0]))
    np.random.shuffle(idx)
    
    orig_X = orig_X[idx]
    orig_y = orig_y[idx]
    
    return orig_X , orig_y

In [12]:
def get_center(data , label):
    centers = []
    
    for label_id in range(9): #一共9个受试者
        equal_idx = np.where(label == label_id)
    
        center = np.mean(data[equal_idx] , axis = 0)
        centers.append(center)
        
    return np.array(centers)

def get_center_simple(data):
    '''
    计算单个用户的脑电的中心
    '''
    return np.mean(data , axis=0)

# PCA降维

## kmeans聚类

In [18]:
for freq in [6 , 7.5 , 8.5 , 10]:
    print('freq : ' , freq)
    
    for session_id in [1 , 3 , 5 , 6 , 7 , 8 , 9 , 11 , 12 , 13]:
        
        X , y = session_data_labels(session_id , freq , is_training=True)
        X_sbp = feature_extraction_ar(X)
        
        pca = PCA(n_components=2)
        start = time.clock()
        X_pca = pca.fit_transform(X_sbp)
        time1 = time.clock() - start
        
        kmeans = KMeans(n_clusters=9)
        start = time.clock()
        _ = kmeans.fit_transform(X_pca)
        time2 = time.clock() - start
        
        print('session : ' , session_id , kmeans.cluster_centers_ , time1 + time2)

freq :  6
session :  1 [[-3.32194478e+03 -4.02106409e+02]
 [ 1.14215906e+04 -4.72063580e+03]
 [ 9.93004056e+03  5.81818049e+03]
 [ 5.26060386e+03 -2.22674282e+03]
 [ 3.28877188e+03  5.38241052e+03]
 [ 2.14123414e+04 -8.16472409e+03]
 [ 3.28138342e+02  1.14909103e+01]
 [ 1.80884030e+04  3.10979946e+03]
 [ 6.52481918e+03  1.28175908e+04]] 0.10313789999999999
session :  3 [[-1.04022409e+04  1.01732014e+01]
 [ 6.71446771e+04 -1.01928614e+04]
 [ 1.61514503e+05  6.62199656e+04]
 [ 1.36421845e+05  2.19759943e+05]
 [ 3.41625903e+04  2.09543131e+03]
 [-1.80992637e+03  3.28491421e+02]
 [ 1.21051765e+05  1.02849499e+04]
 [ 1.03248987e+05 -3.03781552e+04]
 [ 1.82625025e+05 -4.18484889e+04]] 0.07418529999999635
session :  5 [[-9.47196607e+03  4.98699164e+01]
 [ 7.00085499e+04 -1.65301389e+04]
 [ 1.34857724e+05  5.05945247e+03]
 [ 3.52437535e+04 -1.05477824e+04]
 [ 1.80095498e+05  1.07549143e+04]
 [ 1.28469759e+03 -2.11059381e+04]
 [ 2.99590898e+05  6.37945074e+04]
 [ 1.05895932e+05  8.74617658e+02]

session :  7 [[ -4780.07268045   5162.31921737]
 [574492.98699266  59951.47021534]
 [ 14666.29046675  -4930.07038399]
 [113115.97848985  -5030.59681704]
 [ 61236.05542771 -14076.68813682]
 [ -4221.42592218 -20881.94160065]
 [ 31565.27060567 -22114.95910727]
 [ -4064.5899586   -5674.20733109]
 [ 16802.31257759 -32370.2298477 ]] 0.07835189999991599
session :  8 [[ -2755.49129202    516.49989768]
 [ 29714.47107961  -6201.30126541]
 [  8080.49420452   1692.02373111]
 [ -4750.10855439  -3423.63761052]
 [  3128.68050341  11112.39272819]
 [ 43536.97914656 -10992.21125105]
 [   170.66066001   5027.97843145]
 [ 20951.51321481  -1106.97682206]
 [ 16836.28190923   7900.95436583]] 0.08727890000000116
session :  9 [[ -4283.86043073    262.3909538 ]
 [ 79831.5894439    -794.32847   ]
 [ -5257.06637075  36947.90721795]
 [114951.56480883   1854.59229813]
 [ 47415.01051706   1655.3615358 ]
 [ -4443.12969296  -5149.47874151]
 [179965.79528761  -8327.60122986]
 [ -4494.15991212  19265.94503106]
 [ 21895.

## DBSCAN聚类

In [13]:
for freq in [6 , 7.5 , 8.5 , 10]:
    print('freq : ' , freq)
    
    for session_id in [1 , 3 , 5 , 6 , 7 , 8 , 9 , 11 , 12 , 13]:
        
        X , y = session_data_labels(session_id , freq , is_training=True)
        X_sbp = feature_extraction_ar(X)
        
        pca = PCA(n_components=2)
        start = time.clock()
        X_pca = pca.fit_transform(X_sbp)
        time1 = time.clock() - start
        
        dbscan = DBSCAN()
        start = time.clock()
        res = dbscan.fit_predict(X_pca)
        time2 = time.clock() - start
        
        _centers = get_center(X_pca , res)
        
        print('session : ' , session_id , _centers , time1 + time2)

freq :  6
session :  1 [[ 9809.30637055  5873.997126  ]
 [-3309.90182733  -396.18234087]
 [ 5541.31430907 -2332.18747607]
 [20737.07835662 -7529.41538644]
 [ 3238.93097603  5349.80016133]
 [  486.28901659   -60.47010689]
 [ 6524.81918274 12817.59079562]
 [17927.16905739  3663.26547372]
 [11240.12917738 -4617.62376199]] 0.1278658
session :  3 [[-1.68374150e+03  2.79139287e+02]
 [ 6.56552019e+04 -1.03255106e+04]
 [ 1.36421845e+05  2.19759943e+05]
 [ 1.22222721e+05  1.35985147e+04]
 [-1.02854305e+04  3.86199803e+01]
 [ 3.29287421e+04  3.08175132e+03]
 [ 1.82625025e+05 -4.18484888e+04]
 [ 1.61514503e+05  6.62199656e+04]
 [ 1.04275372e+05 -2.83989441e+04]] 0.1446489999999585
session :  5 [[-9.29221834e+03  1.15737434e+04]
 [ 1.80095498e+05  1.07549143e+04]
 [ 8.56339469e+04 -5.36079461e+03]
 [-9.49745890e+03  2.36227544e+01]
 [ 3.00815194e+04 -5.20802596e+03]
 [ 1.25990931e+05  4.63150406e+03]
 [ 2.11728512e+03 -2.16823413e+04]
 [ 2.99590898e+05  6.37945074e+04]
 [ 5.66778084e+04 -2.2469838

session :  7 [[ -4780.46129847   5186.02399512]
 [574492.98699266  59951.47021534]
 [ 37867.91417967 -21207.53011084]
 [ -4082.17463024 -20827.56721423]
 [113115.97848985  -5030.59681704]
 [ 18742.6506954  -31551.42501566]
 [ -4085.75402002  -5456.99420003]
 [ 14666.29046675  -4930.07038399]
 [ 64334.77212981 -10780.44670512]] 0.12536610000006476
session :  8 [[ -2804.06324679    631.95672497]
 [ 29161.69039388  -6423.13806508]
 [ -4750.10855439  -3423.63761052]
 [  3594.55988411  11425.7557516 ]
 [ 13631.89054594   3450.57715726]
 [ 22101.44360354    -86.95807366]
 [ 43536.97914656 -10992.21125105]
 [  3229.54901492   2387.49664641]
 [  -235.05071275   5993.47629266]] 0.09261320000041451
session :  9 [[ -4452.6930101   -5055.02006055]
 [ 79831.5894439    -794.32847   ]
 [ -4381.91535259  19667.11469854]
 [179965.79528761  -8327.60122986]
 [ 47415.01051706   1655.3615358 ]
 [ -5508.04703892  37172.80897328]
 [ -4260.55404453    532.39486546]
 [114951.56480883   1854.59229813]
 [ 21895.

## GMM聚类

In [15]:
for freq in [6 , 7.5 , 8.5 , 10]:
    print('freq : ' , freq)
    
    for session_id in [1 , 3 , 5 , 6 , 7 , 8 , 9 , 11 , 12 , 13]:
        
        X , y = session_data_labels(session_id , freq , is_training=True)
        X_sbp = feature_extraction_ar(X)
        
        pca = PCA(n_components=2)
        start = time.clock()
        X_pca = pca.fit_transform(X_sbp)
        time1 = time.clock() - start
        
        gmm = GaussianMixture(n_components=9)
        start = time.clock()
        res = gmm.fit_predict(X_pca)
        time2 = time.clock() - start
        
        _centers = get_center(X_pca , res)
        
        print('session : ' , session_id , _centers , time1 + time2)

freq :  6
session :  1 [[-3.39862897e+03 -4.06767107e+02]
 [ 4.77477273e+03  2.00046935e+03]
 [ 8.63971678e+03 -3.78376692e+03]
 [-3.24232174e+03 -1.74824676e+02]
 [ 7.80161435e+02 -1.16362253e+03]
 [ 1.72622374e+04  2.88182805e+03]
 [-2.03494219e+03 -1.77499543e+01]
 [ 2.60107543e+04 -1.05975297e+04]
 [ 4.19019239e+03  8.13681021e+03]] 0.15496879999994917
session :  3 [[-12579.75053037  -2380.69232183]
 [128390.36928909  28941.47381467]
 [ -3109.85361538    421.6646709 ]
 [ 62498.58671562  -8834.48276055]
 [136421.84501787 219759.94297658]
 [107053.67205569 -22691.04846554]
 [ -9690.89723548    765.74628505]
 [189555.66356956 -56353.14586549]
 [ 12342.61080094   3057.64816841]] 0.07475379999999632
session :  5 [[ -9443.80760631    792.13318809]
 [125990.93061813   4631.5040574 ]
 [ 51022.60749626 -15547.69420163]
 [177327.87613477   7654.91851007]
 [  9883.18270769 -17129.8423624 ]
 [ -9224.98824967  11282.80581395]
 [ 86946.89203298  -8588.50085442]
 [255065.11548806  54324.68777835]

session :  7 [[ -4414.44003084   6437.79010644]
 [574492.98699266  59951.47021534]
 [  9808.93865958 -14037.94922931]
 [ -4404.54417337 -13630.84264661]
 [ 78136.58502197 -10246.14608743]
 [ 53176.85159028 -16700.57485626]
 [ -4743.05645374   4902.81871703]
 [ -5899.67453129   2007.413237  ]
 [ 25872.77054181 -21157.37531718]] 0.11030229999983021
session :  8 [[ -4844.35319847  -2523.79668928]
 [ 26882.04335898 -10301.04287808]
 [ -1267.42332533   3355.44321682]
 [ 10372.17443347    484.83763788]
 [  3434.87899946   6663.69269472]
 [ -3794.48441903  -3926.0153343 ]
 [ 21155.60184907    868.15195831]
 [ 34203.06821857  -1730.48956182]
 [ 43214.21207039 -10709.74048969]] 0.10082899999997608
session :  9 [[ -4369.77337538  -5411.18051681]
 [ 44583.78376917   1989.98757061]
 [ 82060.08284574   -305.27109422]
 [ -6459.75559785   9968.38572335]
 [ 12410.02787794   2308.90431628]
 [ -5433.32090189  41492.1710512 ]
 [179965.79528761  -8327.60122986]
 [ -4781.90590015  -1760.01149923]
 [ -4767.

# LDA降维

## kmeans聚类

In [19]:
for freq in [6 , 7.5 , 8.5 , 10]:    
    print('freq : ' , freq)
    
    for session_id in [1 , 3 , 5 , 6 , 7 , 8 , 9 , 11 , 12 , 13]:
        
        X , y = session_data_labels(session_id , freq , is_training=True)
        X_sbp = feature_extraction_ar(X)
        
        lda = LinearDiscriminantAnalysis(n_components=2)
        start = time.clock()
        X_lda = lda.fit_transform(X_sbp , y)
        time1 = time.clock() - start
        
        kmeans = KMeans(n_clusters=9)
        start = time.clock()
        _ = kmeans.fit_transform(X_lda)
        time2 = time.clock() - start
        
        print('session : ' , session_id , kmeans.cluster_centers_ , time1 + time2)

freq :  6
session :  1 [[ -7.40633606  -7.26570513]
 [ 32.51609351  10.55445111]
 [-19.46164484  19.4182442 ]
 [ -0.80503718   4.38237822]
 [  1.60063255 -22.00879584]
 [  0.9974996   -0.94014608]
 [ -3.67217993 -12.59894815]
 [ -3.80257775   6.85482548]
 [ -0.08225505   1.90684256]] 0.13290470000000001
session :  3 [[ -3.13032852  21.55771325]
 [ -8.91596997  -5.64628468]
 [ 60.04838805  -4.69626783]
 [-10.06750424 -36.12761009]
 [ -9.5348167    2.84328622]
 [  2.90308586  15.4385979 ]
 [ -6.58160884  -0.62123261]
 [-11.4765645    4.93588448]
 [-13.0125369    2.15002548]] 0.1211909999999996
session :  5 [[ 50.91880319  -1.74586357]
 [ -5.27691605   3.9516469 ]
 [ -6.27702049 -11.99550483]
 [ -5.95916174   9.4676117 ]
 [ -7.85982533 -21.50724702]
 [ -5.83531296  -5.63947138]
 [ -5.96727382  17.02481125]
 [-11.06641991  -2.17924767]
 [ -2.57883357  12.6610662 ]] 0.11396419999999807
session :  6 [[ 2.28276229e+01  8.03081238e-01]
 [-2.93655694e+01  4.22671927e+01]
 [-9.83991272e+00 -1.46

session :  12 [[ -0.9062761   -3.19982938]
 [  4.88592194  34.98393546]
 [ 49.51707458 -14.89709136]
 [-22.01403924 -30.34987191]
 [-10.21090705  14.74542234]
 [-18.15812274  -3.55401288]
 [ -0.39200512   4.38251547]
 [ -1.57267591   2.03374342]
 [  4.27317409  32.87912387]] 0.11966240000003836
session :  13 [[ 4.71657655e+01 -1.18812472e+01]
 [-1.10623696e-01  5.33311210e+00]
 [-1.19176440e+01 -2.38290855e+01]
 [ 8.67816526e+00  3.29547515e+01]
 [-2.46342847e+01  1.44182307e-02]
 [-3.82274369e+00 -1.72752221e+01]
 [-7.22218480e+00  5.40210409e+00]
 [-2.69529455e+00  5.42226859e+00]
 [-5.18162412e+00  3.74793569e+00]] 0.118233000000032
freq :  10
session :  1 [[ -2.70934126  -6.76936505]
 [ 45.72709419   8.7795833 ]
 [-25.06765414  27.15170332]
 [ -4.17839619 -44.41495376]
 [ 13.21319958   8.95720858]
 [ -5.56123351   3.65961372]
 [ -3.40977107   0.43690638]
 [ -9.77256753   2.15657844]
 [ -8.21188983  -0.11048183]] 0.11620259999995142
session :  3 [[ -4.77416922  -7.80558969]
 [ 44.62

## DBSCAN聚类

In [14]:
for freq in [6 , 7.5 , 8.5 , 10]:    
    print('freq : ' , freq)
    
    for session_id in [1 , 3 , 5 , 6 , 7 , 8 , 9 , 11 , 12 , 13]:
        
        X , y = session_data_labels(session_id , freq , is_training=True)
        X_sbp = feature_extraction_ar(X)
        
        lda = LinearDiscriminantAnalysis(n_components=2)
        start = time.clock()
        X_lda = lda.fit_transform(X_sbp , y)
        time1 = time.clock() - start
        
        dbscan = DBSCAN()
        start = time.clock()
        res = dbscan.fit_predict(X_pca)
        time2 = time.clock() - start
        
        _centers = get_center(X_pca , res)
        
        print('session : ' , session_id , _centers , time1 + time2)

freq :  6
session :  1 [[  1.60063255 -22.00879584]
 [ -0.83187987   4.43533803]
 [ 32.51609351  10.55445111]
 [-19.46164484  19.4182442 ]
 [ -7.40633606  -7.26570513]
 [  0.98188418  -0.91184319]
 [ -3.67217993 -12.59894815]
 [ -3.81894335   6.88017383]
 [ -0.11172537   1.97041014]] 0.1752728000001298
session :  3 [[-10.06750424 -36.12761009]
 [ -9.56689976   2.8643418 ]
 [ 60.04838805  -4.69626783]
 [  2.90308586  15.4385979 ]
 [ -8.91596997  -5.64628468]
 [ -3.13032852  21.55771325]
 [-11.50783815   4.91747164]
 [ -6.58367242  -0.60020887]
 [-13.02897857   2.12208643]] 0.14593909999985044
session :  5 [[ -2.57883357  12.6610662 ]
 [ 50.91880319  -1.74586357]
 [ -6.27702049 -11.99550483]
 [-11.06641991  -2.17924767]
 [ -7.85982533 -21.50724702]
 [ -5.27691605   3.9516469 ]
 [ -5.83531296  -5.63947138]
 [ -5.96727382  17.02481125]
 [ -5.95916174   9.4676117 ]] 0.15469799999937095
session :  6 [[-1.32318812e+01  1.10989334e+01]
 [ 5.80048478e+01  1.45565672e+01]
 [ 2.28276229e+01  8.03

session :  12 [[-18.15812274  -3.55401288]
 [ 49.51707458 -14.89709136]
 [  4.56622741  33.88577289]
 [ -0.33660664   4.30346312]
 [-22.01403924 -30.34987191]
 [-10.21090705  14.74542234]
 [ -0.88284978  -3.98447227]
 [ -1.69774827   2.08853439]
 [ -0.92165587  -2.41523872]] 0.1474222999995618
session :  13 [[-7.32863405e+00  5.54966718e+00]
 [ 4.71657655e+01 -1.18812472e+01]
 [-3.82274369e+00 -1.72752221e+01]
 [ 8.67816526e+00  3.29547515e+01]
 [-2.46342847e+01  1.44182307e-02]
 [-1.19176440e+01 -2.38290855e+01]
 [-1.50461684e-01  5.36964546e+00]
 [-2.79349068e+00  5.29734401e+00]
 [-5.45401055e+00  3.82736039e+00]] 0.17148750000069413
freq :  10
session :  1 [[ -3.41520008   0.45068902]
 [ -4.17839619 -44.41495376]
 [ 45.72709419   8.7795833 ]
 [-25.06765414  27.15170332]
 [ 13.21319958   8.95720858]
 [ -5.61953089   3.64873675]
 [ -2.70934126  -6.76936505]
 [ -9.81128664   2.15825935]
 [ -8.22633721  -0.10762628]] 0.15082389999952284
session :  3 [[ -9.55501089   0.97269895]
 [ 44.6

## GMM聚类

In [15]:
for freq in [6 , 7.5 , 8.5 , 10]:
    print('freq : ' , freq)
    
    for session_id in [1 , 3 , 5 , 6 , 7 , 8 , 9 , 11 , 12 , 13]:
        
        X , y = session_data_labels(session_id , freq , is_training=True)
        X_sbp = feature_extraction_ar(X)
        
        lda = LinearDiscriminantAnalysis(n_components=2)
        start = time.clock()
        X_lda = lda.fit_transform(X_sbp , y)
        time1 = time.clock() - start
        
        gmm = GaussianMixture(n_components=9)
        start = time.clock()
        res = gmm.fit_predict(X_lda)
        time2 = time.clock() - start
        
        _centers = get_center(X_lda , res)
        
        print('session : ' , session_id , _centers , time1 + time2)

freq :  6
session :  1 [[  1.60063255 -22.00879584]
 [ -0.11908499   1.95574023]
 [ 32.51609351  10.55445111]
 [-19.46164484  19.4182442 ]
 [ -7.40633606  -7.26570513]
 [ -3.67217993 -12.59894815]
 [  0.99410333  -0.89322823]
 [ -0.80615102   4.40267829]
 [ -3.80257775   6.85482548]] 0.11523190000025352
session :  3 [[-10.06750424 -36.12761009]
 [  2.90308586  15.4385979 ]
 [ 60.04838805  -4.69626783]
 [ -6.60631691  -0.61761094]
 [ -3.13032852  21.55771325]
 [-11.48966135   5.00410668]
 [ -8.91596997  -5.64628468]
 [-12.99933477   2.18088692]
 [ -9.5693634    2.90170599]] 0.11502590000054624
session :  5 [[ -6.27702049 -11.99550483]
 [ 50.91880319  -1.74586357]
 [ -2.59038558  12.67843315]
 [ -5.27691605   3.9516469 ]
 [ -7.85982533 -21.50724702]
 [ -5.83531296  -5.63947138]
 [-11.06641991  -2.17924767]
 [ -5.95916174   9.4676117 ]
 [ -5.98052685  17.03944319]] 0.10874060000060126
session :  6 [[-8.97469240e+00 -1.40707387e+01]
 [ 2.28276229e+01  8.03081238e-01]
 [-2.93655694e+01  4.2

session :  12 [[-22.01403924 -30.34987191]
 [ -0.44874101   4.3660873 ]
 [ 49.05692736 -14.71723688]
 [  4.56622741  33.88577289]
 [-18.15812274  -3.55401288]
 [-10.21090705  14.74542234]
 [ -0.9062761   -3.19982938]
 [ 51.49924721 -15.67184915]
 [ -1.54192041   1.97842795]] 0.14870330000121612
session :  13 [[-2.46342847e+01  1.44182307e-02]
 [ 4.71657655e+01 -1.18812472e+01]
 [-2.75276468e+00  5.39567839e+00]
 [-1.08438636e+01 -2.26379522e+01]
 [ 8.67816526e+00  3.29547515e+01]
 [-3.82274369e+00 -1.72752221e+01]
 [-6.29925228e+00  4.63702697e+00]
 [-1.50530617e-01  5.33521982e+00]
 [-1.24372152e+01 -2.44054404e+01]] 0.12968589999945834
freq :  10
session :  1 [[ -4.24520231 -45.45523981]
 [ -9.03210001   1.0809659 ]
 [-25.06765414  27.15170332]
 [ 45.72709419   8.7795833 ]
 [ 13.21319958   8.95720858]
 [ -2.70934126  -6.76936505]
 [ -3.39535823   0.39702648]
 [ -5.52970446   3.62965334]
 [ -4.11159007 -43.37466772]] 0.11184469999898283
session :  3 [[  4.60446436  30.89768615]
 [  7.