In [2]:
import pandas as pd
import numpy as np
import glob
import matplotlib.pyplot as plt
from collections import defaultdict
import seaborn as sns
from pprint import pprint
import json
from scipy.stats import mode
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score

%matplotlib inline

In [8]:
result_files = glob.glob('../results/baseline-e-daic-woz-modalities+class-weights/*.csv')

dfs = []
for file in result_files:
    df = pd.read_csv(file)
    df['filename'] = file.split('/')[-1]

    dfs.append(df)

dfs = pd.concat(dfs)


In [11]:
dfs["prediction_kind"].unique()

array(['last', 'mean', 'mode', 'threshold', 'mode_threshold',
       'last_presence', 'mean_presence', 'mode_presence',
       'threshold_presence', 'mode_threshold_presence'], dtype=object)

In [13]:
dfs[""]]

Index(['name', 'run_id', 'f1', 'recall', 'precision', 'auc', 'accuracy',
       'dataset', 'dataset_kind', 'model', 'seconds_per_window',
       'presence_threshold', 'modalities', 'model_args.num_layers',
       'model_args.self_attn_num_heads', 'model_args.self_attn_dim_head',
       'prediction_kind', 'filename'],
      dtype='object')

In [40]:
mode_eval = dfs[dfs['prediction_kind'] == 'mode']
mode_eval = mode_eval[mode_eval['filename'].str.contains('audiovisual-run-')]
mode_eval = mode_eval[mode_eval['dataset_kind'] == 'test']
mode_eval.sort_values(by=['seconds_per_window', 'presence_threshold', 'run_id'], inplace=True)

In [41]:
mode_eval

Unnamed: 0,name,run_id,f1,recall,precision,auc,accuracy,dataset,dataset_kind,model,seconds_per_window,presence_threshold,modalities,model_args.num_layers,model_args.self_attn_num_heads,model_args.self_attn_dim_head,prediction_kind,filename
2,baseline-e-daic-woz-modalities+class-weights:a...,1,0.148148,0.117647,0.2,0.517345,0.589286,e-daic-woz,test,baseline,6,0.25,"['edaic_audio_mfcc', 'edaic_audio_egemaps', 'e...",8,8,32,mode,temporal-evaluator:baseline-e-daic-woz-modalit...
2,baseline-e-daic-woz-modalities+class-weights:a...,2,0.0,0.0,0.0,0.550528,0.625,e-daic-woz,test,baseline,6,0.25,"['edaic_audio_mfcc', 'edaic_audio_egemaps', 'e...",8,8,32,mode,temporal-evaluator:baseline-e-daic-woz-modalit...
2,baseline-e-daic-woz-modalities+class-weights:a...,3,0.24,0.176471,0.375,0.542986,0.660714,e-daic-woz,test,baseline,6,0.25,"['edaic_audio_mfcc', 'edaic_audio_egemaps', 'e...",8,8,32,mode,temporal-evaluator:baseline-e-daic-woz-modalit...
2,baseline-e-daic-woz-modalities+class-weights:a...,4,0.2,0.117647,0.666667,0.564103,0.714286,e-daic-woz,test,baseline,6,0.25,"['edaic_audio_mfcc', 'edaic_audio_egemaps', 'e...",8,8,32,mode,temporal-evaluator:baseline-e-daic-woz-modalit...
2,baseline-e-daic-woz-modalities+class-weights:a...,5,0.086957,0.058824,0.166667,0.532428,0.625,e-daic-woz,test,baseline,6,0.25,"['edaic_audio_mfcc', 'edaic_audio_egemaps', 'e...",8,8,32,mode,temporal-evaluator:baseline-e-daic-woz-modalit...


In [42]:
grouped_mean = mode_eval.groupby(['seconds_per_window', 'presence_threshold', 'model_args.num_layers', 'model_args.self_attn_num_heads', 'model_args.self_attn_dim_head']).agg(
    {'run_id': 'count', 'f1': ['mean', 'std'], 'precision': ['mean', 'std'], 'recall': ['mean', 'std'], 'accuracy': ['mean', 'std']}).reset_index()
grouped_mean

Unnamed: 0_level_0,seconds_per_window,presence_threshold,model_args.num_layers,model_args.self_attn_num_heads,model_args.self_attn_dim_head,run_id,f1,f1,precision,precision,recall,recall,accuracy,accuracy
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,count,mean,std,mean,std,mean,std,mean,std
0,6,0.25,8,8,32,5,0.135021,0.09481,0.281667,0.253065,0.094118,0.067069,0.642857,0.047246


In [76]:
grouped_mean.columns = grouped_mean.columns.map(''.join)
grouped_mean.sort_values(by=['f1mean'], ascending=False)

Unnamed: 0,seconds_per_window,presence_threshold,num_layers,num_heads,head_dim,run_idcount,f1mean,f1std,precisionmean,precisionstd,recallmean,recallstd,accuracymean,accuracystd
23,8,0.75,8,8,32,3,0.76356,0.018412,0.718521,0.022514,0.81877,0.073513,0.724868,0.005291
11,4,0.75,8,8,32,3,0.76266,0.012336,0.716906,0.027892,0.815534,0.016816,0.723104,0.021383
18,7,0.25,8,8,32,3,0.760204,0.021113,0.728573,0.024657,0.796117,0.044491,0.726631,0.021383
22,8,0.5,8,8,32,3,0.759512,0.012188,0.762549,0.008881,0.757282,0.033632,0.738977,0.00611
5,2,0.75,8,8,32,3,0.752484,0.007379,0.667034,0.025983,0.867314,0.066086,0.689594,0.011014
25,10,0.25,8,8,32,3,0.752165,0.026042,0.705985,0.050541,0.815534,0.100896,0.708995,0.024246
7,3,0.5,8,8,32,3,0.751901,0.021258,0.670197,0.026935,0.860841,0.073513,0.691358,0.021383
17,6,0.75,8,8,32,3,0.751367,0.028589,0.731831,0.031999,0.776699,0.079469,0.72134,0.021383
8,3,0.75,8,8,32,3,0.751051,0.025666,0.70355,0.023964,0.805825,0.035005,0.708995,0.029459
10,4,0.5,8,8,32,3,0.750279,0.031298,0.709888,0.039443,0.802589,0.092276,0.710758,0.026631
