In [1]:
import pandas as pd
import numpy as np
import pickle as pickle
import os
import time
from pywt import wavedec
import pyeeg
import scipy.io as sio

from pathlib import Path
cwd = os.getcwd()
parent = Path(cwd).parent

from importlib.machinery import SourceFileLoader
thundersvm = SourceFileLoader("thundersvm", r"thundersvm\python\thundersvm\thundersvm.py").load_module()


In [2]:
from pywt import Wavelet
from math import floor, ceil
from numpy import concatenate, flipud, zeros, convolve, array

def padding_symmetric(signal, size=8):
    '''
    Applies a symmetric padding of the specified size to the input signal.
    Parameters
    ----------
    signal : ndarray
        The signal to be padded.
    size : int, optional
        The size of the padding which corresponds to the size of the filter. The default is 8.
    Returns
    -------
    padded_signal : ndarray
        Padded signal.
    '''
    
    padded_signal = concatenate([flipud(signal[:size]), signal, flipud(signal[-size:])])
    return padded_signal


def restore_signal(signal, reconstruction_filter, real_len):
    '''
    Restores the signal to its original size using the reconstruction filter.
    Parameters
    ----------
    signal : ndarray
        The signal to be restored.
    reconstruction_filter : list
        The reconstruction filter to be used for restoring the signal.
    real_len : int
        Real length of the signal.
    Returns
    -------
    restored_signal : ndarray
        Restored signal of the specified length.
    '''
    restored_signal = zeros(2 * len(signal) + 1)
    for i in range(len(signal)):
        restored_signal[i*2+1] = signal[i]
    restored_signal = convolve(restored_signal, reconstruction_filter)
    restored_len = len(restored_signal)
    exceed_len = (restored_len - real_len) / 2
    restored_signal = restored_signal[int(floor(exceed_len)):(restored_len - int(ceil(exceed_len)))]
    return restored_signal

def DWTfn(signal, level=3, mother_wavelet='db4'):
    '''
    Applies a Discrete Wavelet Transform to the signal.
    Parameters
    ----------
    signal : ndarray
        The signal on which the DWT will be applied.
    level : int, optional
        The decomposition levels for the DWT. The default is 3.
    mother_wavelet : str, optional
        The mother wavelet that it is going to be used in the DWT. The default is "db4".
    Returns
    -------
    restored_approx_coeff : list
        Restored approximations coefficients.
    restored_detail_coeff : list
        Restored detail coefficients.
    '''
    if type(signal).__name__ != "ndarray" and type(signal) != list:
        raise TypeError(f"'signal' must be 'ndarray', received: '{type(signal).__name__}'")
    if type(signal) == list:
        signal = array(signal)
    if "float" not in signal.dtype.name and "int" not in signal.dtype.name:
        raise TypeError(f"All elements of 'signal' must be numbers")
           
    if type(level) != int:
        raise TypeError(f"'level' must be 'int', received: '{type(level).__name__}'")
    if level < 1:
        raise TypeError(f"'level' must be greater than 0, received: {level}")
        
    if mother_wavelet not in ['haar', 'db1', 'db2', 'db3', 'db4', 'db5', 'db6', 'db7', 'db8', 'db9', 'db10', 'db11', 'db12', 'db13', 'db14', 'db15', 'db16', 'db17', 'db18', 'db19', 'db20', 'db21', 'db22', 'db23', 'db24', 'db25', 'db26', 'db27', 'db28', 'db29', 'db30', 'db31', 'db32', 'db33', 'db34', 'db35', 'db36', 'db37', 'db38', 'sym2', 'sym3', 'sym4', 'sym5', 'sym6', 'sym7', 'sym8', 'sym9', 'sym10', 'sym11', 'sym12', 'sym13', 'sym14', 'sym15', 'sym16', 'sym17', 'sym18', 'sym19', 'sym20', 'coif1', 'coif2', 'coif3', 'coif4', 'coif5', 'coif6', 'coif7', 'coif8', 'coif9', 'coif10', 'coif11', 'coif12', 'coif13', 'coif14', 'coif15', 'coif16', 'coif17', 'bior1.1', 'bior1.3', 'bior1.5', 'bior2.2', 'bior2.4', 'bior2.6', 'bior2.8', 'bior3.1', 'bior3.3', 'bior3.5', 'bior3.7', 'bior3.9', 'bior4.4', 'bior5.5', 'bior6.8', 'rbio1.1', 'rbio1.3', 'rbio1.5', 'rbio2.2', 'rbio2.4', 'rbio2.6', 'rbio2.8', 'rbio3.1', 'rbio3.3', 'rbio3.5', 'rbio3.7', 'rbio3.9', 'rbio4.4', 'rbio5.5', 'rbio6.8', 'dmey', 'gaus1', 'gaus2', 'gaus3', 'gaus4', 'gaus5', 'gaus6', 'gaus7', 'gaus8', 'mexh', 'morl', 'cgau1', 'cgau2', 'cgau3', 'cgau4', 'cgau5', 'cgau6', 'cgau7', 'cgau8', 'shan', 'fbsp', 'cmor']:
        raise TypeError(f"Invalid 'mother_wavelet' must be 'haar', 'db1', 'db2', 'db3', 'db4', 'db5', 'db6', 'db7', 'db8', 'db9', 'db10', 'db11', 'db12', 'db13', 'db14', 'db15', 'db16', 'db17', 'db18', 'db19', 'db20', 'db21', 'db22', 'db23', 'db24', 'db25', 'db26', 'db27', 'db28', 'db29', 'db30', 'db31', 'db32', 'db33', 'db34', 'db35', 'db36', 'db37', 'db38', 'sym2', 'sym3', 'sym4', 'sym5', 'sym6', 'sym7', 'sym8', 'sym9', 'sym10', 'sym11', 'sym12', 'sym13', 'sym14', 'sym15', 'sym16', 'sym17', 'sym18', 'sym19', 'sym20', 'coif1', 'coif2', 'coif3', 'coif4', 'coif5', 'coif6', 'coif7', 'coif8', 'coif9', 'coif10', 'coif11', 'coif12', 'coif13', 'coif14', 'coif15', 'coif16', 'coif17', 'bior1.1', 'bior1.3', 'bior1.5', 'bior2.2', 'bior2.4', 'bior2.6', 'bior2.8', 'bior3.1', 'bior3.3', 'bior3.5', 'bior3.7', 'bior3.9', 'bior4.4', 'bior5.5', 'bior6.8', 'rbio1.1', 'rbio1.3', 'rbio1.5', 'rbio2.2', 'rbio2.4', 'rbio2.6', 'rbio2.8', 'rbio3.1', 'rbio3.3', 'rbio3.5', 'rbio3.7', 'rbio3.9', 'rbio4.4', 'rbio5.5', 'rbio6.8', 'dmey', 'gaus1', 'gaus2', 'gaus3', 'gaus4', 'gaus5', 'gaus6', 'gaus7', 'gaus8', 'mexh', 'morl', 'cgau1', 'cgau2', 'cgau3', 'cgau4', 'cgau5', 'cgau6', 'cgau7', 'cgau8', 'shan', 'fbsp', or 'cmor', received: '{mother_wavelet}'")
        
    original_len = len(signal)
    approx_coeff = []
    detail_coeff = []
    wavelet = pywt.Wavelet(mother_wavelet)
    low_filter = wavelet.dec_lo
    high_filter = wavelet.dec_hi
    filter_size = len(low_filter)
    try:
        for _ in range(level):
            padded_signal = padding_symmetric(signal, filter_size)
            low_pass_filtered_signal = convolve(padded_signal, low_filter)[filter_size:(2*filter_size)+len(signal)-1] 
            low_pass_filtered_signal = low_pass_filtered_signal[1:len(low_pass_filtered_signal):2]
            high_pass_filtered_signal = convolve(padded_signal, high_filter)[filter_size:filter_size+len(signal)+filter_size-1]
            high_pass_filtered_signal = high_pass_filtered_signal[1:len(high_pass_filtered_signal):2]
            approx_coeff.append(low_pass_filtered_signal)
            detail_coeff.append(high_pass_filtered_signal)
            signal = low_pass_filtered_signal
    except:
        raise
    low_reconstruction_filter = wavelet.rec_lo
    high_reconstruction_filter = wavelet.rec_hi
    real_lengths = []
    for i in range(level-2,-1,-1):
        real_lengths.append(len(approx_coeff[i]))
    real_lengths.append(original_len)
    restored_approx_coeff = []
    for i in range(level):
        restored_signal = restore_signal(approx_coeff[i], low_reconstruction_filter, real_lengths[level-1-i])
        for j in range(i):
            restored_signal = restore_signal(restored_signal, low_reconstruction_filter, real_lengths[level-i+j])
        restored_approx_coeff.append(restored_signal)
    restored_detail_coeff = []
    for i in range(level):
        restored_signal = restore_signal(detail_coeff[i], high_reconstruction_filter, real_lengths[level-1-i])
        for j in range(i):
            restored_signal = restore_signal(restored_signal, high_reconstruction_filter, real_lengths[level-i+j])
        restored_detail_coeff.append(restored_signal)
    return restored_approx_coeff, restored_detail_coeff 

