In [203]:
import numpy as np
import pandas as pd

from scipy.stats import wilcoxon, ttest_rel

In [204]:
scenario = 'dtr'

In [205]:
data = pd.read_csv(f'evaluation_{scenario}.csv')

In [206]:
if scenario == 'dtc' or scenario == 'dtr':
    data = data[data['mode'].isin(['<=', '<', 'avg_full'])]
if scenario == 'rfc' or scenario == 'rfr':
    data = data[data['mode'].isin(['<=', '<', 'avg_half'])]

In [207]:
data['mode'].drop_duplicates()

0           <
1          <=
2    avg_full
Name: mode, dtype: object

In [208]:
score = 'auc' if 'auc' in data.columns else 'r2'

In [209]:
def do_testing(pdf, metric):
    params = pdf['mode'].drop_duplicates().values
    pdf0 = pdf[pdf['mode'] == params[0]].sort_values('fold')
    pdf1 = pdf[pdf['mode'] == params[1]].sort_values('fold')

    return pd.Series({
        'wilcoxon': wilcoxon(pdf0[metric], pdf1[metric], zero_method='zsplit').pvalue
    })

In [210]:
grouped = data\
    .groupby(['name', 'mode'])\
    .apply(lambda pdf: pdf.sort_values('fold')[score].values.tolist())\
    .reset_index(drop=False)\
    .rename(columns={0: score})

In [211]:
def evaluate_one(grouped, pivot_postfix='<='):
    pivot_row = grouped[grouped['mode'] == f"{pivot_postfix}"].iloc[0]
    pivot_label = pivot_row['mode']

    other_rows = grouped[grouped['mode'] != pivot_label]

    result = {f'{score}_{pivot_label}': np.mean(pivot_row[score])}

    for idx, row in other_rows.iterrows():
        result[f'{score}_{row["mode"]}'] = np.mean(row[score])
        if row["mode"] in ['<', '<=']:
            result[f'p_{row["mode"]}'] = wilcoxon(pivot_row[score], row[score], zero_method="zsplit").pvalue
        else:
            result[f'p_{row["mode"]}_l'] = wilcoxon(pivot_row[score], row[score], zero_method="zsplit", alternative="less").pvalue
            result[f'p_{row["mode"]}_g'] = wilcoxon(pivot_row[score], row[score], zero_method="zsplit", alternative="greater").pvalue

    return pd.Series(result)

def evaluate_min_max(grouped):
    pivot_leq = grouped[grouped['mode'] == "<="].iloc[0]
    pivot_l = grouped[grouped['mode'] == "<"].iloc[0]

    score_leq = np.mean(pivot_leq[score])
    score_l = np.mean(pivot_l[score])

    if score_leq < score_l:
        min_row = pivot_leq
        max_row = pivot_l
    else:
        min_row = pivot_l
        max_row = pivot_leq

    other_rows = grouped[~grouped['mode'].isin(['<=', '<'])]

    result = {f'{score}_min': np.mean(min_row[score]),
                f'{score}_max': np.mean(max_row[score])}

    for idx, row in other_rows.iterrows():
        result[f'{score}_{row["mode"]}'] = np.mean(row[score])

        result[f'p_{row["mode"]}_min'] = wilcoxon(row[score], min_row[score], zero_method="zsplit", alternative="greater").pvalue
        result[f'p_{row["mode"]}_max'] = wilcoxon(row[score], max_row[score], zero_method="zsplit", alternative="less").pvalue

    return pd.Series(result)

def evaluate_joint(grouped):
    pivot_leq = grouped[grouped['mode'] == "<="].iloc[0]
    pivot_l = grouped[grouped['mode'] == "<"].iloc[0]

    score_leq = np.mean(pivot_leq[score])
    score_l = np.mean(pivot_l[score])

    #if score_leq < score_l:
    if wilcoxon(pivot_leq[score], pivot_l[score], zero_method="zsplit", alternative="less").pvalue < wilcoxon(pivot_l[score], pivot_leq[score], zero_method="zsplit", alternative="less").pvalue:
        min_row = pivot_leq
        max_row = pivot_l
    else:
        min_row = pivot_l
        max_row = pivot_leq

    other_rows = grouped[~grouped['mode'].isin(['<=', '<'])]

    result = {f'{score}_min': np.mean(min_row[score]),
                f'{score}_max': np.mean(max_row[score]),
                'p_neq': wilcoxon(max_row[score], min_row[score], zero_method="zsplit").pvalue}

    for idx, row in other_rows.iterrows():
        result[f'{score}_{row["mode"]}'] = np.mean(row[score])

        result[f'p_{row["mode"]}_better'] = min(wilcoxon(row[score], min_row[score], zero_method="zsplit", alternative="greater").pvalue,
                                            wilcoxon(row[score], max_row[score], zero_method="zsplit", alternative="greater").pvalue)

        result[f'p_{row["mode"]}_worse'] = min(wilcoxon(row[score], min_row[score], zero_method="zsplit", alternative="less").pvalue,
                                            wilcoxon(row[score], max_row[score], zero_method="zsplit", alternative="less").pvalue)

        result[f'p_{row["mode"]}_tbetter'] = min(ttest_rel(row[score], min_row[score], alternative="greater").pvalue,
                                            wilcoxon(row[score], max_row[score], zero_method="zsplit", alternative="greater").pvalue)

        result[f'p_{row["mode"]}_tworse'] = min(ttest_rel(row[score], min_row[score], alternative="less").pvalue,
                                            wilcoxon(row[score], max_row[score], zero_method="zsplit", alternative="less").pvalue)

    return pd.Series(result)

