In [1]:
import numpy as np
import pywt

import seaborn as sns #绘制confusion matrix heatmap

import os
import scipy.io as sio

from statsmodels.tsa.ar_model import AR

import tqdm
import  time

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
import warnings
warnings.simplefilter('ignore') #忽略警告

In [13]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.neural_network import MLPClassifier

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from sklearn.cluster import DBSCAN

import xgboost

In [4]:
sample_rate = 256
origin_channel = 16


SAMPLE_CHANNEL = ['Pz' , 'PO3' , 'PO4' , 'O1' , 'O2' , 'Oz' , 'O9' , 'FP2' ,
                  'C4' , 'C6' , 'CP3' , 'CP1' ,
                  'CPZ' , 'CP2' , 'CP4' , 'PO8']

LABEL2STR = {0:'sen' , 1:'hong' , 2:'zhao',
             3:'fen' , 4:'xiao' , 5:'yu' , 
             6:'bin' , 7:'wang' , 8:'wei' , 
             9:'fei'}

CLIP_FORWARD = 1 #首部裁掉时间
CLIP_BACKWARD = 1 #尾部裁掉时间

trial_time = 3 #segment second


#是否进行归一化
#reference:a study on performance increasing in ssvep based bci application
#IS_NORMALIZE = True

#是否进行滤波
#IS_FILTER = False
#EEG频率范围
#reference:a study on performance increasing in ssvep based bci application
LO_FREQ = 0.5
HI_FREQ = 40

#是否陷波
#IS_NOTCH = False
NOTCH_FREQ = 50 #陷波 工频


In [5]:
from keras.utils import to_categorical

Using TensorFlow backend.


In [6]:
def load_data(filename):

    data = sio.loadmat(file_name=filename)['data_received'] #length*16 matrix

    data = data[CLIP_FORWARD * sample_rate : - CLIP_BACKWARD * sample_rate] #首部 尾部 进行裁剪
   
    return data 

In [7]:
def separate(data , label , overlap_length):
    train_data = []
    train_labels = []

    size = sample_rate * trial_time #一小段 256*3 个数据点
    data_length = data.shape[0]

    idx = 0
    
    while idx<=data_length-size:
        train_data.append(data[idx : idx+size , :])
        train_labels.append(label)

        idx = idx + (size - overlap_length)
        
    return np.array(train_data) , np.array(train_labels)

In [8]:
def shuffle_t_v(filenames):
    # np.random.shuffle(filenames)
    
    return np.random.choice(filenames , size=10) #20次的计算准确率中 每次随机选择10个样本进行训练测试