def entropy_fn(signal):
    entropy_val = 0
    for i in signal:
        entropy_val += (i**2)*(np.log2(i**2))     
    return entropy_val

def energy_fn(signal):
    return np.sum(np.array(signal)**2)
        
import pywt
def dwt_fn(signal):
    #print(signal)
    #coeffs = pywt.wavedec(signal, 'db4', level=4) 
    restored_approx_coeff,restored_detail_coeff = DWTfn(signal, 4, 'db4') 
    d4, d3, d2, d1 = restored_detail_coeff 
    
#     print(len(d1))
#     print(len(d2))
#     print(len(d3))
#     print(len(d4))
#     raise Exception()
    
    bands = {'theta':d4,'alpha':d3,'beta':d2,'gamma':d1}
    
    band_instance = {}
    
    for band_name, band in bands.items():
        band_instance[f"{band_name}_entropy"] = entropy_fn(band)
        band_instance[f"{band_name}_energy"] = energy_fn(band)
        band_instance[f"{band_name}_mean"] = np.mean(band)
        band_instance[f"{band_name}_std"] = np.std(band)
    
    return band_instance

from scipy.stats import kurtosis, skew, entropy
def extract_time_domain_features(signal, verbose=False):
    mean = np.mean(signal)
    std = np.std(signal)
    rnge = np.max(signal) - np.min(signal)
    skewness = skew(signal)
    kurt = kurtosis(signal)
    hjorth_param_activity = std**2
    hjorth_param_mobility, hjorth_param_complexity = pyeeg.hjorth(signal)    
    #feature_vector = (mean,std,rnge,skewness,kurt,hjorth_param_activity,hjorth_param_mobility,hjorth_param_complexity)    
    feature_vector_dict = {"mean":mean,"std":std,"range":rnge,"skewness":skewness,"kurtosis":kurt,"hjorth_param_activity":hjorth_param_activity, "hjorth_param_mobility":hjorth_param_mobility, "hjorth_param_complexity":hjorth_param_complexity}
    
    if verbose : print(feature_vector_dict)
    return feature_vector_dict

# DWT Feature Extraction

5-level DWT 
Time (stats) and Time-Frequency (wavelet energy/relative/entropy)
Dataset: 120 x 281160

