# WESAD Validation Notebook for FLIRT


In [1]:
# Import Packages
import pandas as pd
import numpy as np

import matplotlib; matplotlib.use('agg')
import matplotlib.pyplot as plt

import multiprocessing
from joblib import Parallel, delayed
from tqdm.autonotebook import trange

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix

from datetime import datetime, timedelta
from typing import List
import lightgbm as lgb
import glob2
import os 

import sys
sys.path.insert(1, '/home/fefespinola/ETHZ_Fall_2020/flirt-1')
import flirt.simple

The following function retrieves all HRV, EDA and ACC features per subject using the FLIRT pipeline


In [2]:
def get_features_per_subject(path, window_length):
    features = flirt.simple.get_features_for_empatica_archive(zip_file_path = path,
                                      window_length = window_length,
                                      window_step_size = 0.25,
                                      hrv_features = False,
                                      eda_features = True,
                                      acc_features = False,
                                      bvp_features = False,
                                      temp_features = False,
                                      debug = True)
    return features

The following function determines the time offsets of the start and end of each relevant analysis period (baseline, stress, amusement). These offsets are combined with the timestamp stating the start of recording, to determine the absolute timestamps of the sections of interest for each subject. 

In [3]:
def find_label_timestamps(csv_path, StartingTime):

    ID = csv_path.split('/', 3)[2]
    df_timestamp = pd.read_csv(glob2.glob('project_data/WESAD/' + ID + '/*quest.csv')[0], delimiter = ';', header = 1).iloc[:2, :].dropna(axis = 1)
    print('===================================')
    print('Printing the timestamp for {0}'.format(ID))
    print('===================================')
    print(df_timestamp.head())
    
    # Start/End of experiment periods
    print('\nStart of the baseline: ' + str(df_timestamp['Base'][0]))
    print('End of the baseline: ' + str(df_timestamp['Base'][1]))
    print('Start of the fun: ' + str(df_timestamp['Fun'][0]))
    print('End of the fun: ' + str(df_timestamp['Fun'][1]))
    print('Start of the stress: ' + str(df_timestamp['TSST'][0]))
    print('End of the stress: ' + str(df_timestamp['TSST'][1]))
    
    # Get start and end time and assign label into a dict
    lab_dict = {'Base':0, 'TSST':1, 'Fun':2}
    labels_times_dict = {}
    for mode in df_timestamp.columns.tolist():
        print('mode', mode)
        if mode=='Base' or mode=='Fun' or mode=='TSST':
            labels_times_dict[mode] = [StartingTime + timedelta(minutes = int(str(df_timestamp[mode][0]).split(".")[0]))+ timedelta                                         (seconds = int(str(df_timestamp[mode][0]).split(".")[1])), 
                                    StartingTime + timedelta(minutes = int(str(df_timestamp[mode][1]).split(".")[0])) + timedelta                                           (seconds = int(str(df_timestamp[mode][1]).split(".")[1])), lab_dict[mode]]
            
            #labels_times_dict[mode] = [StartingTime + timedelta(minutes = float(df_timestamp[mode][0])), 
                                  #StartingTime + timedelta(minutes = float(df_timestamp[mode][1])), lab_dict[mode]]
        
    return labels_times_dict

Plots the training and validation classification metric evolution with the number of iterations. 

In [4]:
def render_metric(eval_results, metric_name):
    ax = lgb.plot_metric(evals_result, metric=metric_name, figsize=(10, 5))
    #plt.show()
    plt.savefig('/home/fefespinola/ETHZ_Fall_2020/plots/render_metric_all_ekf_feat.png')

Plots the 10 top important classification features, i.e. the ones that influence the output the most.

In [5]:
def render_plot_importance(gbm, importance_type, max_features=10, ignore_zero=True, precision=3):
    ax = lgb.plot_importance(gbm, importance_type=importance_type,
                             max_num_features=max_features,
                             ignore_zero=ignore_zero, figsize=(12, 8),
                             precision=precision)
    #plt.show()
    plt.savefig('/home/fefespinola/ETHZ_Fall_2020/plots/feature_importance_all_ekf_feat.png')


Main function that calls the above functions, determines the relevant data to use (i.e. that within the useful recording periods of baseline, stress and amusement) using the timestampp offsets, assignes the appropriate label to each sample and returns the full data with training samples and the corresponding labels.