def combine(freq):    
    overlap_length = 2*256 #重叠2秒数据
    
    #保证随机性 进行置乱
    person_0_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/0/' % freq) )
    person_1_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/1/' % freq) )
    person_2_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/2/' % freq) )
    person_3_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/3/' % freq) )
    person_4_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/4/' % freq) )
    person_5_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/5/' % freq) )
    person_6_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/6/' % freq) )
    person_7_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/7/' % freq) )
    person_8_filenames = shuffle_t_v( os.listdir('../incremental/data/base_rf/%s/8/' % freq) )

    #打开信号文件 并 合并
    person_0 = np.concatenate([load_data('../incremental/data/base_rf/%s/0/' % freq + filename) for filename in person_0_filenames] , axis = 0)
    person_1 = np.concatenate([load_data('../incremental/data/base_rf/%s/1/' % freq + filename) for filename in person_1_filenames] , axis = 0)
    person_2 = np.concatenate([load_data('../incremental/data/base_rf/%s/2/' % freq + filename) for filename in person_2_filenames] , axis = 0)
    person_3 = np.concatenate([load_data('../incremental/data/base_rf/%s/3/' % freq + filename) for filename in person_3_filenames] , axis = 0)
    person_4 = np.concatenate([load_data('../incremental/data/base_rf/%s/4/' % freq + filename) for filename in person_4_filenames] , axis = 0)
    person_5 = np.concatenate([load_data('../incremental/data/base_rf/%s/5/' % freq + filename) for filename in person_5_filenames] , axis = 0)
    person_6 = np.concatenate([load_data('../incremental/data/base_rf/%s/6/' % freq + filename) for filename in person_6_filenames] , axis = 0)
    person_7 = np.concatenate([load_data('../incremental/data/base_rf/%s/7/' % freq + filename) for filename in person_7_filenames] , axis = 0)
    person_8 = np.concatenate([load_data('../incremental/data/base_rf/%s/8/' % freq + filename) for filename in person_8_filenames] , axis = 0)
    
    #============
    #训练数据分段
    train_person_data_0 , train_person_labels_0 = separate(person_0 , label = 0 , overlap_length=overlap_length)
    train_person_data_1 , train_person_labels_1 = separate(person_1 , label = 1 , overlap_length=overlap_length)
    train_person_data_2 , train_person_labels_2 = separate(person_2 , label = 2 , overlap_length=overlap_length)
    train_person_data_3 , train_person_labels_3 = separate(person_3 , label = 3 , overlap_length=overlap_length)
    train_person_data_4 , train_person_labels_4 = separate(person_4 , label = 4 , overlap_length=overlap_length)
    train_person_data_5 , train_person_labels_5 = separate(person_5 , label = 5 , overlap_length=overlap_length)
    train_person_data_6 , train_person_labels_6 = separate(person_6 , label = 6 , overlap_length=overlap_length)
    train_person_data_7 , train_person_labels_7 = separate(person_7 , label = 7 , overlap_length=overlap_length)
    train_person_data_8 , train_person_labels_8 = separate(person_8 , label = 8 , overlap_length=overlap_length)

    #合并数据
    train_data = np.concatenate((train_person_data_0 , train_person_data_1 , train_person_data_2 ,
                                 train_person_data_3 , train_person_data_4 , train_person_data_5 ,
                                 train_person_data_6 , train_person_data_7 , train_person_data_8 ,
                                 ))
    
    train_labels = np.concatenate((train_person_labels_0 , train_person_labels_1 , train_person_labels_2 ,
                                   train_person_labels_3 , train_person_labels_4 , train_person_labels_5 ,
                                   train_person_labels_6 , train_person_labels_7 , train_person_labels_8 ,
                                    ))
    
    #产生索引并置乱
    idx_train_data = list(range(train_data.shape[0]))
    np.random.shuffle(idx_train_data)

    #将训练数据置乱
    train_data = train_data[idx_train_data]
    train_labels = train_labels[idx_train_data]
        
    return train_data , train_labels


In [9]:
def session_data_labels(session_id , freq , is_training):
    if is_training:
        overlap_length = 256*2
    else:
        overlap_length = 0
    
    str_freq = str(freq)
    
    subjcets = os.listdir('../incremental/data/incremental/%s/s%d/' % (str_freq , session_id)) #受试者ID
    
    data = []
    labels = []
    
    for subjcet in subjcets:
        filenames = os.listdir('../incremental/data/incremental/%s/s%d/%s/' % (str_freq , session_id , subjcet))
        
        person = np.concatenate([load_data('../incremental/data/incremental/%s/s%d/%s/%s' % (str_freq , session_id , subjcet , filename)) for filename in filenames] , axis = 0)
        
        person_data , person_label = separate( person , label = int(subjcet) , overlap_length = overlap_length)
        
        data.append(person_data)
        labels.append(person_label)
    
    #合并数据
    data = np.concatenate(data)
    labels = np.concatenate(labels)
    
    #shuffle
    idx_data = list(range(data.shape[0]))
    np.random.shuffle(idx_data)

    data = data[idx_data]
    labels = labels[idx_data]
    
    return data , labels


In [10]:
def feature_extraction_RMS(data):
    def rms(datum):
        '''
        :datum: 一段信号 shape : 768 * 16
        '''
        return [ np.sqrt(np.mean(np.square( d ))) for d in datum.T ]
    
    feature_rms = []
    
    for datum in data: 
        feature_rms.append(rms(datum))
    
    return np.array(feature_rms)