In [9]:
from scipy import stats
def feature_extraction(subjects, channel=[1,7,15,17,25], window_size=640, step_size=320, sample_rate=128, timedomain=True, timefreq=True, baseline=False, directory='data_python'):
    usename=False
    chan_title=str(len(channel))+"chan"
    usename1=False
    usename2=False
    if channel == [1,7,15,17,25]:
        usename1=True
    if channel == [0,1,2,3,4]:
        usename2=True
    meta = []
    
    from os import path
    if baseline: feature = "dwt_baseline"
    else: feature = "dwt"
        
    tag_name = ""
    extension = "dat"
    if directory != "data_python":
        tag_name = "custom"
        extension = "mat"        
    
    if timedomain and timefreq: csv_filename = f'data{tag_name}/{feature}/{chan_title}_time_timefreq_{int(window_size/128)}s-{step_size/128}step.csv'
    if timedomain and not timefreq: csv_filename = f'data{tag_name}/{feature}/{chan_title}_time_{int(window_size/128)}s-{step_size/128}step.csv'
    if not timedomain and timefreq: csv_filename = f'data{tag_name}/{feature}/{chan_title}_timefreq_{int(window_size/128)}s-{step_size/128}step.csv'
    print(csv_filename)
    from os import path
    if path.exists(csv_filename):
        print(f"{csv_filename} already exists.")
        return {"csv_path":csv_filename, "data":None}
    
    reuse_date_optimization = False
    if feature == "dwt_baseline":
        csv_filename_without = csv_filename.replace("dwt_baseline","dwt")
        if path.exists(csv_filename_without):
            print(f"{csv_filename_without} already exists. Will use as trial data.")
            reuse_date_optimization = True
            data_without = pd.read_csv(csv_filename_without)
    
    for sub in subjects:
        #print(f"Loading subject {sub}")
        subject_time = time.time()
        try:
            with open(f'../{directory}/s{sub}.{extension}', 'rb') as file:
                subject = pickle.load(file, encoding='latin1') #resolve the python 2 data problem by encoding : latin1
        except: 
            subject = sio.loadmat(f"../{directory}/s{sub}.{extension}")
            
            num_trials = len(subject["data"])
            for trial in range (0,num_trials):
                eeg = subject["data"][trial]

                val = 1 if subject["labels"][trial][0] >= 5 else 0
                aro = 1 if subject["labels"][trial][1] >= 5 else 0

                if val == 0 and aro == 0:
                    emotion = 0 #LALV
                if val == 0 and aro == 1:
                    emotion = 1 #HALV
                if val == 1 and aro == 0:
                    emotion = 2 #LAHV
                if val == 1 and aro == 1:
                    emotion = 3 #HAHV 
                
                three_sec = 128*3
                if baseline:
                    baseline_instance = {"Sub":sub, "Trial":trial, "Valence":val, "Arousal":aro, "Emotion":emotion}
                    for chan in channel:
                        if usename1:
                            if chan == 1: chan_name = "AF3"
                            elif chan == 7: chan_name = "T7"
                            elif chan == 15: chan_name = "Pz"
                            elif chan == 17: chan_name = "AF4"
                            elif chan == 25: chan_name =  "T8"
                        if usename2:
                            if chan == 0: chan_name = "AF3"
                            elif chan == 1: chan_name = "T7"
                            elif chan == 2: chan_name = "Pz"
                            elif chan == 3: chan_name = "AF4"
                            elif chan == 4: chan_name =  "T8"
                        baseline_slice = eeg[chan][0 : three_sec]
                                               
                        if window_size == 384:
                            #time domain                        
                            if timedomain:
                                time_domain_features = extract_time_domain_features(baseline_slice)
                                for feature_name,value in time_domain_features.items():
                                    if usename: baseline_instance[f"{chan_name}_{feature_name}"] = value
                                    else: baseline_instance[f"{chan}_{feature_name}"] = value
                            #time-frequency domain                            
                            if timefreq:
                                time_freq_feats = dwt_fn(baseline_slice)  
                                for key,value in time_freq_feats.items():
                                    if usename: baseline_instance[f"{chan_name}_{key}"] = value
                                    else: baseline_instance[f"{chan}_{key}"] = value
                                        
                        elif window_size == 128:
                            slices = [[0,128],[128,256],[256,384]]                                
                            for time_slice in slices:
                                baseline_mini_slice = baseline_slice[time_slice[0]:time_slice[1]]
                                if timedomain:
                                    time_domain_features = extract_time_domain_features(baseline_mini_slice)
                                    for feature_name,value in time_domain_features.items():
                                        if usename: 
                                            try: baseline_instance[f"{chan_name}_{feature_name}"] += value
                                            except: baseline_instance[f"{chan_name}_{feature_name}"] = value
                                        else: 
                                            try: baseline_instance[f"{chan}_{feature_name}"] += value
                                            except: baseline_instance[f"{chan}_{feature_name}"] = value
                                #time-frequency domain   
                                if timefreq:
                                    time_freq_feats = dwt_fn(baseline_mini_slice)  
                                    for key,value in time_freq_feats.items():                                        
                                        if usename: 
                                            try: baseline_instance[f"{chan_name}_{key}"] += value
                                            except: baseline_instance[f"{chan_name}_{key}"] = value
                                        else: 
                                            try: baseline_instance[f"{chan}_{key}"] += value
                                            except: baseline_instance[f"{chan}_{key}"] = value    
                                            
                                                              
                        else: raise Exception("Window size must be either 1 or 3 seconds long to use baseline.")
                    if window_size == 128:
                        info_keys = ["Sub","Trial", "Valence", "Arousal", "Emotion"] 
                        for key, value_bl in baseline_instance.items():
                            if key not in info_keys:
                                baseline_instance[key] = value_bl/3
                if not reuse_date_optimization:
                    start = three_sec
                    while start + window_size < eeg.shape[1]:
                        instance = {"Sub":sub, "Trial":trial, "Valence":val, "Arousal":aro, "Emotion":emotion}
                        for chan in channel:                        
                            eeg_slice = eeg[chan][start : start + window_size] 
                            eeg_standardized = stats.zscore(eeg_slice)                        

                            if usename1:
                                if chan == 1: chan_name = "AF3"
                                elif chan == 7: chan_name = "T7"
                                elif chan == 15: chan_name = "Pz"
                                elif chan == 17: chan_name = "AF4"
                                elif chan == 25: chan_name =  "T8"
                            if usename2:
                                if chan == 0: chan_name = "AF3"
                                elif chan == 1: chan_name = "T7"
                                elif chan == 2: chan_name = "Pz"
                                elif chan == 3: chan_name = "AF4"
                                elif chan == 4: chan_name =  "T8"

                            #time domain

                            if timedomain:
                                time_domain_features = extract_time_domain_features(eeg_slice)
                                for feature_name,value in time_domain_features.items():
                                    if usename: instance[f"{chan_name}_{feature_name}"] = value
                                    else: instance[f"{chan}_{feature_name}"] = value

                            #time-frequency domain    

                            if timefreq:
                                time_freq_feats = dwt_fn(eeg_slice)  
                                for key,value in time_freq_feats.items():
                                    if usename: instance[f"{chan_name}_{key}"] = value
                                    else: instance[f"{chan}_{key}"] = value   

                        if baseline:
                            #print(baseline_instance)
                            info_keys = ["Sub","Trial", "Valence", "Arousal", "Emotion"]
                            for key,value_from_key_ffs in instance.items():
                                if key not in info_keys:
                                    #print(value_from_key_ffs)
                                    #print(baseline_instance[key])
                                    instance[key] = value_from_key_ffs - baseline_instance[key]
                                    #print(instance[key],"\n")
                        meta.append(instance)    
                        start = start + step_size
                else:
                    subject_num = int(sub.replace('0',''))
                    data_to_use = data_without.loc[(data_without['Sub']==subject_num) & (data_without['Trial']==trial)].drop(["Sub", "Trial", "Valence", "Arousal", "Emotion"],axis=1)                   
                    #baseline subtraction
                    data_to_use = data_to_use.to_dict(orient='records')
                    for row in data_to_use:
                        instance = {"Sub":sub, "Trial":trial, "Valence":val, "Arousal":aro, "Emotion":emotion}
                        for key,value in row.items():
                            instance[key] = value - baseline_instance[key]
                        meta.append(instance)  
        print(f"Completed subject {sub} in {round(time.time()-subject_time,2)}s")
        
    df = pd.DataFrame(meta)   
    
    df.to_csv(csv_filename,index=False)
        
    return {"csv_path":csv_filename, "data":df}

