In [4]:
import numpy as np
import os
import pandas as pd
from scipy.io import loadmat
import matplotlib.pyplot as plt
from scipy import signal
from sklearn.cross_decomposition import CCA

In [5]:
# Dicionário para armazenar os dados
original = {}

# Path para o file com os dados
directory = "Data_trials"

# Iterar através de cada pasta de participante
for participant_folder in os.listdir(directory):
    participant_path = os.path.join(directory, participant_folder)
    if os.path.isdir(participant_path):
        participant_number = participant_folder[1:]  # Extrair número do participante do nome da pasta

        # Iterar através dos arquivos MATLAB na pasta do participante
        for file_name in os.listdir(participant_path):
            if file_name.endswith(".mat") and not file_name.endswith(("5.mat", "6.mat")):
                file_path = os.path.join(participant_path, file_name)

                # Carregar arquivo MATLAB
                mat_data = loadmat(file_path)

                # Selecionar a key com o nome do file
                keys = mat_data.keys()
                key = list(keys)[3]

                # Criar DataFrame a partir dos dados; .T para transformar linhas em colunas
                df = pd.DataFrame(mat_data[key].T, columns=['TimeStamps','PO3', 'POz', 'PO4', 'O1', 'Oz', 'O2'])

                # Adicionar os dados ao dicionário usando o nome da variável como chave
                if key not in original:
                    original[key] = []
                original[key].append(df)

In [6]:
# Cálculo da frequência de amostragem

data = original['P01_T1_R1_1'][0]
time_diff = data['TimeStamps'].diff().mean()
fs = 1 / time_diff
print("Sampling frequency =", fs, "Hz")

Sampling frequency = 512.0 Hz


In [7]:
num_samples_to_trim = int(0.5 * fs)

for key, dfs in original.items():
    trimmed_dfs = []
    for df in dfs:
        df_trimmed = df.iloc[num_samples_to_trim:-(num_samples_to_trim)].reset_index(drop=True)
        trimmed_dfs.append(df_trimmed)
    original[key] = trimmed_dfs

In [10]:
#Definição de parâmetros dos filtros
notch_freq = 50.0 
quality_factor = 40.0
highcut = 20
order = 8
lowcut = 5

#Lowpass & Notch
sos = signal.iirfilter(order, highcut, btype='lowpass', analog=False, ftype='butter', fs=fs, output='sos')
b_notch, a_notch = signal.iirnotch(notch_freq, quality_factor, fs)
b_hp, a_hp = signal.butter(order, lowcut, btype='highpass', fs=fs)

#Dataframe com os dados filtrados
filtrado= {}

#Aplicação do low pass filter
for key, dfs in original.items():
    filtrado[key] = []
    for df in dfs:
        timestamps = df['TimeStamps']
        df_without_timestamps = df.drop(columns=['TimeStamps'])
        df_filtrado_lp = pd.DataFrame(signal.sosfiltfilt(sos, df_without_timestamps.values, axis=0), columns=df_without_timestamps.columns)
        df_filtrado_lphp = pd.DataFrame(signal.filtfilt(b_hp, a_hp, df_filtrado_lp.values, axis=0), columns=df_without_timestamps.columns)
        df_filtrado = pd.concat([timestamps, df_filtrado_lphp], axis=1)
        filtrado[key].append(df_filtrado)

#Aplicação do notch filter
for key, dfs in filtrado.items():
    for df in dfs:
        for column in df.columns[1:]:
            df[column] = signal.filtfilt(b_notch, a_notch, df[column])

#Oz filtrado
filtered_channels = {}
for key, dfs in filtrado.items():
    filtered_channels[key] = []
    for df in dfs:
        df_without_timestamps = df.drop(columns=['TimeStamps'])
        filtered_channels[key].append(df_without_timestamps)

In [44]:
for key, dfs in filtered_channels.items():
    for df in dfs:
        for channel in df:
            first_matrix = np.array(df[channel]).T
            X = first_matrix.T

In [42]:
# Creating the time series windows
window_size = X.shape[0] # Assuming you want the same window size as the first matrix's rows
t = np.linspace(0, 4, window_size, endpoint=False)

# Generating sine and cosine reference signals
frequencies = [7, 11, 13, 17]
reference_signals = []
for freq in frequencies:
    sine_wave = 10 * np.sin(2 * np.pi * freq * t)
    cosine_wave = 10 * np.cos(2 * np.pi * freq * t)
    reference_signals.append(sine_wave)
    reference_signals.append(cosine_wave)
ref = np.array(reference_signals).T
print(ref.shape)

(2048, 8)


In [14]:
# Função para calcular as correlações
def calculate_correlations(matrix, reference_signals):
    cca = CCA(n_components=1)  # Defina o número de componentes canônicos
    Correlation = []

    for ref_signal in reference_signals:
        # Ajustar o CCA aos dados
        cca.fit(matrix, ref_signal)
        # Transformar os dados usando o CCA
        x1, x2 = cca.transform(matrix, ref_signal)
        corr = np.corrcoef(x1.T, x2.T)[0, 1]
        Correlation.append(corr)
    
    return Correlation

In [80]:
# Iterar sobre todas as chaves em oz_filtered e calcular as correlações
all_correlations = {}

total_correct = 0
total_trials = 0

for key in filtered_channels:
    # if key == "P01_T1_R1_1":
    for df in filtered_channels[key]:
        correlation_totals = []
        matrix = np.array(df).T
        for i in range(matrix.shape[0]):
            column = matrix[i, :][np.newaxis, :]
            column = column.T
            correlations = calculate_correlations(column, reference_signals)
            correlation_totals.append(correlations)
        average_correlations = np.mean(correlation_totals, axis=0)
        max_correlation_index = np.argmax(average_correlations)// 2  # Índice da coluna com o valor máximo
        all_correlations[key] = {'average correlations': average_correlations, 'max_correlation_index': max_correlation_index + 1}
        # print('average correlations:', average_correlations, 'max_correlation_index:', max_correlation_index)
    
    last_character = key[-1]
    if str(max_correlation_index+1) == last_character:  # Comparar com o último caractere
        total_correct += 1
    total_trials += 1

# Calcular a accuracy
accuracy = total_correct / total_trials

# # Exibir as correlações
# for key, correlations in all_correlations.items():
#     print(f"{key}: {all_correlations[key]}")

print(f"Total de trials: {total_trials}")
print(f"Total de respostas corretas: {total_correct}")
print(f"Accuracy: {accuracy * 100:.2f}%")

Total de trials: 560
Total de respostas corretas: 380
Accuracy: 67.86%
