# Obtaining Data

In [None]:
import os
import numpy as np
import mne
import matplotlib.pyplot as plt
import src.functions as src


files = os.listdir('data')
file = [item for item in files if item.endswith('set')]
fs=500
freq_bands={'delta' : [0,4],
            'theta' : [4,8],
            'alpha' : [8, 13],
            'beta' : [13,35],
            'high' : [35, 45],
            'all': [1.5,45],
            'low': [1.5,12]}

In [None]:
data = {}
labels = {}

for file_index, file_name in enumerate(file):
    path_ = os.path.join( 'data', file_name )
    dados = mne.read_epochs_eeglab(path_)
    channel_names = dados.info['ch_names']
    subject_id = file_name[:3]
    
    # Extract labels
    proCorr_antiCorr_labels = len(dados['proCorr','antiCorr'].get_data())
    proErr_antiErr_nogoErr_labels = len(dados['proErr','antiErr','nogoErr'].get_data())
    labels[subject_id] = np.array([0] * proCorr_antiCorr_labels + [1] * proErr_antiErr_nogoErr_labels)
    
    # Initialize channel data dictionary
    channel_data = {}
    
    # Iterate over 'proCorr' and 'antiCorr' events
    for epoch_values in dados['proCorr','antiCorr'].get_data():
        for channel_index, channel_values in enumerate(epoch_values):
            channel_name = channel_names[channel_index]
            if channel_name not in channel_data:
                channel_data[channel_name] = []
            channel_data[channel_name].append(list(channel_values))
    
    # Iterate over 'proErr', 'antiErr', and 'nogoErr' events
    for epoch_values in dados['proErr','antiErr','nogoErr'].get_data():
        for channel_index, channel_values in enumerate(epoch_values):
            channel_name = channel_names[channel_index]
            if channel_name not in channel_data:
                channel_data[channel_name] = []
            channel_data[channel_name].append(list(channel_values))
    
    data[subject_id] = channel_data


# Spliting into pre- and post-response

In [None]:
pre_response={}
post_response={}
for subject_id, dados in data.items():
    channel_data_pre={}
    channel_data_post={}
    for channel_name, values in dados.items():
        values_pre=[]
        values_post=[]
        for i in range(len(values)):
            
            index= int(len(values[i])/2)
            values_pre.append(values[i][0:index])
            values_post.append(values[i][index:-1])
        channel_data_pre[channel_name]= values_pre
        channel_data_post[channel_name]= values_post
         
    pre_response[subject_id] = channel_data_pre #(14,60,n_events,250)
    post_response[subject_id] = channel_data_post

# Obtaining PSD

In [None]:
f_pre,S_pre= src.getpsd(pre_response,fs)    # S_pre ->(14,60,n_events,n_points)
f_post,S_post= src.getpsd(post_response,fs)

# Sorting PSD data per subject per channel

In [None]:
pre_data_per_channel= src.getdataperchannel(S_pre, channel_names) #(14,60,n_events,n_points)
post_data_per_channel= src.getdataperchannel(S_post,channel_names)

In [None]:
channel_means = {}

for channel_name, values in post_data_per_channel['P01'].items():
    mean_array = []
    
    for i in range(len(values)):
        mean_array.append(values[i])
    
    channel_means[channel_name] = src.mean_of_lists(mean_array)

for channel_name, means in channel_means.items():
    plt.plot(f_post['P01'], np.log(means))

# Customize plot
plt.xlim([0, 100])
plt.xlabel('frequency [Hz]')
plt.ylabel('PSD [log(V**2/Hz)]')
plt.title('Mean Event PSD per channel')
plt.grid(True)

# Show plot
plt.show()

In [None]:
channel_means = {}

for channel_name, values in post_data_per_channel['P02'].items():
    mean_array = []
    
    for i in range(len(values)):
        mean_array.append(values[i])
    channel_means[channel_name] = src.mean_of_lists(mean_array)

for channel_name, means in channel_means.items():
    plt.plot(f_post['P02'], np.log(means))

# Customize plot
plt.xlim([0, 100])
plt.xlabel('frequency [Hz]')
plt.ylabel('PSD [log(V**2/Hz)]')
plt.title('Mean Event PSD per channel')
plt.grid(True)

# Show plot
plt.show()

# Feature Extraction

## Frequency Features

### FCZ Features