In [100]:
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
def find_best_params(x, y, random_state=1):
    #take smaller sample of full dataset, retaining class distribution    
    
    skf = StratifiedKFold(n_splits = 3, shuffle=True, random_state = random_state)        
    for train_index, test_index in skf.split(x,y):
        index_to_keep = test_index
        break   
    
    x = x[x.index.isin(index_to_keep)].reset_index(drop=True)
    y = y[y.index.isin(index_to_keep)].reset_index(drop=True)
    
    #lda = LDA(n_components=1)
    #x = lda.fit_transform(x, y)
    
    skf = StratifiedKFold(n_splits = 3, shuffle=True, random_state = random_state) 
    param_grid = {'C':[1, 50, 100, 200, 300],'gamma':[0.00001,0.001,1, 50, 100], 'kernel':['rbf']}  
    grid = GridSearchCV(svc_sklearn(), param_grid, refit = True, verbose=3, cv=skf.split(x,y), n_jobs=-1, scoring = 'accuracy')
    
    scaler = StandardScaler()
    x = scaler.fit_transform(x) 
    grid.fit(x, y)
    
    best_parameters = grid.best_params_
    
    return best_parameters

In [101]:
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
def subj_indept_target(target = None, df = None, hyperparameters=None, random_state=1, verbose=False):
    
    start_time = time.time()
    if target == "val" or target == "valaro": y = df.Valence
    if target == "aro" or target == "aroval": y = df.Arousal

    if target == "val" or target == "aro": x = df.drop(['Sub','Emotion','Trial','Valence','Arousal'],axis=1)
    if target == "valaro": x = df.drop(['Sub','Emotion','Trial','Valence'],axis=1)
    if target == "aroval": x = df.drop(['Sub','Emotion','Trial','Arousal'],axis=1)
        
#     skf = StratifiedKFold(n_splits = 8, shuffle=True, random_state = random_state)        
#     for train_index, test_index in skf.split(x,y):
#         index_to_keep = test_index
#         break
    
#     hold_out_test_x = x[x.index.isin(index_to_keep)]
#     hold_out_test_y = y[y.index.isin(index_to_keep)]
    
#     x = x[~x.index.isin(index_to_keep)].reset_index(drop=True)
#     y = y[~y.index.isin(index_to_keep)].reset_index(drop=True)
    
    if not hyperparameters: params = find_best_params(x,y)
    else: params = hyperparameters
    
#     ##################TEMP#########################
#     print(params)
#     print(time.time()-start_time)
#     return None, None, None
    
#     ##################TEMP#########################
    
    
    skf = StratifiedKFold(n_splits = 6, shuffle=True, random_state = random_state)
    
    scaler = StandardScaler()
    x = scaler.fit_transform(x) 

    svm = svc_thundersvm()
    svm.set_params(**params)
    print(svm)

    acc = []
    for train_index, test_index in skf.split(x,y):
        start_time = time.time()
        x_train_fold, x_test_fold = x[train_index], x[test_index] 
        y_train_fold, y_test_fold = y[train_index], y[test_index] 
        
        #lda = LDA(n_components=1)
        #x_train_fold = lda.fit_transform(x_train_fold, y_train_fold)
        
        svm.fit(x_train_fold, y_train_fold) 
        
        #x_test_fold = lda.transform(x_test_fold)

        score_fold = svm.score(x_test_fold, y_test_fold)
        acc.append(score_fold)
        print(f"Fold completed in {time.time()-start_time}")
    
    acc = np.mean(acc)
    std = np.std(acc)
    
    print(f"{target}: {acc}±{std}")
    return (acc, params, std)

