In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from scipy.fft import irfft, rfft, rfftfreq
import os
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn import svm
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_report
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.naive_bayes import MultinomialNB
from sklearn.neighbors import KNeighborsClassifier

In [4]:
files = []
directory = 'static_signs_eeg_all/' #директория с файлами статических признаков для каждого канала ээг
for filename in os.listdir(directory):
    f = os.path.join(directory, filename)
    if os.path.isfile(f):
        files.append(f)

len(files)

34

In [5]:
files[0][25:-8]

'af3'

In [6]:
labels = pd.read_csv('deap_all_labels.csv',index_col=0)

In [7]:
names= []
for file in files:
    names.append(file[25:-8])

## Построение моделей, группировка каналов

In [6]:
def calculate(data,eeg,join=True):
    result_table = pd.DataFrame()
    result_table['criterion'] = ['accuracy_score', 'f1_score', 'precision_score', 'recall_score']
    for file in files:
        if join:
            data_all = data.join(pd.read_csv(file,index_col=0),rsuffix=eeg[-1][25:-8])
        else:
            data_all = pd.read_csv(file,index_col=0)
            
        # Разделение выборки для дикторозависимых моделей
        X_train, X_test, y_train, y_test = train_test_split(data_all, labels['valence'], test_size=0.25, random_state=249)
        
        # Разделение выборки для дикторонезависимых моделей
#         X_train = data_all[:960]
#         X_test = data_all[960:]
#         y_train = labels['valence'][:960]
#         y_test = labels['valence'][960:]

        # Классификатор для valence
        clf = KNeighborsClassifier(n_neighbors=35,weights='distance',algorithm='auto')        
        clf.fit(X_train, y_train)
        results = clf.predict(X_test)
        result_table[file] = [accuracy_score(y_test,results), f1_score(y_test,results), precision_score(y_test,results), recall_score(y_test,results)]
        
        # Классификатор для arousal
#         for i in range(3):
#             clf = RandomForestClassifier(criterion='entropy',
#                      n_estimators=15,
#                      max_depth=8,
#                      min_samples_leaf= 12,
#                      min_samples_split = 5)
#             clf.fit(X_train, y_train)
#             results = clf.predict(X_test)
#             accuracy.append(accuracy_score(y_test,results))
#         result_table[file] = [np.mean(accuracy), f1_score(y_test,results), precision_score(y_test,results), recall_score(y_test,results)]
    
    if join:
        result_table = result_table.drop(columns=eeg)
    
    return result_table.T.iloc[1:,:].rename(columns={0:'acc',1:'f1',2:'precision',3:'recall'}).sort_values(by='acc',ascending=False)

def next_step(data,eeg,prev):
    print(eeg)
    for elem in prev.index:
        eeg.append(elem)
        eeg_name = [i[25:-8] for i in eeg]
        all_check[' '.join(eeg_name)] = prev.loc[[elem]].to_numpy()[0]
        if len(eeg) < 7:
            data_all = data.join(pd.read_csv(elem,index_col=0),rsuffix=elem[25:-8])
            res = calculate(data_all,eeg)
            next_step(data_all,eeg,res[:3])
        eeg.pop()
    

In [7]:
all_check = pd.DataFrame()
all_check['criterion'] = ['accuracy_score', 'f1_score', 'precision_score', 'recall_score']
start = calculate([],[],False)
for elem in start.index[:10]:
    all_check[elem[25:-8]] = start.loc[[elem]].to_numpy()[0]
    data = pd.read_csv(elem,index_col=0)
    queue = [elem]
    res = calculate(data,queue)
    next_step(data,queue,res[:3])
    all_check.T.iloc[1:,:].rename(columns={0:'acc',1:'f1',2:'precision',3:'recall'}).to_csv('eeg_group_deap_valence_' + elem[25:-8] + '.csv')