In [6]:
def main():
    os.chdir('/home/fefespinola/ETHZ_Fall_2020/') #local directory where the script is
    df_all = pd.DataFrame(None)
    #relevant_features = pd.DataFrame(None)
    File_Path = glob2.glob('project_data/WESAD/**/*_readme.txt', recursive=True)
    window_length = 60 # in seconds
    window_shift = 0.25 # in seconds
    for subject_path in File_Path:
        print(subject_path)
        print(subject_path.split('/', 3)[2])
        ID = subject_path.split('/', 3)[2]
        zip_path = glob2.glob('project_data/WESAD/' + ID + '/*_Data.zip')[0]
        print(zip_path)
        features = get_features_per_subject(zip_path, window_length)
        features.index.name = 'timedata'
        StartingTime = features.index[0]
        print(features)
        labels_times = find_label_timestamps(subject_path, StartingTime)
        relevant_features = features.loc[
            ((features.index >= labels_times['Base'][0]) & (features.index <= labels_times['Base'][1])) 
            | ((features.index >= labels_times['Fun'][0]) & (features.index <= labels_times['Fun'][1])) 
            | ((features.index >= labels_times['TSST'][0]) & (features.index <= labels_times['TSST'][1]))]

        relevant_features.insert(0, 'ID', ID)
        relevant_features['label'] = np.zeros(len(relevant_features))
        relevant_features.loc[(relevant_features.index>=labels_times['Fun'][0]) &
                                (relevant_features.index<=labels_times['Fun'][1]), 'label'] = labels_times['Fun'][2]
        relevant_features.loc[(relevant_features.index>=labels_times['TSST'][0]) & 
                            (relevant_features.index<=labels_times['TSST'][1]), 'label'] = labels_times['TSST'][2]

        # concatenate all subjects and add IDs
        df_all = pd.concat((df_all, relevant_features))
    
    print(df_all)

    return df_all

This function generates and saves feature matrices for the individual physiological signals 

In [7]:
def __get_subset_features(df_all, feature_name: str, eda_method:str='lpf'):
    
    if feature_name=='physio':
        small_df = df_all.loc[:, df_all.columns.str.startswith('hrv')&df_all.columns.str.startswith   ('eda')&df_all.columns.str.startswith('bvp')&df_all.columns.str.startswith('temp')]
        filename = 'features_all_' + features_name +'_' + eda_method + '_feat.csv'
    else:
        small_df = df_all.loc[:, df_all.columns.str.startswith(feature_name)]
        if feature_name=='eda':
            filename = 'features_all_' + features_name +'_' + eda_method + '_feat.csv'
        else:
            filename = 'features_all_' + features_name + '_feat.csv'
    small_df.to_csv(filename)


The following function retrieves the correct training and testing data for LOSO cross-validation. It also deals with missing data (inf and nan), and scales the features.

In [8]:
def __get_train_valid_data(df_all, cv_subject):
    scaler = StandardScaler()

    #training data
    X_train = df_all.loc[df_all['ID']!=cv_subject]  # 500 entities, each contains 10 features
    X_train = X_train.iloc[:, 1:len(df_all.columns)-1]
    X_train = X_train.replace([np.nan, np.inf, -np.inf], -1000)
    X_train = scaler.fit_transform(X_train)
    y_train = df_all.loc[df_all['ID']!=cv_subject, ['label']]  # binary target
    train_data = lgb.Dataset(X_train, label=y_train)

    #validation data
    X_test = df_all.loc[df_all['ID']==cv_subject]  # 500 entities, each contains 10 features
    X_test = X_test.iloc[:, 1:len(df_all.columns)-1]
    X_test = X_test.replace([np.nan, np.inf, -np.inf], -1000)
    X_test = scaler.transform(X_test)
    y_test = df_all.loc[df_all['ID']==cv_subject, ['label']]   # binary target
    test_data = lgb.Dataset(X_test, label=y_test, reference=train_data)

    return X_train, y_train, train_data, X_test, y_test, test_data
 

Run the evaluation script to retrieve the labeled data and train classifier to output f1-score

In [9]:
if __name__ == '__main__':
    df_all = main()
    df_all.to_csv('features_all_eda_lr_1_0_feat.csv')
    df_all = pd.read_csv('features_all_eda_lr_1_0_feat.csv')
    df_all.set_index('timedata', inplace=True)
    ID=data.ID