In [15]:
from sklearn.preprocessing import StandardScaler
from thundersvm import SVC as svc_thundersvm
from sklearn.svm import SVC as svc_sklearn
from sklearn.model_selection import StratifiedKFold, ParameterGrid
from sklearn.model_selection import GridSearchCV
import time

def convert_dict_to_string(dictionary):
    hyperstring = ""
    for key,value in dictionary.items():
        hyperstring+= f"{key}:{value}, "
    hyperstring = hyperstring[:-2]
    return hyperstring

def subj_indept(csv_path=None, dataframe=None, verbose=False, hyperparameters=None, results_csv="results/Results - Indept - CUSTOM.csv", to_save=True, baseline=False, custom=False):
    
    start_time = time.time()
    try:
        df = pd.read_csv(csv_path)
    except: 
        df = dataframe
    val, val_params,valstd = subj_indept_target(target = "val", df = df, hyperparameters=hyperparameters)
    aro, aro_params,arostd = subj_indept_target(target = "aro", df = df, hyperparameters=hyperparameters)
    valaro, valaro_params,valarostd = subj_indept_target(target = "valaro", df = df, hyperparameters=hyperparameters)
    aroval, aroval_params,arovalstd = subj_indept_target(target = "aroval", df = df, hyperparameters=hyperparameters)
    
    end_time = round(time.time()-start_time,2)
    
    if verbose:
        print(f"VAL: {val}")
        print(f"ARO: {aro}")
        print(f"VALARO: {valaro}")
        print(f"AROVAL: {aroval}")
        
        print(f"VAL x ARO: {val * aro}")
        print(f"VAL x AROVAL: {val * aroval}")
        print(f"ARO x VALARO: {aro * valaro}")
        
        combo_dicts = {"VALxARO":val * aro, "VALxAROVAL":val * aroval, "AROxVALARO":aro * valaro}
        combo_dicts = dict(sorted(combo_dicts.items(), key=lambda item: item[1], reverse=True))    

        best_combination = list(combo_dicts.keys())[0]
        best_acc = list(combo_dicts.values())[0]
        
        print(best_combination, best_acc)
        print(f"Completed in {end_time}s")
    if to_save:
        results = pd.read_csv(results_csv)
    
        dependency = "Independent"
        model_name = "SVM"
        if baseline: model_name += "-Baseline"
        if custom: model_name += "-CUSTOM"

        if 'time_timefreq' in csv_path:
            domain = "T-TF"
        elif 'timefreq_' in csv_path and 'time_timefreq_' not in csv_path:
            domain = "TF"
        elif 'time_' in csv_path and 'time_timefreq_' not in csv_path:
            domain = "T"

        channels = int(str(str(csv_path.split("/")[-1]).split("_")[0]).split("chan")[0])

        window_size = int(str(str(str(csv_path.split("/")[-1]).split("_")[-1]).split("-")[0]).split("s")[0])
        step_size = str(str(str(csv_path.split("/")[-1]).split("_")[-1]).split("-")[1]).split("step")[0]

        combo_dicts = {"VALxARO":val * aro, "VALxAROVAL":val * aroval, "AROxVALARO":aro * valaro}
        combo_dicts = dict(sorted(combo_dicts.items(), key=lambda item: item[1], reverse=True))    

        best_combination = list(combo_dicts.keys())[0]
        best_acc = list(combo_dicts.values())[0]

        val_params = convert_dict_to_string(val_params)
        aro_params = convert_dict_to_string(aro_params)
        valaro_params = convert_dict_to_string(valaro_params)
        aroval_params = convert_dict_to_string(aroval_params)

        time_taken = end_time

        val = round(val*100,2)
        aro = round(aro*100,2)
        valaro = round(valaro*100,2)
        aroval = round(aroval*100,2)


        new_result = {"Dependency":dependency, "Model Name":model_name, "Domain":domain, "Channels":channels, 
                      "Window Size":window_size, "Step Size": step_size, "VAL":f"{val}", "ARO":f"{aro}", 
                      "VALARO":f"{valaro}", "AROVAL":f"{aroval}", "Best Combination":best_combination, "Best Acc":best_acc, 
                      "HP_VAL":val_params, "HP_ARO":aro_params, "HP_VALARO":valaro_params, "HP_AROVAL":aroval_params, 
                      "Time taken":time_taken}

        results = results.append(new_result, ignore_index=True)
        results = results[["Dependency","Model Name","Domain","Channels", "Window Size","Step Size",
                           "VAL","ARO","VALARO","AROVAL","Best Combination","Best Acc", 
                           "HP_VAL", "HP_ARO", "HP_VALARO", "HP_AROVAL", "Time taken"]]  

        results.to_csv(results_csv, index=False)
        print(f"Completed. - {dependency}, {domain}, {channels}, {window_size}-{step_size} in {time_taken}s")
        print(f"VAL:{val}±{val}\nARO:{aro}±{arostd}\nVALARO:{valaro}±{valarostd}\nAROVAL:{aroval}±{arovalstd}\n")

In [None]:
sample_rate = 128 
subject_list = ['01','02','03','04','05','06','07','08','09','10','11','12','13','14','15','16','17','18','19','20','21','22','23','24','25','26','27','28','29','30','31','32']
channels = [1,7,15,17,25]
timefreq = True
timedomain = False
window_in_sec = 3 #[2,4,6,8]:
baseline = True
window_size = window_in_sec * sample_rate
step_size = window_size
                
start_time = time.time()
data = feature_extraction(subjects=subject_list, channel=channels, window_size=window_size, step_size=step_size, timedomain=timedomain, timefreq=timefreq, baseline=baseline)
print(f"Time taken to process dataset: {round(time.time()-start_time,2)}s.")

subj_indept(csv_path=data['csv_path'], verbose=False, baseline=baseline)

