Libraries

In [9]:
import numpy as np
import pandas as pd
import os
from scipy import signal
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
from mne.time_frequency import psd_array_multitaper

Pre-processing

In [10]:
data_path = os.getcwd()

#Dataframe com os dados originais
dataframes_original = {}

for file in os.listdir(data_path):
    #Escolher só os files que terminem com 1.csv
    if file.endswith('1.csv'):
        if 'subject' in file:

            file_path = os.path.join(data_path, file)
            df = pd.read_csv(file_path) # df armazena os dados dos files
            
            data = df.iloc[:, 1:5]  # data fica com as colunas 2 a 5
            
            # Sujeito e estado conforme indicado no nome 
            sujeito = file.split('subject')[1][0]
            estado = file.split('-')[1]

            key = (estado, sujeito)
            if key not in dataframes_original:
                dataframes_original[key] = []
            dataframes_original[key].append(data)


# Parâmetros dos filtros
notch_freq = 50 # Notch
quality_factor = 40
fs = 256  # Sampling rate in Hz
time_interval = 1.0 / fs  # Time interval between samples
highcut = 90 # Low-pass
lowcut = 4 # High-pass
order = 8

# Aplicação dos parâmetros dos filtros
b_notch, a_notch = signal.iirnotch(notch_freq, quality_factor, fs)
sos = signal.iirfilter(order, highcut, btype='lowpass', analog=False, ftype='butter', fs=256, output='sos')
b_hp, a_hp = signal.butter(order, lowcut, btype='highpass', fs=256)

dataframes_filtrado= {}

# Low-pass e high-pass
for key, dfs in dataframes_original.items():
    dataframes_filtrado[key] = []
    for df in dfs:
        df_filtrado = pd.DataFrame(signal.sosfiltfilt(sos, df.values, axis=0), columns=df.columns)
        df_filtrado_lphp = pd.DataFrame(signal.filtfilt(b_hp, a_hp, df_filtrado.values, axis=0), columns=df.columns)
        dataframes_filtrado[key].append(df_filtrado_lphp)

# Notch
for key, dfs in dataframes_filtrado.items():
    dataframes_filtrado[key] = []
    for df in dfs:
        filtered_channel=pd.DataFrame()
        for channel in df.columns:
            filtered_data = pd.DataFrame(signal.filtfilt(b_notch, a_notch, df[channel]))
            filtered_channel[channel]=filtered_data
        dataframes_filtrado[key].append(filtered_channel)

Feature extraction

In [11]:
def calculate_average_power(freq, magnitude, low_freq, high_freq):
    mask = (freq >= low_freq) & (freq <= high_freq)
    freq_interval = freq[mask]
    magnitude_interval = magnitude[mask]
    average_power = np.trapz(magnitude_interval, x=freq_interval)
    return average_power

# Banda de interesse
beta = (12, 35)

def extract_features(df, channel_prefix):
    features = {}
    mag_column = f'{channel_prefix}_psd'
    if mag_column in df.columns:
        mag_signal = df[mag_column]
        freqs = df[f'{channel_prefix}_freq']
        beta_power = calculate_average_power(freqs, mag_signal, beta[0], beta[1])
        features[f'{channel_prefix}_beta_power'] = beta_power
    return features

dataframes_multitaper = {}

for key, dfs in dataframes_filtrado.items():
    dataframes_multitaper[key] = []
    for df in dfs:
        df_multitaper = pd.DataFrame()
        for column in df.columns:
            psd_mt, freq_mt = psd_array_multitaper(df[column], fs, normalization='full', verbose=0)
            df_multitaper[f"{column}_freq"] = freq_mt
            df_multitaper[f"{column}_psd"] = psd_mt
        dataframes_multitaper[key].append(df_multitaper)

all_features = []

for (condition, sujeito), dfs in dataframes_multitaper.items():
    for df in dfs:
        for channel_prefix in ['TP9', 'AF7', 'AF8', 'TP10']:
            channel_features = extract_features(df, channel_prefix)
            all_features.append({
                'condition': condition,
                'sujeito': sujeito,
                'channel': channel_prefix,
                **channel_features
            })

all_data = []
labels = []

for item in all_features:
    values = [item[key] for key in item.keys() if key.endswith('_beta_power')]
    all_data.append(values)
    labels.append(item['condition'])

# Converter listas para arrays numpy
all_data = np.array(all_data)
labels = np.array(labels)

Classification

In [12]:
# Create RF classifier
rf_classifier = RandomForestClassifier(max_depth= None, min_samples_leaf=1, min_samples_split=2, n_estimators=100)

# Ajuste dos modelos a todos os dados
rf_classifier.fit(all_data, labels)
y_pred = rf_classifier.predict(all_data)

# Mapear os estados
trials_dict = {'neutral': 0, 'relaxed': 1, 'concentrating': 2}
y_pred_num = [trials_dict[label] for label in y_pred]
print("Classification report:")
print(classification_report(labels, y_pred))

Classification report:
               precision    recall  f1-score   support

concentrating       1.00      1.00      1.00        16
      neutral       1.00      1.00      1.00        16
      relaxed       1.00      1.00      1.00        16

     accuracy                           1.00        48
    macro avg       1.00      1.00      1.00        48
 weighted avg       1.00      1.00      1.00        48