In [None]:
feature1= src.feature('fcz_features','theta','all',post_data_per_channel,post_data_per_channel,freq_bands,f_post, f_post)
feature2= src.feature('fcz_features','theta','high',post_data_per_channel,post_data_per_channel,freq_bands,f_post,f_post)
feature3= src.feature('fcz_features','theta','all',pre_data_per_channel,pre_data_per_channel,freq_bands,f_pre,f_pre)
feature4=src.feature('fcz_features','theta','high',pre_data_per_channel,pre_data_per_channel,freq_bands,f_pre,f_pre)
feature5=src.feature('fcz_features','theta','other',post_data_per_channel,post_data_per_channel,freq_bands,f_post,f_post)
feature6=src.feature('fcz_features','theta','other',pre_data_per_channel,pre_data_per_channel,freq_bands,f_pre,f_pre)
feature7= src.feature('fcz_features','theta','theta',post_data_per_channel,pre_data_per_channel,freq_bands,f_post,f_pre)

### All Channels Features

In [None]:
feature8= src.feature('all_features','theta','all',post_data_per_channel,post_data_per_channel,freq_bands,f_post,f_post)
feature9= src.feature('all_features','theta','high',post_data_per_channel,post_data_per_channel,freq_bands,f_post,f_post)
feature10= src.feature('all_features','theta','all',pre_data_per_channel,pre_data_per_channel,freq_bands,f_pre,f_pre)
feature11= src.feature('all_features','theta','high',pre_data_per_channel,pre_data_per_channel,freq_bands,f_pre,f_pre)
feature12= src.feature('all_features','theta','theta',post_data_per_channel,pre_data_per_channel,freq_bands,f_post,f_pre)
feature13= src.feature('all_features','theta','delta',post_data_per_channel,post_data_per_channel,freq_bands,f_post,f_post)
feature14= src.feature('all_features','theta','alpha',post_data_per_channel,post_data_per_channel,freq_bands,f_post,f_post)
feature15= src.feature('all_features','theta','delta',pre_data_per_channel,pre_data_per_channel,freq_bands,f_pre,f_pre)
feature16= src.feature('all_features','theta','alpha',pre_data_per_channel,pre_data_per_channel,freq_bands,f_pre,f_pre)

### Midfrontal Features

In [None]:
feature17= src.feature('midfrontal_features','theta','other',post_data_per_channel,post_data_per_channel,freq_bands,f_post,f_post)
feature18= src.feature('midfrontal_features','theta','other',pre_data_per_channel,pre_data_per_channel,freq_bands,f_pre,f_pre)

### Delta, Alpha and Low Features

In [None]:
feature19= src.feature('low_features','delta','all',post_data_per_channel,post_data_per_channel,freq_bands,f_post,f_post)
feature20= src.feature('low_features','alpha','all',post_data_per_channel,post_data_per_channel,freq_bands,f_post,f_post)
feature21= src.feature('low_features','low','all',post_data_per_channel,post_data_per_channel,freq_bands,f_post,f_post)
feature22= src.feature('low_features','delta','high',post_data_per_channel,post_data_per_channel,freq_bands,f_post,f_post)
feature23= src.feature('low_features','alpha','high',post_data_per_channel,post_data_per_channel,freq_bands,f_post,f_post)
feature24= src.feature('low_features','low','high',post_data_per_channel,post_data_per_channel,freq_bands,f_post,f_post)
feature25= src.feature('low_features','delta','all',pre_data_per_channel,pre_data_per_channel,freq_bands,f_pre,f_pre)
feature26= src.feature('low_features','alpha','all',pre_data_per_channel,pre_data_per_channel,freq_bands,f_pre,f_pre)
feature27= src.feature('low_features','low','all',pre_data_per_channel,pre_data_per_channel,freq_bands,f_pre,f_pre)
feature28= src.feature('low_features','delta','high',pre_data_per_channel,pre_data_per_channel,freq_bands,f_pre,f_pre)
feature29= src.feature('low_features','alpha','high',pre_data_per_channel,pre_data_per_channel,freq_bands,f_pre,f_pre)
feature30= src.feature('low_features','low','high',pre_data_per_channel,pre_data_per_channel,freq_bands,f_pre,f_pre)

## Time Features