In [16]:
sample_rate = 128 
subject_list = ['01','02','03','04','05','06','07','08','09','10','11','12','13','14','15','16','17','18','19','20','21','22','23','24','25','26','27','28','29','30','31','32']

#for baseline in [True, False]:
    #for channels in [[1,7,15,17,25], list(range(0,32))]:
for channels in [[0,1,2,3,4]]:
    for timefreq in [True]:
        for timedomain in [False]:
#             if timedomain == False and timefreq == False:
#                 continue
            for window_in_sec in [1,3,5,7,9]:
                for baseline in [False, True]:
                    if baseline == True and window_in_sec > 3: continue  
                    window_size = window_in_sec * sample_rate
                    step_size = window_size

                    start_time = time.time()
                    data = feature_extraction(subjects=subject_list, channel=channels, window_size=window_size, step_size=step_size, timedomain=timedomain, timefreq=timefreq, baseline=baseline, directory='DEAP_5chan_custom_preproc')
                    print(f"Time taken to process dataset: {round(time.time()-start_time,2)}s.")
                    if baseline == True and window_in_sec == 1:
                        subj_indept(csv_path=data['csv_path'], verbose=True, baseline=baseline, custom=True, hyperparameters={'C':50,'gamma':1})
                    elif baseline == False and window_in_sec == 1:
                        subj_indept(csv_path=data['csv_path'], verbose=True, baseline=baseline, custom=True, hyperparameters={'C':1,'gamma':1})
                    else: 
                        subj_indept(csv_path=data['csv_path'], verbose=True, baseline=baseline, custom=True)

datacustom/dwt/5chan_timefreq_1s-1.0step.csv
datacustom/dwt/5chan_timefreq_1s-1.0step.csv already exists.
Time taken to process dataset: 0.0s.
SVC(C=1, gamma=1)
Fold completed in 12.223158597946167
Fold completed in 12.417818546295166
Fold completed in 12.548962116241455
Fold completed in 12.566531896591187
Fold completed in 12.12250566482544
Fold completed in 12.820863008499146
val: 0.6335824782750281±0.0
SVC(C=1, gamma=1)
Fold completed in 12.569934129714966
Fold completed in 12.475698471069336
Fold completed in 12.782794952392578
Fold completed in 12.500403642654419
Fold completed in 12.942951440811157
Fold completed in 12.77852201461792
aro: 0.6434439621039162±0.0
SVC(C=1, gamma=1)
Fold completed in 12.23851490020752
Fold completed in 12.021023035049438
Fold completed in 11.923681497573853
Fold completed in 11.91264295578003
Fold completed in 12.108025550842285
Fold completed in 11.970342874526978
valaro: 0.6556016958486046±0.0
SVC(C=1, gamma=1)
Fold completed in 12.720830202102661

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=-1)]: Done  20 tasks      | elapsed:   15.2s
[Parallel(n_jobs=-1)]: Done  75 out of  75 | elapsed:   49.8s finished


SVC(C=1, gamma=1)
Fold completed in 2.5461699962615967
Fold completed in 2.4543802738189697
Fold completed in 2.5696070194244385
Fold completed in 2.483809471130371
Fold completed in 2.53094482421875
Fold completed in 2.5392110347747803
val: 0.6486419196734348±0.0
Fitting 3 folds for each of 25 candidates, totalling 75 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=-1)]: Done  20 tasks      | elapsed:   13.2s
[Parallel(n_jobs=-1)]: Done  75 out of  75 | elapsed:   48.4s finished


SVC(C=1, gamma=1)
Fold completed in 2.6204307079315186
Fold completed in 2.5723185539245605
Fold completed in 2.4944629669189453
Fold completed in 2.559711456298828
Fold completed in 2.4872443675994873
Fold completed in 2.488931655883789
aro: 0.6596462431270698±0.0
Fitting 3 folds for each of 25 candidates, totalling 75 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=-1)]: Done  20 tasks      | elapsed:   13.4s
[Parallel(n_jobs=-1)]: Done  75 out of  75 | elapsed:   47.9s finished


SVC(C=1, gamma=1)
Fold completed in 2.5098628997802734
Fold completed in 2.4668147563934326
Fold completed in 2.5609822273254395
Fold completed in 2.449598550796509
Fold completed in 2.4129414558410645
Fold completed in 2.4480979442596436
valaro: 0.6723816483517648±0.0
Fitting 3 folds for each of 25 candidates, totalling 75 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=-1)]: Done  20 tasks      | elapsed:   13.2s
[Parallel(n_jobs=-1)]: Done  75 out of  75 | elapsed:   48.2s finished


SVC(C=1, gamma=1)
Fold completed in 2.575129985809326
Fold completed in 2.512087821960449
Fold completed in 2.581465005874634
Fold completed in 2.5916285514831543
Fold completed in 2.5717689990997314
Fold completed in 2.5498955249786377
aroval: 0.6841689943069146±0.0
VAL: 0.6486419196734348
ARO: 0.6596462431270698
VALARO: 0.6723816483517648
AROVAL: 0.6841689943069146
VAL x ARO: 0.4278742054473118
VAL x AROVAL: 0.44378068984828034
ARO x VALARO: 0.4435340282828282
VALxAROVAL 0.44378068984828034
Completed in 271.21s
Completed. - Independent, TF, 5, 3-3.0 in 271.21s
VAL:64.86±64.86
ARO:65.96±0.0
VALARO:67.24±0.0
AROVAL:68.42±0.0

datacustom/dwt_baseline/5chan_timefreq_3s-3.0step.csv
datacustom/dwt_baseline/5chan_timefreq_3s-3.0step.csv already exists.
Time taken to process dataset: 0.0s.
Fitting 3 folds for each of 25 candidates, totalling 75 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=-1)]: Done  20 tasks      | elapsed:   13.1s
[Parallel(n_jobs=-1)]: Done  75 out of  75 | elapsed:   47.1s finished