In [11]:
def concat_and_shuffle(orig_X , orig_y , session_id , freq):
    session_id_data , session_id_labels = session_data_labels(session_id , freq , is_training=True)
    session_id_data = feature_extraction_ar(session_id_data)
    # session_id_labels = to_categorical(session_id_labels , num_classes=9)
    
    orig_X = np.concatenate((orig_X , session_id_data) , axis=0)
    orig_y = np.concatenate((orig_y , session_id_labels) , axis=0)
    
    idx = list(range(orig_X.shape[0]))
    np.random.shuffle(idx)
    
    orig_X = orig_X[idx]
    orig_y = orig_y[idx]
    
    return orig_X , orig_y

In [12]:
def get_center(data , label):
    centers = []
    
    for label_id in range(9): #一共9个受试者
        equal_idx = np.where(label == label_id)
    
        center = np.mean(data[equal_idx] , axis = 0)
        centers.append(center)
        
    return np.array(centers)

def get_center_simple(data):
    '''
    计算单个用户的脑电的中心
    '''
    return np.mean(data , axis=0)

In [14]:
X , y = session_data_labels(1 , 6 , is_training=True)
X_sbp = feature_extraction_RMS(X)

In [15]:
pca = PCA(n_components=2)
start = time.clock()
X_pca = pca.fit_transform(X_sbp)
time1 = time.clock() - start

In [17]:
dbscan = DBSCAN()
_ = dbscan.fit(X_pca)

In [18]:
dbscan.fit_predict(X_pca)

array([-1, -1, -1, ..., -1, -1, -1], dtype=int64)

# PCA降维

## kmeans聚类

In [15]:
for freq in [6 , 7.5 , 8.5 , 10]:    
    print('freq : ' , freq)
    
    for session_id in [1 , 3 , 5 , 6 , 7 , 8 , 9 , 11 , 12 , 13]:
        
        X , y = session_data_labels(session_id , freq , is_training=True)
        X_sbp = feature_extraction_RMS(X)
        
        pca = PCA(n_components=2)
        start = time.clock()
        X_pca = pca.fit_transform(X_sbp)
        time1 = time.clock() - start
        
        kmeans = KMeans(n_clusters=9)
        start = time.clock()
        _ = kmeans.fit_transform(X_pca)
        time2 = time.clock() - start
                
        print('session : ' , session_id , kmeans.cluster_centers_ , time1 + time2)

freq :  6
session :  1 [[ 10010.419   -4230.933 ]
 [ 48636.555    3198.599 ]
 [-16749.674    3082.0154]
 [ -7521.4233  -2716.4714]
 [ -5313.1445  18285.004 ]
 [ -5103.003  -15641.252 ]
 [  4645.6934 -15693.303 ]
 [-19978.266   -1479.0504]
 [ -3248.2925   -224.5728]] 0.10130159999999999
session :  3 [[-1.1086482e+05 -1.2722072e+04]
 [ 8.7625075e+05  3.9666547e+02]
 [-1.2930768e+05  8.8162275e+03]
 [-6.9935992e+04 -7.5937616e+02]
 [-8.4653727e+04 -1.1317825e+04]
 [-1.0002624e+05  2.0964303e+04]
 [-1.1739029e+05 -5.7423096e+03]
 [ 8.6270756e+05  7.7370148e+02]
 [-1.1413672e+05  1.8936217e+04]] 0.05342350000000007
session :  5 [[-37208.484  -23244.922 ]
 [197092.86    -7218.011 ]
 [-11680.661   28292.348 ]
 [-23386.14     4150.3994]
 [-31626.629  -12890.508 ]
 [ -5761.245   37323.125 ]
 [-14019.395   21071.229 ]
 [-30075.668   -5626.795 ]
 [-31527.512  -17591.283 ]] 0.05576400000000015
