In [None]:
from pathlib import Path
import pandas as pd
from dotenv import dotenv_values

# problem_type = 'regression'
# problem_type = 'classification'

# if problem_type == 'regression':
#     task_name = 'brain age'
ENV_VARS = dotenv_values('.env')
DPATH_RESULTS = Path(ENV_VARS['DPATH_FL_RESULTS'])
DPATH_FIGS = Path(ENV_VARS['DPATH_FL_FIGS'])
fpath_results_ba = DPATH_RESULTS / 'results-regression-age-sex-hc-aseg-5-3791.tsv'
# elif problem_type == 'classification':
#     task_name = 'cognitive decline'
fpath_results_cog = DPATH_RESULTS / 'results-classification-decline-age-case-aparc-5-3791.tsv'
# else:
#     raise ValueError(f'Unknown problem_type: {problem_type}')

pd.set_option('display.float_format', lambda x: '%.2f' % x)

DATASET_COLOUR_MAP = {
    'PPMI': '#D0A441',
    'ADNI': '#0CA789',
    'QPN': '#A6A6C6',
}

df_results = pd.concat(
    [
        # pd.read_csv(fpath_results_ba, sep='\t'),
        pd.read_csv(fpath_results_cog, sep='\t'),
    ],
    axis='index',
)
df_results = df_results.query('method != "fl_voting" and (metric == "balanced_accuracy" or metric == "mean_absolute_error")')
# df_results = df_results.query('method != "fl_voting" and test_dataset != "all" and (metric == "balanced_accuracy" or metric == "r2")')
df_results['method'] = df_results['method'].map({'silo': 'Siloed', 'mega': 'Mega-analysis', 'fl_fedavg': 'Federated'})
df_results['test_dataset'] = df_results['test_dataset'].str.upper()
df_results = df_results.reset_index(drop=True)
df_results

In [None]:
# df_results_null = df_results.query('is_null == True')

# dataset = 'ADNI'
# df_results_null.query(f'test_dataset == "{dataset}"').groupby(['method', 'metric'])['score'].describe()

# for metric, df in df_results_null.groupby('metric'):
#     print(df_results_null.query(f'metric == "{metric}"').describe())
#     print(f'===== {metric.upper()} =====')
#     df = df.groupby(['method', 'test_dataset'])
#     # print(df['score'].max())
#     print(df['score'].mean())

In [None]:
import seaborn as sns

df_results_null = df_results.query('is_null == True and test_dataset != "ALL"')
df_results_nonnull = df_results.query('is_null == False and test_dataset != "ALL"')
df_results_all = df_results.query('test_dataset == "ALL"')


bar_width = 0.8
null_width = 0.8

grid = sns.catplot(
    data=df_results_nonnull,
    x='method',
    y='score',
    hue='test_dataset',
    row='metric',
    kind='bar',
    errorbar='sd',
    order=['Siloed', 'Federated', 'Mega-analysis'],
    height=2.5,
    aspect=3,
    width=bar_width,
    sharex=False,
    sharey=False,
    palette=DATASET_COLOUR_MAP,
    alpha=0.8,
    saturation=1,
)    

for i_ax, (metric, ax) in enumerate(grid.axes_dict.items()):

    task_name = {'balanced_accuracy': 'cognitive decline', 'mean_absolute_error': 'brain age', 'r2': 'brain age'}[metric]
    print(f'===== {task_name.upper()} =====')

    ax.text(-0.08, 1.05, 'ABCDE'[i_ax], transform=ax.transAxes, size=16, weight='bold')

    for xticklabel, xtick in zip(ax.get_xticklabels(), ax.get_xticks()):

        method = xticklabel.get_text()
        df_mean_null_values = df_results_null.query(f'metric == @metric').groupby(['method', 'test_dataset'])['score'].describe()
        df_mean_all_values = df_results_all.query(f'metric == @metric').groupby(['method', 'test_dataset'])['score'].describe()

        if method in ['Siloed', 'Federated', 'Mega-analysis']:
            mean_null_values = df_mean_null_values.loc['Siloed', 'mean']
            mean_all_values = df_mean_all_values.loc[method, 'mean'].item()
        # elif xticklabel.get_text() == 'Mega-analysis':
        #     mean_null_values = df_mean_null_values.loc['Mega-analysis', 'mean']
        else:
            raise ValueError(f'Unknown method: {method}')

        print(f'----- {method.capitalize()} -----')
        print('MEAN NULL')
        print(mean_null_values)
        print('MEAN All')
        print(mean_all_values)

        if metric == 'mean_absolute_error':
            best_null_value = mean_null_values.min()
        else:
            best_null_value = mean_null_values.max()

        # ax.plot([xtick - null_width/2, xtick + null_width/2], [best_null_value, best_null_value], 'k--', alpha=0.5)
        ax.plot([xtick - null_width/2, xtick + null_width/2], [mean_all_values, mean_all_values], 'r:', alpha=0.75)
    ax.axhline(best_null_value, color='k', linestyle='--', alpha=0.5)

    # ax.set_ylabel('')
    # ax.set_title(f"{metric.capitalize().replace('_', ' ')} for {task_name} task")
    ax.set_ylabel(metric.capitalize().replace('_', ' '))
    ax.set_title(f"{task_name.capitalize()} task")
    ax.set_xlabel('')

    if metric == 'mean_absolute_error':
        arrowstyle = '->'
    else:
        arrowstyle = '<-'
    
    ax.annotate(
        '', xy=(1.05, 0.25), xycoords='axes fraction', xytext=(1.05, 0.75), 
        arrowprops=dict(arrowstyle=arrowstyle, linewidth=2, mutation_scale=20))
    ax.annotate(
        'Better\nmodel', xy=(1.1, 0.5), xycoords='axes fraction', ha='center', va='center',
    )

grid.legend.set_title('Test dataset')

In [None]:
df_results_all

In [None]:
grid.figure

In [None]:
dataset = 'QPN'
df_results_nonnull.query(f'test_dataset == "{dataset}"').groupby(['method', 'metric'])['score'].describe()


In [None]:
DPATH_FIGS.mkdir(exist_ok=True)

fpath_fig = DPATH_FIGS / f'metrics-combined.png'
# grid.savefig(fpath_fig, dpi=300)