SVC(C=50, gamma=1)
Fold completed in 8.535311698913574
Fold completed in 7.854887008666992
Fold completed in 8.179536581039429
Fold completed in 7.832632303237915
Fold completed in 7.980997800827026
Fold completed in 8.601992845535278
val: 0.8070311514392063±0.0
Fitting 3 folds for each of 25 candidates, totalling 75 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=-1)]: Done  20 tasks      | elapsed:   12.8s
[Parallel(n_jobs=-1)]: Done  75 out of  75 | elapsed:   45.0s finished


SVC(C=50, gamma=1)
Fold completed in 7.436862945556641
Fold completed in 7.235740900039673
Fold completed in 7.485136032104492
Fold completed in 7.595627784729004
Fold completed in 7.479313611984253
Fold completed in 7.514719724655151
aro: 0.8140789146089364±0.0
Fitting 3 folds for each of 25 candidates, totalling 75 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=-1)]: Done  20 tasks      | elapsed:   11.9s
[Parallel(n_jobs=-1)]: Done  75 out of  75 | elapsed:   42.8s finished


SVC(C=50, gamma=1)
Fold completed in 6.140585899353027
Fold completed in 5.97540545463562
Fold completed in 6.200981140136719
Fold completed in 5.965474367141724
Fold completed in 5.919128894805908
Fold completed in 5.846118450164795
valaro: 0.8274735188872603±0.0
Fitting 3 folds for each of 25 candidates, totalling 75 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=-1)]: Done  20 tasks      | elapsed:   12.3s
[Parallel(n_jobs=-1)]: Done  75 out of  75 | elapsed:   44.0s finished


SVC(C=50, gamma=1)
Fold completed in 6.03823447227478
Fold completed in 6.072747707366943
Fold completed in 6.271881103515625
Fold completed in 6.0506751537323
Fold completed in 6.062435150146484
Fold completed in 5.966654300689697
aroval: 0.8343569384637729±0.0
VAL: 0.8070311514392063
ARO: 0.8140789146089364
VALARO: 0.8274735188872603
AROVAL: 0.8343569384637729
VAL x ARO: 0.6569870438192293
VAL x AROVAL: 0.6733520407597097
ARO x VALARO: 0.6736287441233781
AROxVALARO 0.6736287441233781
Completed in 367.74s
Completed. - Independent, TF, 5, 3-3.0 in 367.74s
VAL:80.7±80.7
ARO:81.41±0.0
VALARO:82.75±0.0
AROVAL:83.44±0.0

datacustom/dwt/5chan_timefreq_5s-5.0step.csv
datacustom/dwt/5chan_timefreq_5s-5.0step.csv already exists.
Time taken to process dataset: 0.0s.
Fitting 3 folds for each of 25 candidates, totalling 75 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=-1)]: Done  20 tasks      | elapsed:    3.2s
[Parallel(n_jobs=-1)]: Done  75 out of  75 | elapsed:   12.3s finished


SVC(C=1, gamma=1)
Fold completed in 1.2430827617645264
Fold completed in 1.224421501159668
Fold completed in 1.2023611068725586
Fold completed in 1.2240700721740723
Fold completed in 1.194342851638794
Fold completed in 1.224273443222046
val: 0.6530216746658807±0.0
Fitting 3 folds for each of 25 candidates, totalling 75 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=-1)]: Done  20 tasks      | elapsed:    3.2s
[Parallel(n_jobs=-1)]: Done  75 out of  75 | elapsed:   12.5s finished


SVC(C=1, gamma=1)
Fold completed in 1.2546775341033936
Fold completed in 1.223121166229248
Fold completed in 1.2090294361114502
Fold completed in 1.2085647583007812
Fold completed in 1.223858118057251
Fold completed in 1.2181384563446045
aro: 0.6629183384640901±0.0
Fitting 3 folds for each of 25 candidates, totalling 75 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=-1)]: Done  20 tasks      | elapsed:    3.2s
[Parallel(n_jobs=-1)]: Done  75 out of  75 | elapsed:   12.2s finished


SVC(C=1, gamma=1)
Fold completed in 1.217094898223877
Fold completed in 1.1931002140045166
Fold completed in 1.2146635055541992
Fold completed in 1.1827325820922852
Fold completed in 1.2028868198394775
Fold completed in 1.1944632530212402
valaro: 0.6769415254234197±0.0
Fitting 3 folds for each of 25 candidates, totalling 75 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=-1)]: Done  20 tasks      | elapsed:    3.3s
[Parallel(n_jobs=-1)]: Done  75 out of  75 | elapsed:   12.5s finished


SVC(C=1, gamma=1)
Fold completed in 1.2852039337158203
Fold completed in 1.195277452468872
Fold completed in 1.2140202522277832
Fold completed in 1.229724645614624
Fold completed in 1.230684757232666
Fold completed in 1.2236902713775635
aroval: 0.6917495254399567±0.0
VAL: 0.6530216746658807
ARO: 0.6629183384640901
VALARO: 0.6769415254234197
AROVAL: 0.6917495254399567
VAL x ARO: 0.43290004355054323
VAL x AROVAL: 0.45172743355212874
ARO x VALARO: 0.44875695127104
VALxAROVAL 0.45172743355212874
Completed in 84.57s
Completed. - Independent, TF, 5, 5-5.0 in 84.57s
VAL:65.3±65.3
ARO:66.29±0.0
VALARO:67.69±0.0
AROVAL:69.17±0.0

datacustom/dwt/5chan_timefreq_7s-7.0step.csv
datacustom/dwt/5chan_timefreq_7s-7.0step.csv already exists.
Time taken to process dataset: 0.0s.
Fitting 3 folds for each of 25 candidates, totalling 75 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=-1)]: Done  20 tasks      | elapsed:    1.6s
[Parallel(n_jobs=-1)]: Done  75 out of  75 | elapsed:    6.2s finished