print(ID.unique())
    print('---start classification---')

    #parameters
    param = {'metric': 'auc_mu', 'learning_rate': 0.01, 'num_leaves': 31, 'is_unbalance':True,
        'verbose': 1, 'objective':'multiclass', 'num_class':3, 'lambda_l1':0, 'force_col_wise':True}
    
    subjects = ['S2', 'S3', 'S4', 'S10', 'S11', 'S13', 'S14', 'S15', 'S16', 'S17']

    ##### Start Classification
    f1_tot = 0
    f1_dict = {}
    #data with validation set
    for subj in subjects:
        # get data
        print('===Training for LOSO ', subj, '===')
        _, _, train_data, X_test, y_test, test_data = __get_train_valid_data(df_all, subj)
    
        evals_result = {}  # to record eval results for plotting
        
        #train normally
        bst_norm = lgb.train(param, train_data, num_boost_round=550, valid_sets=[train_data, test_data],                        evals_result=evals_result, verbose_eval=10)

        best_iteration = np.argmax(evals_result['valid_1']['auc_mu'][2:]) + 1
        print ('best iter', best_iteration)
        y_pred = bst_norm.predict(X_test, num_iteration=best_iteration)
        y_pred = np.argmax(y_pred, axis=1)

        f1_metric = f1_score(y_test, y_pred, average='macro')
        f1_dict[subj] = f1_metric
        print(f1_dict)
        f1_tot = f1_tot + f1_metric
        print(confusion_matrix(y_test, y_pred))


    f1_tot = f1_tot/len(subjects)
    print(f1_tot)
    #render_metric(evals_result, param['metric'])
    #render_plot_importance(bst_norm, importance_type='split')

project_data/WESAD/S10/S10_readme.txt
S10
project_data/WESAD/S10/S10_E4_Data.zip
Reading files
Calculating EDA features


HBox(children=(HTML(value='EDA features'), FloatProgress(value=0.0, max=27287.0), HTML(value='')))


                                  eda_phasic_mean  eda_phasic_std  \
timedata                                                            
2017-07-25 07:06:08+00:00                0.025535        0.057329   
2017-07-25 07:06:08.250000+00:00         0.025544        0.057326   
2017-07-25 07:06:08.500000+00:00         0.025552        0.057322   
2017-07-25 07:06:08.750000+00:00         0.025001        0.056911   
2017-07-25 07:06:09+00:00                0.024195        0.055855   
...                                           ...             ...   
2017-07-25 08:59:48+00:00                0.030510        0.008256   
2017-07-25 08:59:48.250000+00:00         0.028520        0.006799   
2017-07-25 08:59:48.500000+00:00         0.026669        0.005471   
2017-07-25 08:59:48.750000+00:00         0.024950        0.004266   
2017-07-25 08:59:49+00:00                0.023358        0.003174   

                                  eda_phasic_min  eda_phasic_max  \
timedata                         

HBox(children=(HTML(value='EDA features'), FloatProgress(value=0.0, max=25841.0), HTML(value='')))


                                  eda_phasic_mean  eda_phasic_std  \
timedata                                                            
2017-07-25 11:15:19+00:00                0.253668        0.380638   
2017-07-25 11:15:19.250000+00:00         0.254212        0.380368   
2017-07-25 11:15:19.500000+00:00         0.254729        0.380107   
2017-07-25 11:15:19.750000+00:00         0.252718        0.379553   
2017-07-25 11:15:20+00:00                0.248962        0.376476   
...                                           ...             ...   
2017-07-25 13:02:58+00:00                0.082369        0.042572   
2017-07-25 13:02:58.250000+00:00         0.091421        0.041026   
2017-07-25 13:02:58.500000+00:00         0.104720        0.034924   
2017-07-25 13:02:58.750000+00:00         0.121633        0.021956   
2017-07-25 13:02:59+00:00                0.136336        0.008633   

                                  eda_phasic_min  eda_phasic_max  \
timedata                         

HBox(children=(HTML(value='EDA features'), FloatProgress(value=0.0, max=27449.0), HTML(value='')))


                                  eda_phasic_mean  eda_phasic_std  \
timedata                                                            
2017-08-08 11:14:07+00:00                0.277352        0.312769   
2017-08-08 11:14:07.250000+00:00         0.278264        0.312278   
2017-08-08 11:14:07.500000+00:00         0.279153        0.311788   
2017-08-08 11:14:07.750000+00:00         0.277695        0.311326   
2017-08-08 11:14:08+00:00                0.275106        0.309431   
...                                           ...             ...   
2017-08-08 13:08:28+00:00                0.261045        0.015091   
2017-08-08 13:08:28.250000+00:00         0.265889        0.011511   
2017-08-08 13:08:28.500000+00:00         0.265238        0.012787   
2017-08-08 13:08:28.750000+00:00         0.260853        0.011879   
2017-08-08 13:08:29+00:00                0.254000        0.008412   

                                  eda_phasic_min  eda_phasic_max  \