['static_signs_eeg_all/eeg_oz_all.csv']
['static_signs_eeg_all/eeg_oz_all.csv', 'static_signs_eeg_all/eeg_f4_all.csv']
['static_signs_eeg_all/eeg_oz_all.csv', 'static_signs_eeg_all/eeg_f4_all.csv', 'static_signs_eeg_all/eeg_po3_all.csv']
['static_signs_eeg_all/eeg_oz_all.csv', 'static_signs_eeg_all/eeg_f4_all.csv', 'static_signs_eeg_all/eeg_po3_all.csv', 'static_signs_eeg_all/eeg_f8_all.csv']
['static_signs_eeg_all/eeg_oz_all.csv', 'static_signs_eeg_all/eeg_f4_all.csv', 'static_signs_eeg_all/eeg_po3_all.csv', 'static_signs_eeg_all/eeg_f8_all.csv', 'static_signs_eeg_all/eeg_fc2_all.csv']
['static_signs_eeg_all/eeg_oz_all.csv', 'static_signs_eeg_all/eeg_f4_all.csv', 'static_signs_eeg_all/eeg_po3_all.csv', 'static_signs_eeg_all/eeg_f8_all.csv', 'static_signs_eeg_all/eeg_fc2_all.csv', 'static_signs_eeg_all/eeg_po4_all.csv']
['static_signs_eeg_all/eeg_oz_all.csv', 'static_signs_eeg_all/eeg_f4_all.csv', 'static_signs_eeg_all/eeg_po3_all.csv', 'static_signs_eeg_all/eeg_f8_all.csv', 'static_si

## Результаты

In [8]:
deap_valence = pd.read_csv('eeg_group_valence.csv',index_col=0).drop_duplicates()
deap_valence.sort_values(by='acc',ascending=False).head(10)

Unnamed: 0,acc,f1,precision,recall
f8 c3 p8 pz fz,0.715625,0.777506,0.719457,0.845745
po4 oz c3 p8 po3 fz,0.715625,0.776413,0.721461,0.840426
o2 af4 f3 f8 p8 p4 cp1,0.715625,0.781775,0.71179,0.867021
f8 p8 af3 o1 oz fz,0.7125,0.777778,0.712389,0.856383
f8 p8 p3 fz cp2 po3 po4,0.7125,0.779904,0.708696,0.867021
po4 oz c3 p8 pz fc6,0.7125,0.772277,0.722222,0.829787
cp2 cz po3 oz c3 f8 p3,0.7125,0.780952,0.706897,0.87234
f8 c3 f3 p8 oz fc6,0.709375,0.766917,0.725118,0.81383
f8 p8 af3 o1 oz fz o2,0.709375,0.774818,0.711111,0.851064
po4 oz p8 f3 c3 fc6 fp1,0.709375,0.762148,0.73399,0.792553


In [10]:
deap_valence_independent = pd.read_csv('eeg_group_deap_valence_independent.csv',index_col=0).drop_duplicates()
deap_valence_independent.sort_values(by='acc',ascending=False).head(10)

Unnamed: 0,acc,f1,precision,recall
p7 cp1 p3 cz po3 p8 o2,0.671875,0.774194,0.674157,0.909091
oz cp6 af3 p7 t7 po3 f4,0.665625,0.754023,0.691983,0.828283
p7 cp1 c4 oz cp6 f8 af3,0.6625,0.765217,0.671756,0.888889
cp5 af4 o1 p4 t7 cp6 pz,0.6625,0.770213,0.665441,0.914141
cp1 o2 cp2 cp5 t7 fp2 fc5,0.659375,0.750572,0.686192,0.828283
oz cp6 o2 fc5 t8 cp1 f8,0.659375,0.757238,0.677291,0.858586
oz f4 f8 t8 po4 cp5 fc2,0.659375,0.762527,0.670498,0.883838
fp2 af3 p8 fc5 cz cp1,0.659375,0.770526,0.66065,0.924242
p7 p8 cp1 po3 fz cz,0.659375,0.758315,0.675889,0.863636
cp5 af4 o1 p4 t7 cp6,0.65625,0.764957,0.662963,0.90404


In [11]:
deap_arousal = pd.read_csv('eeg_group_arousal.csv',index_col=0).drop_duplicates()
deap_arousal.sort_values(by='acc',ascending=False).head(10)

Unnamed: 0,acc,f1,precision,recall
p3 cp6 c3 cp2 po3,0.672917,0.755869,0.665289,0.875
p3 af3 fz po3 c3 o1,0.669792,0.762125,0.662651,0.896739
p3 af3 c3 t8 o1 o2,0.669792,0.759434,0.670833,0.875
p3 af3 fz cp6 p7 o1 cp1,0.669792,0.743119,0.642857,0.880435
p3 cp6 fz o2 f3 f7 fc2,0.66875,0.749403,0.668085,0.853261
p3 af3 fz cp6 c3 cp2 fp1,0.66875,0.733179,0.639676,0.858696
p3 af3 fz po3 fp1 f7 o2,0.667708,0.756881,0.654762,0.896739
p3 cp6 fz o2 f3,0.667708,0.754098,0.662551,0.875
p3 af3 cp1 fz p4 fc5 cp5,0.667708,0.764302,0.660079,0.907609
p3 af3 fz cp6 p7 fp1,0.667708,0.762353,0.672199,0.880435


In [13]:
deap_arousal_independent = pd.read_csv('eeg_group_deap_arousal_independent.csv',index_col=0).drop_duplicates()
deap_arousal_independent.sort_values(by='acc',ascending=False).head(10)

Unnamed: 0,acc,f1,precision,recall
p7 c3 t8 f4 cp6,0.622917,0.714617,0.636364,0.814815
c3 po4 p7 f4 po3 cp6,0.622917,0.718954,0.611111,0.873016
fp2 af4 c3 fp1 c4 po3 cz,0.621875,0.741722,0.636364,0.888889
p7 fp2 af4 p4 c4 cz,0.619792,0.741228,0.632959,0.89418
fp2 af4 c3 fp1 cp1,0.619792,0.760504,0.630662,0.957672
c3 po3 cp5 p8 fc2 fp1,0.61875,0.755459,0.643123,0.915344
c4 c3 f3 fc2 o2 fp1,0.61875,0.716157,0.609665,0.867725
c3 po4 fz cp6 p7 t8,0.617708,0.718062,0.615094,0.862434
o2 p7 c3 fp2 p4 fp1,0.617708,0.708046,0.626016,0.814815
cp6 fz c3 f4 p4,0.617708,0.665012,0.626168,0.708995