SVC(C=1, gamma=1)
Fold completed in 0.8904564380645752
Fold completed in 0.8322446346282959
Fold completed in 0.8450183868408203
Fold completed in 0.8591063022613525
Fold completed in 0.8448514938354492
Fold completed in 0.8562452793121338
val: 0.6528019377338992±0.0
Fitting 3 folds for each of 25 candidates, totalling 75 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=-1)]: Done  20 tasks      | elapsed:    1.6s
[Parallel(n_jobs=-1)]: Done  75 out of  75 | elapsed:    6.3s finished


SVC(C=1, gamma=1)
Fold completed in 0.8975980281829834
Fold completed in 0.8547482490539551
Fold completed in 0.8674540519714355
Fold completed in 0.868049144744873
Fold completed in 0.8551244735717773
Fold completed in 0.8473021984100342
aro: 0.6619023731536178±0.0
Fitting 3 folds for each of 25 candidates, totalling 75 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=-1)]: Done  20 tasks      | elapsed:    1.6s
[Parallel(n_jobs=-1)]: Done  75 out of  75 | elapsed:    6.3s finished


SVC(C=1, gamma=1)
Fold completed in 0.8847026824951172
Fold completed in 0.845505952835083
Fold completed in 0.8266406059265137
Fold completed in 0.8335590362548828
Fold completed in 0.8314483165740967
Fold completed in 0.8323071002960205
valaro: 0.6761955411972007±0.0
Fitting 3 folds for each of 25 candidates, totalling 75 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=-1)]: Done  20 tasks      | elapsed:    1.6s
[Parallel(n_jobs=-1)]: Done  75 out of  75 | elapsed:    6.3s finished


SVC(C=1, gamma=1)
Fold completed in 0.9024190902709961
Fold completed in 0.8432130813598633
Fold completed in 0.864459753036499
Fold completed in 0.8597497940063477
Fold completed in 0.8640072345733643
Fold completed in 0.8579981327056885
aroval: 0.6841232460216862±0.0
VAL: 0.6528019377338992
ARO: 0.6619023731536178
VALARO: 0.6761955411972007
AROVAL: 0.6841232460216862
VAL x ARO: 0.43209115178534807
VAL x AROVAL: 0.44659698065176173
ARO x VALARO: 0.44757543343432205
AROxVALARO 0.44757543343432205
Completed in 49.1s
Completed. - Independent, TF, 5, 7-7.0 in 49.1s
VAL:65.28±65.28
ARO:66.19±0.0
VALARO:67.62±0.0
AROVAL:68.41±0.0

datacustom/dwt/5chan_timefreq_9s-9.0step.csv
datacustom/dwt/5chan_timefreq_9s-9.0step.csv already exists.
Time taken to process dataset: 0.0s.
Fitting 3 folds for each of 25 candidates, totalling 75 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=-1)]: Done  28 tasks      | elapsed:    1.4s
[Parallel(n_jobs=-1)]: Done  64 out of  75 | elapsed:    2.9s remaining:    0.4s
[Parallel(n_jobs=-1)]: Done  75 out of  75 | elapsed:    3.5s finished


SVC(C=1, gamma=1)
Fold completed in 0.627655029296875
Fold completed in 0.6191937923431396
Fold completed in 0.6268396377563477
Fold completed in 0.6183090209960938
Fold completed in 0.6066441535949707
Fold completed in 0.6151449680328369
val: 0.6588358131036283±0.0
Fitting 3 folds for each of 25 candidates, totalling 75 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=-1)]: Done  28 tasks      | elapsed:    1.3s
[Parallel(n_jobs=-1)]: Done  64 out of  75 | elapsed:    2.9s remaining:    0.4s
[Parallel(n_jobs=-1)]: Done  75 out of  75 | elapsed:    3.5s finished


SVC(C=1, gamma=1)
Fold completed in 0.6870689392089844
Fold completed in 0.6320264339447021
Fold completed in 0.6244521141052246
Fold completed in 0.6267085075378418
Fold completed in 0.6372458934783936
Fold completed in 0.6182699203491211
aro: 0.666536152440616±0.0
Fitting 3 folds for each of 25 candidates, totalling 75 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=-1)]: Done  28 tasks      | elapsed:    1.4s
[Parallel(n_jobs=-1)]: Done  64 out of  75 | elapsed:    3.0s remaining:    0.4s
[Parallel(n_jobs=-1)]: Done  75 out of  75 | elapsed:    3.5s finished


SVC(C=1, gamma=1)
Fold completed in 1.162623405456543
Fold completed in 0.6884472370147705
Fold completed in 0.6396713256835938
Fold completed in 0.622988224029541
Fold completed in 0.6209323406219482
Fold completed in 0.6255033016204834
valaro: 0.6752806055860089±0.0
Fitting 3 folds for each of 25 candidates, totalling 75 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=-1)]: Done  28 tasks      | elapsed:    1.4s
[Parallel(n_jobs=-1)]: Done  64 out of  75 | elapsed:    3.2s remaining:    0.5s
[Parallel(n_jobs=-1)]: Done  75 out of  75 | elapsed:    3.9s finished


SVC(C=1, gamma=1)
Fold completed in 0.699988842010498
Fold completed in 0.6784048080444336
Fold completed in 0.6935176849365234
Fold completed in 0.677016019821167
Fold completed in 0.6805086135864258
Fold completed in 0.681157112121582
aroval: 0.6842860871835031±0.0
VAL: 0.6588358131036283
ARO: 0.666536152440616
VALARO: 0.6752806055860089
AROVAL: 0.6842860871835031
VAL x ARO: 0.4391378879561772
VAL x AROVAL: 0.45083218064504355
ARO x VALARO: 0.4500989366650675
VALxAROVAL 0.45083218064504355
Completed in 32.42s
Completed. - Independent, TF, 5, 9-9.0 in 32.42s
VAL:65.88±65.88
ARO:66.65±0.0
VALARO:67.53±0.0
AROVAL:68.43±0.0