timedata                         

HBox(children=(HTML(value='EDA features'), FloatProgress(value=0.0, max=27941.0), HTML(value='')))


                                  eda_phasic_mean  eda_phasic_std  \
timedata                                                            
2017-08-09 07:10:31+00:00                0.037355        0.054893   
2017-08-09 07:10:31.250000+00:00         0.037364        0.054887   
2017-08-09 07:10:31.500000+00:00         0.037371        0.054882   
2017-08-09 07:10:31.750000+00:00         0.036883        0.054677   
2017-08-09 07:10:32+00:00                0.036141        0.053928   
...                                           ...             ...   
2017-08-09 09:06:55+00:00                0.056980        0.005644   
2017-08-09 09:06:55.250000+00:00         0.059246        0.002721   
2017-08-09 09:06:55.500000+00:00         0.060423        0.001525   
2017-08-09 09:06:55.750000+00:00         0.060363        0.001757   
2017-08-09 09:06:56+00:00                0.059456        0.001470   

                                  eda_phasic_min  eda_phasic_max  \
timedata                         

HBox(children=(HTML(value='EDA features'), FloatProgress(value=0.0, max=26567.0), HTML(value='')))


                                  eda_phasic_mean  eda_phasic_std  \
timedata                                                            
2017-08-10 07:11:56+00:00                0.031792        0.068394   
2017-08-10 07:11:56.250000+00:00         0.031815        0.068384   
2017-08-10 07:11:56.500000+00:00         0.031836        0.068376   
2017-08-10 07:11:56.750000+00:00         0.031179        0.067878   
2017-08-10 07:11:57+00:00                0.030226        0.066635   
...                                           ...             ...   
2017-08-10 09:02:36+00:00                0.034758        0.009504   
2017-08-10 09:02:36.250000+00:00         0.032459        0.007806   
2017-08-10 09:02:36.500000+00:00         0.030328        0.006268   
2017-08-10 09:02:36.750000+00:00         0.028355        0.004879   
2017-08-10 09:02:37+00:00                0.026532        0.003625   

                                  eda_phasic_min  eda_phasic_max  \
timedata                         

HBox(children=(HTML(value='EDA features'), FloatProgress(value=0.0, max=28427.0), HTML(value='')))


                                  eda_phasic_mean  eda_phasic_std  \
timedata                                                            
2017-08-10 12:00:25+00:00                0.027419        0.058596   
2017-08-10 12:00:25.250000+00:00         0.027419        0.058596   
2017-08-10 12:00:25.500000+00:00         0.027419        0.058596   
2017-08-10 12:00:25.750000+00:00         0.026923        0.058321   
2017-08-10 12:00:26+00:00                0.026177        0.057510   
...                                           ...             ...   
2017-08-10 13:58:50+00:00                0.024589        0.006874   
2017-08-10 13:58:50.250000+00:00         0.022913        0.005615   
2017-08-10 13:58:50.500000+00:00         0.021371        0.004489   
2017-08-10 13:58:50.750000+00:00         0.019953        0.003481   
2017-08-10 13:58:51+00:00                0.018650        0.002579   

                                  eda_phasic_min  eda_phasic_max  \
timedata                         

HBox(children=(HTML(value='EDA features'), FloatProgress(value=0.0, max=28925.0), HTML(value='')))


                                  eda_phasic_mean  eda_phasic_std  \
timedata                                                            
2017-08-11 07:20:22+00:00                0.301725        0.583928   
2017-08-11 07:20:22.250000+00:00         0.302082        0.583770   
2017-08-11 07:20:22.500000+00:00         0.302438        0.583611   
2017-08-11 07:20:22.750000+00:00         0.300527        0.583577   
2017-08-11 07:20:23+00:00                0.296718        0.581990   
...                                           ...             ...   
2017-08-11 09:20:52+00:00                1.700792        0.042859   
2017-08-11 09:20:52.250000+00:00         1.697484        0.046246   
2017-08-11 09:20:52.500000+00:00         1.703569        0.049881   
2017-08-11 09:20:52.750000+00:00         1.720663        0.046354   
2017-08-11 09:20:53+00:00                1.750995        0.021516   

                                  eda_phasic_min  eda_phasic_max  \
timedata                         

HBox(children=(HTML(value='EDA features'), FloatProgress(value=0.0, max=31493.0), HTML(value='')))

# Get relevant features