In [212]:
def evaluate_all(data, pivot_postfix='<='):
    return data.groupby("name").apply(lambda x: evaluate_one(x, pivot_postfix))


In [213]:
result = grouped.groupby("name").apply(evaluate_joint)

In [214]:
result

Unnamed: 0_level_0,r2_min,r2_max,p_neq,r2_avg_full,p_avg_full_better,p_avg_full_worse,p_avg_full_tbetter,p_avg_full_tworse
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
airfoil,0.858849,0.858857,1.6343550000000003e-17,0.858869,4.5042129999999995e-34,3.498213e-07,1.425827e-19,3.498213e-07
autoMPG6,0.824958,0.825267,0.001706841,0.82525,3.94412e-07,0.3245531,2.929387e-21,0.3245531
baseball,0.632275,0.633562,0.6431537,0.63451,1.860197e-12,0.9999886,2.092854e-27,0.9999886
cpu_performance,0.793773,0.795909,3.5494319999999996e-21,0.799377,2.896063e-121,0.7631098,5.274883e-55,0.7631098
daily-demand,0.675897,0.676236,0.03138893,0.676122,0.01560123,0.01946283,0.003126,0.01946283
diabetes,-0.171306,-0.17124,0.9643561,-0.17126,0.4821781,0.4821781,0.1587158,0.4821781
excitation_current,0.999823,0.999823,6.382882e-06,0.999823,4.994452e-37,1.0,3.338693e-41,1.0
laser,0.922326,0.92327,0.06117692,0.923891,1.522854e-187,1.0,1.9696310000000002e-157,1.0
maternal_health_risk,0.713822,0.715197,2.262532e-11,0.715366,1.574436e-15,4.920788e-08,3.176403e-30,4.920788e-08
medical_cost,0.761533,0.761533,1.0,0.761533,0.5,0.5,,


In [215]:
(result < 0.05).sum()

r2_min                 1
r2_max                 1
p_neq                 11
r2_avg_full            1
p_avg_full_better     13
p_avg_full_worse       6
p_avg_full_tbetter    14
p_avg_full_tworse      6
dtype: int64

In [216]:
final = evaluate_all(grouped, pivot_postfix='<=')
final

Unnamed: 0_level_0,r2_<=,r2_<,p_<,r2_avg_full,p_avg_full_l,p_avg_full_g
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
airfoil,0.858857,0.858849,1.6343550000000003e-17,0.858869,0.9999997,3.498213e-07
autoMPG6,0.824958,0.825267,0.001706841,0.82525,3.94412e-07,0.9999996
baseball,0.632275,0.633562,0.6431537,0.63451,1.860197e-12,1.0
cpu_performance,0.795909,0.793773,3.5494319999999996e-21,0.799377,0.2368902,0.7631098
daily-demand,0.675897,0.676236,0.03138893,0.676122,0.01560123,0.9843988
diabetes,-0.171306,-0.17124,0.9643561,-0.17126,0.4821781,0.5178219
excitation_current,0.999823,0.999823,6.382882e-06,0.999823,8.477609e-30,1.0
laser,0.922326,0.92327,0.06117692,0.923891,1.522854e-187,1.0
maternal_health_risk,0.715197,0.713822,2.262532e-11,0.715366,1.0,4.920788e-08
medical_cost,0.761533,0.761533,1.0,0.761533,0.5,0.5


In [217]:
(final <= 0.05).sum()

r2_<=            1
r2_<             1
p_<             11
r2_avg_full      1
p_avg_full_l     9
p_avg_full_g     3
dtype: int64

In [218]:
final = evaluate_all(grouped, pivot_postfix='<')
final

Unnamed: 0_level_0,r2_<,r2_<=,p_<=,r2_avg_full,p_avg_full_l,p_avg_full_g
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
airfoil,0.858849,0.858857,1.6343550000000003e-17,0.858869,4.5042129999999995e-34,1.0
autoMPG6,0.825267,0.824958,0.001706841,0.82525,0.6754469,0.324553
baseball,0.633562,0.632275,0.6431537,0.63451,1.14264e-05,0.999989
cpu_performance,0.793773,0.795909,3.5494319999999996e-21,0.799377,2.896063e-121,1.0
daily-demand,0.676236,0.675897,0.03138893,0.676122,0.9805372,0.019463
diabetes,-0.17124,-0.171306,0.9643561,-0.17126,0.5178219,0.482178
excitation_current,0.999823,0.999823,6.382882e-06,0.999823,4.994452e-37,1.0
laser,0.92327,0.922326,0.06117692,0.923891,1.9696310000000002e-157,1.0
maternal_health_risk,0.713822,0.715197,2.262532e-11,0.715366,1.574436e-15,1.0
medical_cost,0.761533,0.761533,1.0,0.761533,0.5,0.5


In [219]:
(final <= 0.05).sum()

r2_<             1
r2_<=            1
p_<=            11
r2_avg_full      1
p_avg_full_l     7
p_avg_full_g     3
dtype: int64