Libraries

In [7]:
import numpy as np
import pandas as pd
import os
from scipy import signal
from sklearn.ensemble import RandomForestClassifier
import numpy as np
import pandas as pd
import os
from scipy import signal
from sklearn.metrics import classification_report
from mne.time_frequency import psd_array_multitaper
import pyxdf

Pre-processing

In [8]:
data_path = os.getcwd()

# Dataframe com os dados originais 
dataframes_original = {}
xdf_files = [file for file in os.listdir(data_path) if file.endswith('.xdf')]

for file in xdf_files:
    estado = file.split('.')[0]  # Extrair o estado do nome do arquivo sem a extensão
    streams, _ = pyxdf.load_xdf(os.path.join(data_path, file))
    
    eeg_data_found = False  # Flag para indicar se os dados EEG foram encontrados
    
    for stream in streams:     
        if stream['info']['type'][0] == 'EEG':            
            eeg_data = pd.DataFrame(stream['time_series'])
            
            # Extrai os nomes das colunas 
            channel_names = []
            for channel_desc in stream['info']['desc'][0]['channels']:
                if 'label' in channel_desc:
                    channel_names.append(channel_desc['label'][0])   

            # Verifica se o número de nomes de colunas corresponde ao número de colunas
            if len(channel_names) != eeg_data.shape[1]:
                print("Warning: Number of column names does not match number of columns in EEG data. Adjusting...")
                channel_names = [f"Channel_{i+1}" for i in range(eeg_data.shape[1])]
            
            # Define os nomes das colunas no DataFrame
            eeg_data.columns = channel_names
            key = estado
            
            if key not in dataframes_original:
                dataframes_original[key] = []
            dataframes_original[key].append(eeg_data)

            eeg_data_found = True  # Define a flag como True
            break  # Para de iterar sobre os streams após encontrar os dados de EEG
            
    if not eeg_data_found:
        print("No EEG data found in file:", file)

# Parâmetros dos filtros
notch_freq = 50 # Notch
quality_factor = 40
fs = 256  # Sampling rate in Hz
time_interval = 1.0 / fs  # Time interval between samples
highcut = 90 # Low-pass
lowcut = 4 # High-pass
order = 8

# Aplicação dos parâmetros dos filtros
b_notch, a_notch = signal.iirnotch(notch_freq, quality_factor, fs)
sos = signal.iirfilter(order, highcut, btype='lowpass', analog=False, ftype='butter', fs=256, output='sos')
b_hp, a_hp = signal.butter(order, lowcut, btype='highpass', fs=256)

# Low-pass e high-pass
dataframes_filtrado = {}
for key, dfs in dataframes_original.items():
    dataframes_filtrado[key] = []
    for df in dfs:
        df_filtrado = pd.DataFrame(signal.sosfiltfilt(sos, df.values, axis=0), columns=df.columns)
        df_filtrado_lphp = pd.DataFrame(signal.filtfilt(b_hp, a_hp, df_filtrado.values, axis=0), columns=df.columns)
        dataframes_filtrado[key].append(df_filtrado_lphp)

# Notch
for key, dfs in dataframes_filtrado.items():
    dataframes_filtrado[key] = []
    for df in dfs:
        filtered_channel=pd.DataFrame()
        for channel in df.columns:
            filtered_data = pd.DataFrame(signal.filtfilt(b_notch, a_notch, df[channel]))
            filtered_channel[channel]=filtered_data
        dataframes_filtrado[key].append(filtered_channel)



Feature extraction

In [9]:
def calculate_average_power(freq, magnitude, low_freq, high_freq):
    mask = (freq >= low_freq) & (freq <= high_freq)
    freq_interval = freq[mask]
    magnitude_interval = magnitude[mask]
    average_power = np.trapz(magnitude_interval, x=freq_interval)
    return average_power

# Banda de interesse
beta = (12, 35)

def extract_features(df):
    features = {}
    for column in df.columns:
        mag_signal = df[column]
        beta_power = calculate_average_power(df.index, mag_signal, beta[0], beta[1])
        features[f'_beta_power'] = beta_power
    return features

dataframes_multitaper = {}

for key, dfs in dataframes_filtrado.items():
    dataframes_multitaper[key] = []
    for df in dfs:
        df_multitaper = pd.DataFrame()
        for column in df.columns:
            psd_mt, freq_mt = psd_array_multitaper(df[column], fs, normalization='full', verbose=0)
            channel_name = column.split('_')[0]  # Extrair o nome do canal do nome da coluna
            df_multitaper[f"{channel_name}_{column}_freq"] = freq_mt
            df_multitaper[column] = psd_mt  # Salvar os valores PSD diretamente
        dataframes_multitaper[key].append(df_multitaper)

all_features = []

for estado, dfs in dataframes_multitaper.items():
    for df in dfs:
        for channel_prefix in ['Chan1', 'Chan2', 'Chan3', 'Chan4']:
            channel_features = extract_features(df)
            all_features.append({
                'condition': estado,
                'channel': channel_prefix,
                **channel_features
            })

all_data = []
labels = []

for item in all_features:
    values = [item[key] for key in item.keys() if key.endswith('_beta_power')]
    all_data.append(values)
    labels.append(item['condition'])

# Converter listas para arrays numpy
all_data = np.array(all_data)
labels = np.array(labels)

Classification

In [10]:
# Create RF classifier
rf_classifier = RandomForestClassifier(max_depth= None, min_samples_leaf=1, min_samples_split=2, n_estimators=100)
rf_classifier.fit(all_data, labels)
y_pred = rf_classifier.predict(all_data)

# Mapear os estados
trials_dict = {'neutral': 0, 'relaxed': 1, 'concentrating': 2}
y_pred_num = [trials_dict[label] for label in y_pred]
print("Classification report:")
print(classification_report(labels, y_pred))

Classification report:
               precision    recall  f1-score   support

concentrating       1.00      1.00      1.00         4
      neutral       1.00      1.00      1.00         4
      relaxed       1.00      1.00      1.00         4

     accuracy                           1.00        12
    macro avg       1.00      1.00      1.00        12
 weighted avg       1.00      1.00      1.00        12