session :  6 [[-16869.49   -8411.644]
 [158905.12   17630.44 ]
 [-42015.74  109952.76 ]
 [-22492.465 -311

session :  13 [[ -7789.5283 -13885.631 ]
 [ 82256.97    -2301.4983]
 [ -9219.157   42118.92  ]
 [-14140.697   -6390.0186]
 [ -9487.165   16576.344 ]
 [-10550.521  -23179.678 ]
 [-18159.055     167.5012]
 [ 97083.35     1045.0869]
 [ -6621.1943  -5080.2617]] 0.07316969999999756
freq :  10
session :  1 [[ -9302.281    9358.878 ]
 [ 67134.48    -9615.407 ]
 [-18624.572  -16730.826 ]
 [  1268.5735  45104.74  ]
 [  1733.7273  -2390.3757]
 [-12894.699   -3169.2358]
 [-10074.35   -18801.066 ]
 [ -3157.382    6269.955 ]
 [-15773.902  -12611.174 ]] 0.06489899999999693
session :  3 [[-28127.28     4861.327 ]
 [ 29153.604  -11050.339 ]
 [ 77723.02    -4007.5789]
 [ 18351.63    26126.31  ]
 [-16483.838   -4077.8218]
 [-39738.31     1646.167 ]
 [-12386.646    4684.975 ]
 [-20766.662  -15163.515 ]
 [ -5692.335   -2773.9329]] 0.05827729999999676
session :  5 [[-32770.043  -18794.709 ]
 [189497.69    -6726.6606]
 [-17450.646   26168.758 ]
 [-28735.492   -7717.9097]
 [ -6093.433   34902.586 ]
 [-16188.

## DBSCAN聚类

In [20]:
for freq in [6 , 7.5 , 8.5 , 10]:
    
    print('freq : ' , freq)
    
    for session_id in [1 , 3 , 5 , 6 , 7 , 8 , 9 , 11 , 12 , 13]:
        
        X , y = session_data_labels(session_id , freq , is_training=True)
        X_sbp = feature_extraction_RMS(X)
        
        pca = PCA(n_components=2)
        start = time.clock()
        X_pca = pca.fit_transform(X_sbp)
        time1 = time.clock() - start
        
        dbscan = DBSCAN()
        start = time.clock()
        res = dbscan.fit_predict(X_pca)
        time2 = time.clock() - start
        
        _centers = get_center(X_pca , res)
        
        print('session : ' , session_id , _centers , time1 + time2)
    

freq :  6
session :  1 [[-16711.828     3190.6697 ]
 [ 48636.54      3198.572  ]
 [  4645.7017  -15693.256  ]
 [ -3240.757     -217.37665]
 [ -5313.139    18285.031  ]
 [ 10010.424    -4230.8457 ]
 [-19890.053    -1405.8038 ]
 [ -5102.991   -15641.181  ]
 [ -7504.9453   -2711.261  ]] 0.1062566999999035
session :  3 [[-1.1690031e+05 -5.9514463e+03]
 [ 8.6270744e+05  7.7369934e+02]
 [-6.9936000e+04 -7.5938055e+02]
 [-1.2930764e+05  8.8162275e+03]
 [-1.1401582e+05  1.8944424e+04]
 [-8.4653703e+04 -1.1317825e+04]
 [ 8.7625081e+05  3.9666428e+02]
 [-1.1061524e+05 -1.3618818e+04]
 [-9.9477852e+04  2.1080760e+04]] 0.08094520000008742
session :  5 [[-31527.518  -17591.295 ]
 [197093.03    -7218.021 ]
 [-14005.834   21160.791 ]
 [-23386.143    4150.3926]
 [ -5735.7754  37394.996 ]
 [-30075.662   -5626.8   ]
 [-37208.477  -23244.932 ]
 [-31626.64   -12890.5205]
 [-11618.522   28380.424 ]] 0.14718729999981406
session :  6 [[-16869.475  -8411.635]
 [159511.67   17409.344]
 [-42015.766 109952.77 ]


session :  13 [[ -6621.184    -5080.258  ]
 [ 82256.91     -2301.4963 ]
 [ -9219.148    42118.94   ]
 [ -7789.514   -13885.625  ]
 [ -9487.154    16576.322  ]
 [-14169.301    -6379.9175 ]
 [ 97083.266     1045.0894 ]
 [-10550.511   -23179.656  ]
 [-18135.441      299.95743]] 0.10341429999994034
freq :  10
session :  1 [[ -3157.3816   6269.9663]
 [-15588.048  -12632.162 ]
 [ 67134.47    -9615.405 ]
 [  1268.5748  45104.77  ]
 [-12894.693   -3169.225 ]
 [  1733.7256  -2390.371 ]
 [-10005.857  -18844.863 ]
 [ -9302.281    9358.885 ]
 [-18642.775  -16336.068 ]] 0.10139379999986886
session :  3 [[ 29153.6    -11050.349 ]
 [-16483.834   -4077.8303]
 [ 77723.02    -4007.575 ]
 [ 18351.617   26126.312 ]
 [-39738.31     1646.1669]
 [ -5692.3228  -2773.9302]
 [-28127.277    4861.3354]
 [-20766.66   -15163.526 ]
 [-12386.632    4684.988 ]] 0.1777159999999185
session :  5 [[-28202.371  -17923.602 ]
 [189497.72    -6726.634 ]
 [-11510.16    23182.482 ]
 [ -6093.4814  34902.605 ]
 [-28735.5     -771

## GMM聚类

In [26]:
for freq in [6 , 7.5 , 8.5 , 10]:    
    print('freq : ' , freq)
    
    for session_id in [1 , 3 , 5 , 6 , 7 , 8 , 9 , 11 , 12 , 13]:
        
        X , y = session_data_labels(session_id , freq , is_training=True)
        X_sbp = feature_extraction_RMS(X)
        
        pca = PCA(n_components=2)
        start = time.clock()
        X_pca = pca.fit_transform(X_sbp)
        time1 = time.clock() - start
        
        gmm = GaussianMixture(n_components=9)
        start = time.clock()
        res = gmm.fit_predict(X_pca)
        time2 = time.clock() - start
        
        _centers = get_center(X_pca , res)
        
        print('session : ' , session_id , _centers , time1 + time2)

freq :  6
session :  1 [[ -5169.169   -15644.102  ]
 [ 48636.56      3198.5923 ]
 [ -4139.677     3321.9773 ]
 [-19162.223     -408.67822]
 [ -5313.135    18284.988  ]
 [ 10010.425    -4230.9077 ]
 [  4572.184   -15690.796  ]
 [-16004.312     4059.5881 ]
 [ -5096.125    -2450.335  ]] 0.08241759999998521
session :  3 [[-1.29307680e+05  8.81623242e+03]
 [ 8.62624562e+05  7.74667664e+02]
 [-8.46537031e+04 -1.13178154e+04]
 [-1.16900328e+05 -5.95145312e+03]
 [-1.14015867e+05  1.89444434e+04]
 [-6.99359844e+04 -7.59372070e+02]
 [ 8.76135250e+05  4.01504425e+02]
 [-1.10615266e+05 -1.36188203e+04]
 [-9.94778516e+04  2.10807676e+04]] 0.015134300000028134
session :  5 [[-31124.2    -12557.742 ]
 [197092.97    -7218.026 ]
 [-11657.449   28407.947 ]
 [-23449.268    4066.822 ]
 [-37208.523  -23244.934 ]
 [ -5761.2554  37323.137 ]
 [-13944.828   21256.992 ]
 [-32303.95   -16507.182 ]
 [-30161.045   -5773.82  ]] 0.0289231000000143
session :  6 [[-22586.2    -31317.904 ]
 [157977.98    18178.129 ]
 [

session :  13 [[ -9487.156    16576.332  ]
 [ 97207.76      1065.35   ]
 [-14061.952    -6401.0347 ]
 [ -9219.151    42118.953  ]
 [-10550.508   -23179.662  ]
 [ 82353.11     -2274.5967 ]
 [ -6621.183    -5080.2563 ]
 [ -7789.5137  -13885.627  ]
 [-18236.592     -196.36684]] 0.018960399999997435
freq :  10
session :  1 [[-16392.807  -14179.877 ]
 [ 69623.71   -13428.807 ]
 [  1268.5433  45104.754 ]
 [ -3157.4077   6269.958 ]
 [  1733.7035  -2390.366 ]
 [ -8994.623  -19260.648 ]
 [-12894.727   -3169.235 ]
 [ -9302.311    9358.88  ]
 [ 66188.58    -8166.308 ]] 0.03509079999997766
session :  3 [[ 18351.621   26126.314 ]
 [-12386.646    4684.983 ]
 [ 77723.04    -4007.5774]
 [ 29153.617  -11050.342 ]
 [-28127.291    4861.328 ]
 [-20766.664  -15163.529 ]
 [-39738.324    1646.1635]
 [-16483.844   -4077.8296]
 [ -5692.3315  -2773.935 ]] 0.014724200000017618
session :  5 [[-28202.393 -17923.623]
 [189497.6    -6726.663]
 [-17450.662  26168.762]
 [ -6093.454  34902.59 ]
 [-28735.5    -7717.923]

# LDA降维

## kmeans聚类

In [20]:
for freq in [6 , 7.5 , 8.5 , 10]:    
    print('freq : ' , freq)
    
    for session_id in [1 , 3 , 5 , 6 , 7 , 8 , 9 , 11 , 12 , 13]:
        
        X , y = session_data_labels(session_id , freq , is_training=True)
        X_sbp = feature_extraction_RMS(X)
        
        lda = LinearDiscriminantAnalysis(n_components=2)
        start = time.clock()
        X_lda = lda.fit_transform(X_sbp , y)
        time1 = time.clock() - start
        
        kmeans = KMeans(n_clusters=9)
        start = time.clock()
        _ = kmeans.fit_transform(X_lda)
        time2 = time.clock() - start
        
        print('session : ' , session_id , kmeans.cluster_centers_ , time1 + time2)

freq :  6
session :  1 [[-18.62413864   9.78983397]
 [ 49.31160122   5.197085  ]
 [ -7.50327039 -30.32203151]
 [ -8.28978304  41.48761179]
 [ 30.61685239   4.70215318]
 [-14.57127081 -12.30654097]
 [  9.14441938 -21.62942599]
 [-17.27719695  -2.3079119 ]
 [-22.80721316   5.38922643]] 0.0469891
session :  3 [[-1.08966118e+02 -1.69605774e+01]
 [ 1.09195420e+03 -1.08175654e+00]
 [-1.25165742e+02  2.17844358e+01]
 [-1.82611289e+02 -1.93314790e+01]
 [-1.27239376e+02  5.78426824e+01]
 [-1.54276160e+02  7.74974150e+00]
 [-1.58003725e+02 -2.90187059e+01]
 [-1.17082097e+02  7.25090542e+00]
 [-1.18609693e+02 -2.82352463e+01]] 0.045040699999999934
session :  5 [[-7.02803822e+01 -6.10576096e-01]
 [ 5.53118241e+02  7.72121043e-02]
 [-5.68995509e+01 -4.28517964e+01]
 [-5.99401497e+01  3.40092146e+01]
 [-6.04741820e+01  1.41261773e+01]
 [-7.68351833e+01 -1.13706486e+01]
 [-7.61492071e+01  6.10520272e+00]
 [-7.64860229e+01  1.65227056e+00]
 [-7.60269497e+01 -1.30870706e+00]] 0.05377579999999993
sessio

session :  11 [[-1.12315499e+02  3.92769278e+01]
 [ 7.11095331e+02  1.77448040e+01]
 [-1.35814290e+01 -1.49318810e+02]
 [-8.83125413e+01 -8.48586557e+00]
 [-1.07908651e+02  2.87342278e+01]
 [-7.51946939e+01  7.88651243e+00]
 [-8.95987055e+01  5.18539558e-01]
 [-1.22132249e+02  3.73054563e+01]
 [-1.02051563e+02  2.63382080e+01]] 0.04231919999999789
session :  12 [[-104.80319141  -46.67561664]
 [ 471.43921303  -22.490635  ]
 [ -38.42769935   24.48406942]
 [ -14.18043751   81.57878592]
 [ -70.36743017  -43.96971897]
 [ -33.16255809    8.19340499]
 [ -83.86664043  -31.73379612]
 [ -31.00709757   81.28758869]
 [ -95.6241585   -50.6740823 ]] 0.04263010000000378
session :  13 [[  44.76465255   -7.05424933]
 [ -58.22507018  -36.69843681]
 [  -8.49479837    8.34534804]
 [ 144.29680638  -17.85501612]
 [ -70.78817759   46.94950751]
 [ -43.56112818 -100.67879655]
 [   3.6768816    55.82472055]
 [ -21.344033     30.18095828]
 [   9.67486679   20.98596445]] 0.04196190000000044
freq :  10
session :  

## DBSCAN聚类

In [21]:
for freq in [6 , 7.5 , 8.5 , 10]:    
    print('freq : ' , freq)
    
    for session_id in [1 , 3 , 5 , 6 , 7 , 8 , 9 , 11 , 12 , 13]:
        
        X , y = session_data_labels(session_id , freq , is_training=True)
        X_sbp = feature_extraction_RMS(X)
        
        lda = LinearDiscriminantAnalysis(n_components=2)
        start = time.clock()
        X_lda = lda.fit_transform(X_sbp , y)
        time1 = time.clock() - start
        
        dbscan = DBSCAN()
        start = time.clock()
        res = dbscan.fit_predict(X_pca)
        time2 = time.clock() - start
        
        _centers = get_center(X_pca , res)
        
        print('session : ' , session_id , _centers , time1 + time2)

freq :  6
session :  1 [[ -7.50327039 -30.32203151]
 [ 49.31160122   5.197085  ]
 [-18.62413864   9.78983397]
 [ -8.28978304  41.48761179]
 [-14.57127081 -12.30654097]
 [  9.14441938 -21.62942599]
 [ 30.61685239   4.70215318]
 [-22.80721316   5.38922643]
 [-17.27719695  -2.3079119 ]] 0.08003570000005311
session :  3 [[-1.18609693e+02 -2.82352463e+01]
 [ 1.09195420e+03 -1.08175654e+00]
 [-1.54276160e+02  7.74974150e+00]
 [-1.27239376e+02  5.78426824e+01]
 [-1.82611289e+02 -1.93314790e+01]
 [-1.17082097e+02  7.25090542e+00]
 [-1.08966118e+02 -1.69605774e+01]
 [-1.58003725e+02 -2.90187059e+01]
 [-1.25165742e+02  2.17844358e+01]] 0.08667140000011386
session :  5 [[-5.77562717e+01 -4.31391687e+01]
 [ 5.53118241e+02  7.72121043e-02]
 [-7.61805985e+01  5.25694830e+00]
 [-5.99401497e+01  3.40092146e+01]
 [-7.68351833e+01 -1.13706486e+01]
 [-6.04741820e+01  1.41261773e+01]
 [-7.62719751e+01 -4.31155472e-01]
 [-7.02803822e+01 -6.10576096e-01]
 [-5.36794620e+01 -4.17716728e+01]] 0.086371099999951

session :  11 [[-1.02051563e+02  2.63382080e+01]
 [ 7.11095331e+02  1.77448040e+01]
 [-1.35814290e+01 -1.49318810e+02]
 [-8.95987055e+01  5.18539558e-01]
 [-1.22132249e+02  3.73054563e+01]
 [-7.51946939e+01  7.88651243e+00]
 [-1.12315499e+02  3.92769278e+01]
 [-8.83125413e+01 -8.48586557e+00]
 [-1.07908651e+02  2.87342278e+01]] 0.07157340000003387
session :  12 [[ -14.18043751   81.57878592]
 [ 471.43921303  -22.490635  ]
 [ -70.36743017  -43.96971897]
 [ -38.42769935   24.48406942]
 [ -95.6241585   -50.6740823 ]
 [ -31.00709757   81.28758869]
 [ -83.86664043  -31.73379612]
 [ -33.16255809    8.19340499]
 [-104.80319141  -46.67561664]] 0.06209539999986191
session :  13 [[   9.67486679   20.98596445]
 [ -43.56112818 -100.67879655]
 [ 144.29680638  -17.85501612]
 [ -70.78817759   46.94950751]
 [ -58.22507018  -36.69843681]
 [  44.76465255   -7.05424933]
 [ -21.344033     30.18095828]
 [   3.6768816    55.82472055]
 [  -8.49479837    8.34534804]] 0.06455909999999676
freq :  10
session :  

## GMM聚类

In [25]:
for freq in [6 , 7.5 , 8.5 , 10]:
    print('freq : ' , freq)
    
    for session_id in [1 , 3 , 5 , 6 , 7 , 8 , 9 , 11 , 12 , 13]:
        
        X , y = session_data_labels(session_id , freq , is_training=True)
        X_sbp = feature_extraction_RMS(X)
        
        lda = LinearDiscriminantAnalysis(n_components=2)
        start = time.clock()
        X_lda = lda.fit_transform(X_sbp , y)
        time1 = time.clock() - start
        
        gmm = GaussianMixture(n_components=9)
        start = time.clock()
        res = gmm.fit_predict(X_lda)
        time2 = time.clock() - start
        
        _centers = get_center(X_lda , res)
        
        print('session : ' , session_id , _centers , time1 + time2)

freq :  6
session :  1 [[  9.14441938 -21.62942599]
 [-18.62413864   9.78983397]
 [ 49.31160122   5.197085  ]
 [ -8.28978304  41.48761179]
 [-17.27719695  -2.3079119 ]
 [ 30.61685239   4.70215318]
 [ -7.50327039 -30.32203151]
 [-14.57127081 -12.30654097]
 [-22.80721316   5.38922643]] 0.015922900000006734
session :  3 [[-1.54276160e+02  7.74974150e+00]
 [ 1.09195420e+03 -1.08175654e+00]
 [-1.18609693e+02 -2.82352463e+01]
 [-1.25165742e+02  2.17844358e+01]
 [-1.58003725e+02 -2.90187059e+01]
 [-1.27239376e+02  5.78426824e+01]
 [-1.82611289e+02 -1.93314790e+01]
 [-1.08966118e+02 -1.69605774e+01]
 [-1.17082097e+02  7.25090542e+00]] 0.016042799999993917
session :  5 [[-7.62833098e+01 -5.06320012e-01]
 [ 5.53118241e+02  7.72121043e-02]
 [-5.35015034e+01 -4.18244904e+01]
 [-5.99401497e+01  3.40092146e+01]
 [-6.04741820e+01  1.41261773e+01]
 [-7.68351833e+01 -1.13706486e+01]
 [-7.61710024e+01  5.16657920e+00]
 [-5.77261029e+01 -4.31016817e+01]
 [-7.02803822e+01 -6.10576096e-01]] 0.0175801000000

session :  11 [[-1.35814290e+01 -1.49318810e+02]
 [ 7.11095331e+02  1.77448040e+01]
 [-1.02051563e+02  2.63382080e+01]
 [-8.95987055e+01  5.18539558e-01]
 [-1.22132249e+02  3.73054563e+01]
 [-7.51946939e+01  7.88651243e+00]
 [-8.83125413e+01 -8.48586557e+00]
 [-1.12315499e+02  3.92769278e+01]
 [-1.07908651e+02  2.87342278e+01]] 0.015385899999998287
session :  12 [[ -83.86664043  -31.73379612]
 [ 471.43921303  -22.490635  ]
 [ -31.00709757   81.28758869]
 [ -33.16255809    8.19340499]
 [-104.80319141  -46.67561664]
 [ -14.18043751   81.57878592]
 [ -70.36743017  -43.96971897]
 [ -38.42769935   24.48406942]
 [ -95.6241585   -50.6740823 ]] 0.01566259999999886
session :  13 [[ -70.78817759   46.94950751]
 [  44.76465255   -7.05424933]
 [ -43.56112818 -100.67879655]
 [   9.67486679   20.98596445]
 [ 144.29680638  -17.85501612]
 [ -58.22507018  -36.69843681]
 [   3.6768816    55.82472055]
 [ -21.344033     30.18095828]
 [  -8.49479837    8.34534804]] 0.015781100000012316
freq :  10
session :