In [None]:
feature31={} 
feature32={}
feature33={}
feature34={}
feature35={}
feature36={}
Cluster_Pz=['CPZ','P1','PZ','P2','POZ']
Cluster_FCZ=['FZ','FC1','FCZ','FC2','CZ']
for subject_id,values in post_response.items():
    mean_ERN_values = []
    mean_Pe_values = []
    mean_all0_250_values = []
    mean_all250_500_values = []

    # Create a range object that will iterate over the indices of events 
    for i in range(len(next(iter(values.values())))):

        ERN_values = []
        Pe_values = []
        all0_250_values = []
        all250_500_values = []
        
        for channel_name, channel_values in values.items():
            if channel_name in Cluster_FCZ:
                ERN_values.append(np.mean(channel_values[i][35:81]))  # Potential's list for [70,160]ms for each event
            elif channel_name in Cluster_Pz:
                Pe_values.append(np.mean(channel_values[i][100:-1]))
            
            all0_250_values.append(np.mean(channel_values[i][0:126]))
            all250_500_values.append(np.mean(channel_values[i][125:-1]))
        
        mean_ERN_values.append(np.mean(ERN_values))
        mean_Pe_values.append(np.mean(Pe_values))
        mean_all0_250_values.append(np.mean(all0_250_values))
        mean_all250_500_values.append(np.mean(all250_500_values))
    feature31[subject_id]= mean_ERN_values
    feature32[subject_id]=mean_Pe_values
    feature35[subject_id]=mean_all0_250_values
    feature36[subject_id]=mean_all250_500_values


for subject_id,values in pre_response.items():
    mean_all500_250_values = []
    mean_all250_0_values = []
    for i in range(len(next(iter(values.values())))):
        
        all500_250_values = []
        all250_0_values = []
        
        for channel_name, channel_values in values.items():            
            all500_250_values.append(np.mean(channel_values[i][0:126]))
            all250_0_values.append(np.mean(channel_values[i][125:-1]))
        
        mean_all500_250_values.append(np.mean(all500_250_values))
        mean_all250_0_values.append(np.mean(all250_0_values))
    feature33[subject_id]= mean_all500_250_values
    feature34[subject_id]=mean_all250_0_values#(14,n_events)

   



## Saving Feature Data

In [None]:
frequency_features = {}
for feature_index in range(1, 31):
    feature_corr_key = 'feature{}'.format(feature_index)
    
    # Iterate over subjects and their respective values for the current feature
    for subject_id, values_list in globals()[feature_corr_key].items():
        # If subject_id not in all_subjects_data, initialize it with an empty list
        if subject_id not in frequency_features:
            frequency_features[subject_id] = []

        # Ensure that all_subjects_data[subject_id] has enough lists to store event values
        while len(frequency_features[subject_id]) < len(values_list):
            frequency_features[subject_id].append([])

        # Accumulate values for each event
        for event_index, value in enumerate(values_list):
            frequency_features[subject_id][event_index].append(value) #(14,n_events,n_features)


theta_features={}
for feature_index in range(1, 19):
    feature_corr_key = 'feature{}'.format(feature_index)
    
    for subject_id, values_list in globals()[feature_corr_key].items():

        if subject_id not in theta_features:
            theta_features[subject_id] = []

        while len(theta_features[subject_id]) < len(values_list):
            theta_features[subject_id].append([])

        for event_index, value in enumerate(values_list):
            theta_features[subject_id][event_index].append(value)


temporal_features={}
for feature_index in range(31, 37):
    feature_corr_key = 'feature{}'.format(feature_index)
    
    for subject_id, values_list in globals()[feature_corr_key].items():

        if subject_id not in temporal_features:
            temporal_features[subject_id] = []

        while len(temporal_features[subject_id]) < len(values_list):
            temporal_features[subject_id].append([])

        for event_index, value in enumerate(values_list):
            temporal_features[subject_id][event_index].append(value)

all_features={}
for feature_index in range(1, 37):
    feature_corr_key = 'feature{}'.format(feature_index)
    
    for subject_id, values_list in globals()[feature_corr_key].items():

        if subject_id not in all_features:
            all_features[subject_id] = []

        while len(all_features[subject_id]) < len(values_list):
            all_features[subject_id].append([])

        for event_index, value in enumerate(values_list):
            all_features[subject_id][event_index].append(value)


In [None]:
src.save_dict_to_file(all_features, 'all_features.txt')
src.save_dict_to_file(frequency_features, 'frequency_features.txt')
src.save_dict_to_file(theta_features, 'theta_features.txt')
src.save_dict_to_file(temporal_features, 'temporal_features.txt')
src.save_dict_to_file(labels,'labels.txt')