# Load data

In [None]:
from pathlib import Path

import pandas as pd
import numpy as np
from dotenv import dotenv_values

np.set_printoptions(precision=3, suppress=True)

def load_data(fpath):
    df = pd.read_csv(fpath, sep='\t', index_col='participant_id')
    df = df.dropna(subset=[TARGET_VAR], axis='index')
    if PROBLEM_TYPE == 'classification':
        df[TARGET_VAR] = df[TARGET_VAR].astype(bool)
    elif PROBLEM_TYPE == 'regression':
        df = df.query('AGE >= 55')
        # df = df[['AGE', 'SEX', 'Left-Lateral-Ventricle', 'Right-Lateral-Ventricle', 'Right-Hippocampus', 'Left-Hippocampus']]
    if 'SEX' in df.columns:
        df.loc[:, 'SEX'] = df['SEX'].astype('category')
    return df

PROBLEM_TYPE = 'classification'
# PROBLEM_TYPE = 'regression'
N_FOLDS = 5
RNG_SEED = 3791

DATASET_COLOUR_MAP = {
    'PPMI': '#D0A441',
    'ADNI': '#0CA789',
    'QPN': '#A6A6C6',
}

if PROBLEM_TYPE == 'classification':
    TARGET_VAR = 'COG_DECLINE'
    COMMON_TAGS = 'decline-age-case-aparc'
elif PROBLEM_TYPE == 'regression':
    TARGET_VAR = 'AGE'
    COMMON_TAGS = 'age-sex-hc-aseg'
else:
    raise ValueError('PROBLEM_TYPE must be either classification or regression')

ENV_VARS = dotenv_values('.env')
DPATH_DATA = Path(ENV_VARS['DPATH_FL_DATA'])
DPATH_FIGS = Path(ENV_VARS['DPATH_FL_FIGS'])
DPATH_RESULTS = Path(ENV_VARS['DPATH_FL_RESULTS'])

df_ppmi = load_data(DPATH_DATA / f'ppmi-{COMMON_TAGS}.tsv')
df_adni = load_data(DPATH_DATA / f'adni-{COMMON_TAGS}.tsv')
df_qpn = load_data(DPATH_DATA / f'qpn-{COMMON_TAGS}.tsv')
df_all = pd.concat(
    {
        'ppmi': df_ppmi, 
        'adni': df_adni, 
        'qpn': df_qpn,
    },
    axis=0,
)
df_all.index.names = ['dataset', df_all.index.names[-1]]

for label, df in zip(['ppmi', 'adni', 'qpn', 'all'], [df_ppmi, df_adni, df_qpn, df_all]):
# for label, df in zip(['ppmi', 'qpn', 'all'], [df_ppmi, df_qpn, df_all]):
    print(f'{label}: {df.shape}')
    if PROBLEM_TYPE == 'classification':
        print(f'\t{TARGET_VAR}: {df[TARGET_VAR].value_counts(dropna=False).to_dict()}')

# reorder columns (needed for coefficient comparison/averaging)
df_ppmi = df_ppmi[df_all.columns]
df_adni = df_adni[df_all.columns]
df_qpn = df_qpn[df_all.columns]

DF_RESULTS_TO_SAVE = pd.DataFrame(columns=['method', 'is_null', 'test_dataset', 'metric', 'i_fold', 'score'])

def add_scores(method, test_dataset, metric, scores, is_null):
    for i_fold, score in enumerate(scores):
        DF_RESULTS_TO_SAVE.loc[len(DF_RESULTS_TO_SAVE)] = [method, is_null, test_dataset, metric, i_fold, score]


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme(style='ticks')

if PROBLEM_TYPE == 'regression':
    fig, ax = plt.subplots(figsize=(2, 6))

    df_to_plot = df_all.reset_index().copy()
    df_to_plot['dataset'] = df_to_plot['dataset'].map(lambda x: x.upper())
    sns.kdeplot(
        data=df_to_plot,
        y=TARGET_VAR,
        hue='dataset',
        ax=ax,
        palette=DATASET_COLOUR_MAP,
        fill=True,
        linewidth=0,
        alpha=0.5,
        common_norm=False,
    )
    ax.set_ylabel('Age')
    legend = ax.get_legend()
    legend.set_title('')
    legend.set_loc('lower right')
    legend.get_frame().set_linewidth(0.0)
    ax.axhline(y=55, color='grey', linestyle='--')
    sns.despine(ax=ax)
    xticks = [0, 0.05]
    ax.set_xticks(xticks)
    ax.set_xticklabels([f'{x:.2f}' for x in xticks])
    ax.set_ylim(bottom=15, top=95)

elif PROBLEM_TYPE == 'classification':
    df_to_plot = df_all.reset_index().copy()
    df_to_plot['dataset'] = df_to_plot['dataset'].map(lambda x: x.upper())
    df_to_plot[TARGET_VAR] = df_to_plot[TARGET_VAR].map({True: 'Cog. decline', False: 'Stable'})
    fig = sns.catplot(
        data=df_to_plot,
        y=TARGET_VAR,
        row='dataset',
        kind='count',
        sharex=False,
        height=2,
        aspect=1.5,
        palette=DATASET_COLOUR_MAP,
        hue='dataset',
        legend=False,
        saturation=1,
        alpha=0.8,
    )
    xticks = {'PPMI': [0, 1000], 'ADNI': [0, 400], 'QPN': [0, 40]}
    for i_ax, (dataset, ax) in enumerate(fig.axes_dict.items()):
        print(f'{dataset}: {ax.get_xlim()[1]}')
        ax.set_ylabel('')
        # ax.set_yticklabels(['Stable', 'Decline'])
        # ax.set_xlabel('Count')
        ax.set_title(dataset)
        ax.set_xticks(xticks[dataset])
        ax.set_xticklabels(xticks[dataset])
        if i_ax == len(fig.axes_dict) - 1:
            ax.set_xlabel('Count')
        ax.containers[i_ax].patches[1].set_alpha(0.4)

else:
    raise ValueError('PROBLEM_TYPE must be either classification or regression')

DPATH_FIGS.mkdir(exist_ok=True)

fpath_out = DPATH_FIGS / f'data-{PROBLEM_TYPE}.png'
# fig.savefig(fpath_out, dpi=300, bbox_inches='tight')
print(fpath_out)

# for dataset in ['ppmi', 'adni', 'qpn']:
#     df_all.loc[dataset, 'AGE'].hist(ax=ax, alpha=0.5, label=dataset)
# ax.legend()
# fig

In [None]:
from skrub import TableReport
TableReport(df_all)

# Base ML model and helper function for fake data

In [None]:
import numpy as np
from sklearn.ensemble import HistGradientBoostingClassifier, GradientBoostingClassifier, HistGradientBoostingRegressor
from sklearn.cross_decomposition import PLSRegression
from sklearn.linear_model import LogisticRegressionCV, SGDClassifier, SGDRegressor, LassoCV, RidgeCV, LinearRegression
from sklearn.model_selection import cross_validate, StratifiedKFold, GridSearchCV, KFold, GroupKFold
from sklearn.metrics import balanced_accuracy_score, roc_auc_score, r2_score, root_mean_squared_error, mean_absolute_error, explained_variance_score
from skrub import tabular_learner
from scipy.stats import pearsonr
from sklearn.metrics import make_scorer

def neg_root_mean_squared_error(*args, **kwargs):
    return -root_mean_squared_error(*args, **kwargs)

def neg_mean_absolute_error(*args, **kwargs):
    return -mean_absolute_error(*args, **kwargs)

def get_metrics_map(include_roc_auc=True):
    if PROBLEM_TYPE == 'classification':
        metrics_dict = {
            "balanced_accuracy": balanced_accuracy_score,
        }
        # if include_roc_auc:
        #     # invalid, this is getting the preds instead of the scores
        #     metrics_dict["roc_auc"] = roc_auc_score
    elif PROBLEM_TYPE == 'regression':
        metrics_dict = {
            "r2": r2_score,
            "neg_mean_absolute_error": neg_mean_absolute_error,
            "explained_variance": explained_variance_score,
            'corr': (lambda x, y: pearsonr(np.squeeze(x), np.squeeze(y))[0]),
            "mean_absolute_error": mean_absolute_error,
        }
    else:
        raise ValueError(f"PROBLEM_TYPE must be either 'classification' or 'regression'")
    return metrics_dict

def get_X_y(df_Xy):
    y = df_Xy[TARGET_VAR].dropna()
    X = df_Xy.drop(columns=TARGET_VAR).loc[y.index]
    return X, y

def ml_helper(df_Xy, pipeline, null=False, cv=None):

    metrics = get_metrics_map()
    for metric_name in metrics:
        metrics[metric_name] = make_scorer(metrics[metric_name])

    X, y = get_X_y(df_Xy)

    if null:
        y = y.sample(frac=1, replace=False, random_state=RNG_SEED)
    
    if cv is None:
        if PROBLEM_TYPE == 'classification':
            cv = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=RNG_SEED)
        elif PROBLEM_TYPE == 'regression':
            cv = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=RNG_SEED)
            cv = cv.split(X, get_age_groups(df_Xy))
        else:
            raise ValueError(f"PROBLEM_TYPE must be either 'classification' or 'regression'")

    results = cross_validate(
        pipeline,
        X,
        y,
        cv=cv,
        return_train_score=True,
        return_estimator=True,
        return_indices=True,
        scoring=metrics,
    )
    results['X'] = X
    results['y'] = y
    if metrics is None:
        metrics = ['score']
    for metric in metrics:
        train_scores = results[f"train_{metric}"]
        test_scores = results[f"test_{metric}"]
        print(f"{metric}")
        print(
            f"\ttrain: {train_scores.mean():.2f} ({train_scores.std():.2f})\t {train_scores}"
        )
        print(f"\ttest: {test_scores.mean():.2f} ({test_scores.std():.2f})\t {test_scores}")
    return results

def get_fake_X_y():
    X_fake = df_all.drop(columns=TARGET_VAR).iloc[[0] * N_FOLDS * 2]
    y_fake = pd.Series([0, 1] * N_FOLDS)
    if 'SEX' in X_fake.columns:
        index_cols = X_fake.index.names
        X_fake = X_fake.reset_index()
        X_fake.loc[X_fake.index[0], 'SEX'] = 0
        X_fake.loc[X_fake.index[1], 'SEX'] = 1
        X_fake = X_fake.set_index(index_cols)
    return X_fake, y_fake

def get_age_groups(df, bin_size=10):
    min_age = df['AGE'].min() // bin_size * bin_size
    max_age = (df['AGE'].max() // bin_size + 1) * bin_size
    bins = np.arange(min_age, max_age+1, bin_size)
    labels = [f"{int(bins[i])}-{int(bins[i+1])}" for i in range(len(bins)-1)]
    age_groups = pd.cut(df['AGE'], bins=bins, labels=labels)
    # age_groups = pd.cut(df['AGE'], bins=2)# labels=labels)
    return age_groups

if PROBLEM_TYPE == 'classification':
    model = LogisticRegressionCV(max_iter=1000, cv=StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=RNG_SEED), class_weight="balanced")
    # model = LogisticRegressionCV(max_iter=1000, cv=StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=RNG_SEED), penalty='l1', solver='saga', class_weight="balanced")
    pipeline = tabular_learner(model)
    pipeline.set_params(simpleimputer__add_indicator=False, standardscaler='passthrough')
elif PROBLEM_TYPE == 'regression':
    # model = HistGradientBoostingRegressor(max_iter=5)
    # model = LinearRegression()
    model = RidgeCV(alphas=np.linspace(0.1, 10000, 100))#(cv=KFold(n_splits=N_FOLDS, shuffle=True, random_state=RNG_SEED))#, alphas=(0.1, 1, 10, 100, 1000))
    # model = LassoCV(cv=KFold(n_splits=N_FOLDS, shuffle=True, random_state=RNG_SEED), max_iter=10000)
    # model = PLSRegression(n_components=2)
    pipeline = tabular_learner(model)
    if 'simpleimputer' in pipeline.named_steps:
        pipeline.set_params(simpleimputer__add_indicator=False)
    # if 'standardscaler' in pipeline.named_steps:
    #     pipeline.set_params(standardscaler='passthrough')
else:
    raise ValueError(f"PROBLEM_TYPE must be either 'classification' or 'regression'")

pipeline


# Experiment 1: individual datasets

In [None]:
from collections import defaultdict

print(f"===== PPMI (sample size {len(df_ppmi)}) =====")
results_ppmi = ml_helper(df_ppmi, pipeline)

print(f"===== ADNI (sample size {len(df_adni)}) =====")
results_adni = ml_helper(df_adni, pipeline)

print(f"===== QPN (sample size {len(df_qpn)}) =====")
results_qpn = ml_helper(df_qpn, pipeline)

results_all = defaultdict(list)
for metric_name, metric_func in get_metrics_map().items():
    for i_fold in range(N_FOLDS):
        y_test_all = []
        y_pred_all = []
        for results, test_dataset in zip([results_ppmi, results_adni, results_qpn], ['ppmi', 'adni', 'qpn']):
        # for results, test_dataset in zip([results_ppmi, results_qpn], ['ppmi', 'qpn']):
            estimator_silo = results['estimator'][i_fold]
            idx_test = results['indices']['test'][i_fold]
            X_test = results['X'].iloc[idx_test]
            y_test = results['y'].iloc[idx_test]
            if metric_name == 'roc_auc':
                # technically scores but renaming to pred for concatenation
                y_test_pred = pd.Series(estimator_silo.predict_proba(X_test)[:, 1])
                y_test_pred.index = y_test_pred.index
            else:
                y_test_pred = pd.Series(estimator_silo.predict(X_test))
                y_test_pred.index = y_test.index
            y_test_all.append(y_test)
            y_pred_all.append(y_test_pred)
        y_test_all_concat = pd.concat(y_test_all)
        y_pred_all_concat = pd.concat(y_pred_all)
        results_all[f'test_{metric_name}'].append(metric_func(y_test_all_concat, y_pred_all_concat))

for results, test_dataset in zip([results_ppmi, results_adni, results_qpn, results_all], ['ppmi', 'adni', 'qpn', 'all']):
# for results, test_dataset in zip([results_ppmi, results_qpn, results_all], ['ppmi', 'qpn', 'all']):
    for metric in get_metrics_map():
        add_scores('silo', test_dataset, metric, results[f'test_{metric}'], is_null=False)



In [None]:
# from sklearn.metrics import ConfusionMatrixDisplay
# import matplotlib.pyplot as plt

# fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(12, 3))

# for i_ax, (y_test, y_pred) in enumerate(zip(y_test_all, y_pred_all)):
#     ax = axes[i_ax]
#     ConfusionMatrixDisplay.from_predictions(y_test, y_pred, ax=ax)
#     ax.set_title(['ppmi', 'adni', 'qpn'][i_ax])
#     # print('-----')
#     # print(f'{score_fn(y_test, y_pred):.3f} ({len(y_test)})')
#     # manual_avg += score_fn(y_test, y_pred) * len(y_test)
#     # assert y_test_all_concat.loc[y_test.index].equals(y_test)
#     # assert y_pred_all_concat.loc[y_pred.index].equals(y_pred)

# ConfusionMatrixDisplay.from_predictions(y_test_all_concat, y_pred_all_concat, ax=axes[-1])
# axes[-1].set_title('all')
# fig.tight_layout()

In [None]:
# manual_avg = 0
# if PROBLEM_TYPE == 'classification':
#     score_fn = get_metrics_map()['balanced_accuracy']
# elif PROBLEM_TYPE == 'regression':
#     score_fn = get_metrics_map()['mean_absolute_error']
# for y_test, y_pred in zip(y_test_all, y_pred_all):
#     print('-----')
#     print(f'{score_fn(y_test, y_pred):.3f} ({len(y_test)})')
#     manual_avg += score_fn(y_test, y_pred) * len(y_test)
#     assert y_test_all_concat.loc[y_test.index].equals(y_test)
#     assert y_pred_all_concat.loc[y_pred.index].equals(y_pred)
# print('-----')
# print(f'{score_fn(y_test_all_concat, y_pred_all_concat):.3f} ({len(y_test_all_concat)})')
# print(manual_avg / len(y_test_all_concat))

In [None]:
# DF_RESULTS_TO_SAVE.query('method == "silo" and test_dataset == "all" and metric == "mean_absolute_error"')
# DF_RESULTS_TO_SAVE

In [None]:
print(f"===== PPMI (NULL) =====")
results_ppmi_null = ml_helper(df_ppmi, pipeline, null=True)

print(f"===== ADNI (NULL) =====")
results_adni_null = ml_helper(df_adni, pipeline, null=True)

print(f"===== QPN (NULL) =====")
results_qpn_null = ml_helper(df_qpn, pipeline, null=True)

for results, test_dataset in zip([results_ppmi_null, results_adni_null, results_qpn_null], ['ppmi', 'adni', 'qpn']):
# for results, test_dataset in zip([results_ppmi_null, results_qpn_null], ['ppmi', 'qpn']):
    for metric in get_metrics_map().keys():
        add_scores('silo', test_dataset, metric, results[f'test_{metric}'], is_null=True)

# Experiment 2: aggregated dataset

In [None]:
# _ = ml_helper(df_all, pipeline, null=False)

In [None]:
# from sklearn.dummy import DummyClassifier, DummyRegressor
# if PROBLEM_TYPE == 'classification':
#     model_null = DummyClassifier()
#     pipeline_null = tabular_learner(model_null)
#     pipeline_null.set_params(simpleimputer__add_indicator=False, standardscaler='passthrough')
# elif PROBLEM_TYPE == 'regression':
#     model_null = DummyRegressor(strategy='mean')
#     pipeline_null = tabular_learner(model_null)
#     if 'simpleimputer' in pipeline.named_steps:
#         pipeline_null.set_params(simpleimputer__add_indicator=False)

# results_all_null = ml_helper(df_all, pipeline_null)

# # for metric in get_metrics_map().keys():
# #     add_scores('mega', 'all', metric, results[f'test_{metric}'], is_null=False)

In [None]:
from collections import defaultdict
import numpy as np
from sklearn.base import clone

def mega_helper(null=False):

    metric_map = get_metrics_map()

    results_mega = defaultdict(lambda: defaultdict(list))
    for i_fold in range(N_FOLDS):
        global estimator_mega
        estimator_mega = clone(pipeline)
        # estimator_ppmi = results_ppmi['estimator'][i_fold]
        # estimator_adni = results_adni['estimator'][i_fold]
        # estimator_qpn = results_qpn['estimator'][i_fold]
        indices_train_ppmi = results_ppmi['indices']['train'][i_fold]
        indices_train_adni = results_adni['indices']['train'][i_fold]
        indices_train_qpn = results_qpn['indices']['train'][i_fold]
        global X_train, y_train
        X_train = pd.concat({
            'ppmi': results_ppmi['X'].iloc[indices_train_ppmi],
            'adni': results_adni['X'].iloc[indices_train_adni],
            'qpn': results_qpn['X'].iloc[indices_train_qpn],
            # 'ppmi': pd.DataFrame(estimator_ppmi[:-1].transform(results_ppmi['X'].iloc[indices_train_ppmi]), columns=results_ppmi['X'].columns, index=results_ppmi['X'].iloc[indices_train_ppmi].index),
            # 'adni': pd.DataFrame(estimator_adni[:-1].transform(results_adni['X'].iloc[indices_train_adni]), columns=results_adni['X'].columns, index=results_adni['X'].iloc[indices_train_adni].index),
            # 'qpn': pd.DataFrame(estimator_qpn[:-1].transform(results_qpn['X'].iloc[indices_train_qpn]), columns=results_qpn['X'].columns, index=results_qpn['X'].iloc[indices_train_qpn].index),
        })
        X_train.index.names = ['dataset', X_train.index.names[-1]]
        y_train = pd.concat({
            'ppmi': results_ppmi['y'].iloc[indices_train_ppmi],
            'adni': results_adni['y'].iloc[indices_train_adni],
            'qpn': results_qpn['y'].iloc[indices_train_qpn],
        })
        y_train.index.names = ['dataset', y_train.index.names[-1]]
        if null:
            # print('Shuffling y_train')
            y_train = y_train.sample(frac=1, replace=False, random_state=RNG_SEED)
        # y_train =  y_train.sample(frac=1, replace=False, random_state=RNG_SEED)
        # if not null:
        #     X_train = X_train.loc[y_train.index]
            
        estimator_mega.fit(X_train, y_train)

        dataset_map = {
            'ppmi': results_ppmi,
            'adni': results_adni,
            'qpn': results_qpn,
        }
        for label, results in dataset_map.items():
            # estimator = results['estimator'][i_fold]
            idx_test = results["indices"]["test"][i_fold]
            global X_test, y_test
            # X_test = estimator[:-1].transform(results['X'].iloc[idx_test])
            X_test = results['X'].iloc[idx_test]
            y_test = results['y'].iloc[idx_test]

            for metric_name, metric_func in metric_map.items():
                if metric_name == 'roc_auc':
                    y_score = estimator_mega.predict_proba(X_test)[:, 1]
                    metric = metric_func(y_test, y_score)
                else:
                    global y_pred
                    y_pred = estimator_mega.predict(X_test)
                    metric = metric_func(y_test, y_pred)
                results_mega[label][metric_name].append(metric)

        indices_test_ppmi = results_ppmi['indices']['test'][i_fold]
        indices_test_adni = results_adni['indices']['test'][i_fold]
        indices_test_qpn = results_qpn['indices']['test'][i_fold]
        # global X_test_all, y_test_all
        X_test_all = pd.concat({
            'ppmi': results_ppmi['X'].iloc[indices_test_ppmi],
            'adni': results_adni['X'].iloc[indices_test_adni],
            'qpn': results_qpn['X'].iloc[indices_test_qpn],
            # 'ppmi': pd.DataFrame(estimator_ppmi[:-1].transform(results_ppmi['X'].iloc[indices_test_ppmi]), columns=results_ppmi['X'].columns, index=results_ppmi['X'].iloc[indices_test_ppmi].index),
            # 'adni': pd.DataFrame(estimator_adni[:-1].transform(results_adni['X'].iloc[indices_test_adni]), columns=results_adni['X'].columns, index=results_adni['X'].iloc[indices_test_adni].index),
            # 'qpn': pd.DataFrame(estimator_qpn[:-1].transform(results_qpn['X'].iloc[indices_test_qpn]), columns=results_qpn['X'].columns, index=results_qpn['X'].iloc[indices_test_qpn].index),
 
        })
        X_test_all.index.names = ['dataset', X_test_all.index.names[-1]]
        global y_test_all
        y_test_all = pd.concat({
            'ppmi': results_ppmi['y'].iloc[indices_test_ppmi],
            'adni': results_adni['y'].iloc[indices_test_adni],
            'qpn': results_qpn['y'].iloc[indices_test_qpn],
        })
        y_test_all.index.names = ['dataset', y_test_all.index.names[-1]]
        for metric_name, metric_func in metric_map.items():
            if metric_name == 'roc_auc':
                y_test_all_score = estimator_mega.predict_proba(X_test_all)[:, 1]
                metric_test = metric_func(y_test_all, y_test_all_score)
                y_train_score = estimator_mega.predict_proba(X_train)[:, 1]
                metric_train = metric_func(y_train, y_train_score)
            else:
                global y_test_all_pred
                y_test_all_pred = pd.Series(estimator_mega.predict(X_test_all))
                y_test_all_pred.index = y_test_all.index
                metric_test = metric_func(y_test_all, y_test_all_pred)
                y_train_pred = pd.Series(estimator_mega.predict(X_train))
                y_train_pred.index = y_train.index
                metric_train = metric_func(y_train, y_train_pred)
            results_mega['all'][metric_name].append(metric_test)
            results_mega['train'][metric_name].append(metric_train)
        # break

    for dataset_name in results_mega:
        print(f"===== {dataset_name.upper()} =====")
        for metric_name, metric_values in results_mega[dataset_name].items():
            metric_values = np.array(metric_values)
            print(f"{metric_name}: {metric_values.mean():.2f} ({metric_values.std():.2f})\t{metric_values}")
            if dataset_name not in ['train']:
                # add_scores('mega-split_test', dataset_name, metric_name, metric_values, is_null=null)
                add_scores('mega', dataset_name, metric_name, metric_values, is_null=null)

mega_helper(null=False)

In [None]:
# # manual_avg = 0
# # if PROBLEM_TYPE == 'classification':
# #     score_fn = get_metrics_map()['balanced_accuracy']
# # elif PROBLEM_TYPE == 'regression':
# #     score_fn = get_metrics_map()['mean_absolute_error']

# # for dataset in ['ppmi', 'adni', 'qpn']:
# #     y_test = y_test_all.loc[dataset]
# #     y_pred = y_test_all_pred.loc[dataset]
# #     print('-----')
# #     print(f'{score_fn(y_test, y_pred):.3f} ({len(y_test)})')
# #     manual_avg += score_fn(y_test, y_pred) * len(y_test)
# #     # assert y_test_all_concat.loc[y_test.index].equals(y_test)
# #     # assert y_pred_all_concat.loc[y_pred.index].equals(y_pred)
# # print('-----')
# # print(f'{score_fn(y_test_all, y_test_all_pred):.3f} ({len(y_test_all)})')
# # print(manual_avg / len(y_test_all))

# from sklearn.metrics import ConfusionMatrixDisplay
# import matplotlib.pyplot as plt

# fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(12, 3))

# for i_ax, dataset in enumerate(['ppmi', 'adni', 'qpn']):
# # for i_ax, (y_test, y_pred) in enumerate(zip(y_test_all, y_pred_all)):
#     ax = axes[i_ax]
#     y_test = y_test_all.loc[dataset]
#     y_pred = y_test_all_pred.loc[dataset]
#     ConfusionMatrixDisplay.from_predictions(y_test, y_pred, ax=ax)
#     ax.set_title(['ppmi', 'adni', 'qpn'][i_ax])
#     # print('-----')
#     # print(f'{score_fn(y_test, y_pred):.3f} ({len(y_test)})')
#     # manual_avg += score_fn(y_test, y_pred) * len(y_test)
#     # assert y_test_all_concat.loc[y_test.index].equals(y_test)
#     # assert y_pred_all_concat.loc[y_pred.index].equals(y_pred)

# ConfusionMatrixDisplay.from_predictions(y_test_all, y_test_all_pred, ax=axes[-1])
# axes[-1].set_title('all')
# fig.tight_layout()

In [None]:
# fig, ax = plt.subplots(figsize=(5,5))

# y_test_all_pred.name = 'y_pred'
# y_test_all.name = 'y_test'
# df_tmp = pd.merge(y_test_all, y_test_all_pred, left_index=True, right_index=True)
# # pd.concat(
# #     {
# #         'y_test': y_test_all.reset_index(),
# #         'y_pred': y_test_all_pred.reset_index(),
# #     },
# # ).reset_index(level=0)
# ax = sns.scatterplot(data=df_tmp, x='y_test', y='y_pred', hue='dataset')
# ax.set_xlim(55, 87)
# ax.set_ylim(55, 87)
# ax.plot([55, 87], [55, 87], color='grey', linestyle='--')


In [None]:
# df_tmp = pd.merge(y_test_all, y_test_all_pred, left_index=True, right_index=True)
# df_tmp['diff'] = (df_tmp['y_pred'] - df_tmp['y_test']).abs()
# df_tmp.reset_index().groupby('dataset')['diff'].mean()
# # df_tmp = df_tmp.sort_values('diff', ascending=False)
# # print(df_tmp['diff'].mean())
# # df_tmp = df_tmp.query('diff > 4.6')
# # print(df_tmp.reset_index()['dataset'].value_counts())
# # sns.stripplot(data=df_tmp, y='diff', hue='dataset')
# # df_tmp

In [None]:
# import matplotlib.pyplot as plt

# fig, ax = plt.subplots(figsize=(5, 5))
# ax.scatter(y_test_all, y_test_all_pred, color=y_test_all.index.get_level_values('dataset').map({'ppmi': 'red', 'adni': 'blue', 'qpn': 'green'}))
# ax.set_xlim(20, 95)
# ax.set_ylim(20, 95)
# ax.plot([20, 95], [20, 95], color='black', linestyle='--', zorder=-1)
# ax.set_xlabel('true')
# ax.set_ylabel('predicted')
# fig

In [None]:
# df = pd.concat({'true': y_test_all['adni'], 'pred': y_test_all_pred['adni']}, axis=1)
# df['diff'] = (df['true'] - df['pred']).abs()
# df = df.sort_values('diff', ascending=False)

# df

In [None]:
# np.array_equal(y_test_all_pred['adni'].values, y_pred)

In [None]:

# print(r2_score(
#     y_test,
#     results_adni['estimator'][-1].predict(X_test),
# ))
# print(r2_score(
#     y_test,
#     estimator_mega.predict(X_test),
# ))

In [None]:
# print(results_ppmi['estimator'][-1][-1].alpha_)
# print(results_adni['estimator'][-1][-1].alpha_)
# print(results_qpn['estimator'][-1][-1].alpha_)
# print(estimator_mega[-1].alpha_)

In [None]:
# import matplotlib.pyplot as plt
# cols = ['Left-Lateral-Ventricle', 'Right-Lateral-Ventricle', 'Right-Hippocampus', 'Left-Hippocampus']
# fig, axes = plt.subplots(1, len(cols), figsize=(15, 5))
# for col, ax in zip(cols, axes):
#     x = pd.concat([X_train[col], X_test[col]])
#     y = pd.concat([y_train, y_test])
#     ax.scatter(x, y, alpha=0.5, c=[0]*len(X_train) + [1]*len(X_test))
#     ax.set_title(col)
# fig

In [None]:
# df_tmp = pd.DataFrame({'y_test': y_test, 'y_pred': y_pred})
# df_tmp['diff'] = (df_tmp['y_pred'] - df_tmp['y_test']).abs()
# df_tmp.sort_values('diff', ascending=False)

In [None]:
# for col in ['Left-Lateral-Ventricle', 'Right-Lateral-Ventricle', 'Right-Hippocampus', 'Left-Hippocampus']:
#     print(f'{col}: {pearsonr(y_train, X_train[col])}')


In [None]:
# for col in ['Left-Lateral-Ventricle', 'Right-Lateral-Ventricle', 'Right-Hippocampus', 'Left-Hippocampus']:
#     print(f'{col}: {pearsonr(y_test, X_test[col])}')


In [None]:
print("===== ALL DATASETS (NULL) =====")
mega_helper(null=True)

# Experiment 3: federated learning

## Voting

In [None]:
from collections import defaultdict

import numpy as np
from sklearn.ensemble import VotingClassifier, VotingRegressor
from scipy.special import softmax


metric_map = get_metrics_map(include_roc_auc=False)

X_fake, y_fake = get_fake_X_y()

results_fl_voting = defaultdict(lambda: defaultdict(list))
for i_fold in range(N_FOLDS):

    estimator_ppmi = results_ppmi["estimator"][i_fold]
    estimator_adni = results_adni["estimator"][i_fold]
    estimator_qpn = results_qpn["estimator"][i_fold]

    estimators = [
        ("ppmi", estimator_ppmi),
        ("adni", estimator_adni),
        ("qpn", estimator_qpn)
    ]

    if PROBLEM_TYPE == 'classification':
        score_ppmi = results_ppmi["test_balanced_accuracy"][i_fold]
        score_adni = results_adni["test_balanced_accuracy"][i_fold]
        score_qpn = results_qpn["test_balanced_accuracy"][i_fold]
        scores_all = np.array([score_ppmi, score_adni, score_qpn])
        # weights = (scores_all-0.5) / scores_all.sum()
        weights = softmax(scores_all)

        # n_samples_ppmi = len(results_ppmi['indices']['test'][i_fold])
        # n_samples_adni = len(results_adni['indices']['test'][i_fold])
        # n_samples_qpn = len(results_qpn['indices']['test'][i_fold])
        # n_samples_all = np.array([n_samples_ppmi, n_samples_adni, n_samples_qpn])
        # weights = n_samples_all / n_samples_all.sum()
        
        voter = VotingClassifier(
            estimators=estimators,
            voting="hard",
            weights=weights,
        )
    elif PROBLEM_TYPE == 'regression':
        score_ppmi = results_ppmi["test_r2"][i_fold]
        score_adni = results_adni["test_r2"][i_fold]
        score_qpn = results_qpn["test_r2"][i_fold]
        scores_all = np.array([score_ppmi, score_adni, score_qpn])
        weights = softmax(scores_all)
        # weights = scores_all / scores_all.sum()
        voter = VotingRegressor(
            estimators=estimators,
            weights=weights,
        )
    else:
        raise ValueError(f"PROBLEM_TYPE must be either 'classification' or 'regression'")

    voter.fit(X_fake, y_fake)  # not sure this is valid
    voter.estimators_ = [estimator for _, estimator in estimators]

    dataset_map = {
        'ppmi': (results_ppmi, df_ppmi),
        'adni': (results_adni, df_adni),
        'qpn': (results_qpn, df_qpn),
    }
    for metric_name, metric_func in metric_map.items():
        y_test_all = []
        y_pred_all = []
        for label, (results, df_data) in dataset_map.items():
            idx_test = results["indices"]["test"][i_fold]
            X_test = results['X'].iloc[idx_test]
            y_test = results['y'].iloc[idx_test]
            y_pred = voter.predict(X_test)
            metric = metric_func(y_test, y_pred)
            results_fl_voting[label][metric_name].append(metric)
            y_test_all.append(y_test)
            y_pred_all.append(y_pred)
        y_test_all = pd.concat(y_test_all)
        y_pred_all = np.concatenate(y_pred_all)
        results_fl_voting['all'][metric_name].append(metric_func(y_test_all, y_pred_all))

for dataset_name in results_fl_voting:
    print(f"===== {dataset_name} =====")
    for metric_name, metric_values in results_fl_voting[dataset_name].items():
        metric_values = np.array(metric_values)
        print(f"{metric_name}: {metric_values.mean():.2f} ({metric_values.std():.2f})\t{metric_values}")
        add_scores('fl_voting', dataset_name, metric_name, metric_values, is_null=False)




In [None]:
# y_pred1 = estimator_ppmi.predict(X_test)
# y_pred2 = estimator_adni.predict(X_test)
# y_pred3 = estimator_qpn.predict(X_test)

# df = pd.DataFrame({
#     'ppmi': y_pred1,
#     'adni': y_pred2,
#     'qpn': y_pred3,
#     'voter': voter.predict(X_test),
#     'true': y_test,
# })
# print(balanced_accuracy_score(df['true'], df['voter']))
# df

In [None]:
from collections import defaultdict

from sklearn.base import clone

import matplotlib.pyplot as plt
import seaborn as sns

# fig, axes = plt.subplots(nrows=N_FOLDS, figsize=(2*N_FOLDS, 15))

metric_map = get_metrics_map(include_roc_auc=True)

results_fl_avg = defaultdict(lambda: defaultdict(list))

X_fake, y_fake = get_fake_X_y()

dfs_coefs = []

for i_fold in range(5):
    estimator_avg = clone(pipeline)
    estimator_avg.fit(X_fake, y_fake)

    n_samples_ppmi = len(results_ppmi['indices']['test'][i_fold])
    n_samples_adni = len(results_adni['indices']['test'][i_fold])
    n_samples_qpn = len(results_qpn['indices']['test'][i_fold])
    n_samples_all = np.array([n_samples_ppmi, n_samples_adni, n_samples_qpn])
    # n_samples_all = np.array([n_samples_ppmi, n_samples_qpn])
    weights = n_samples_all / n_samples_all.sum()

    if 'standardscaler' in estimator_avg.named_steps and estimator_avg['standardscaler'] != 'passthrough':
        scaler_ppmi = results_ppmi['estimator'][i_fold]['standardscaler']
        scaler_adni = results_adni['estimator'][i_fold]['standardscaler']
        scaler_qpn = results_qpn['estimator'][i_fold]['standardscaler']
        scaler_avg = estimator_avg['standardscaler']

        scales_all = np.vstack([scaler_ppmi.scale_, scaler_adni.scale_, scaler_qpn.scale_])
        means_all = np.vstack([scaler_ppmi.mean_, scaler_adni.mean_, scaler_qpn.mean_])
        # scales_all = np.vstack([scaler_ppmi.scale_, scaler_qpn.scale_])
        # means_all = np.vstack([scaler_ppmi.mean_, scaler_qpn.mean_])

        scaler_avg.scale_ = np.average(scales_all, axis=0, weights=weights, keepdims=True)
        scaler_avg.mean_ = np.average(means_all, axis=0, weights=weights, keepdims=True)

    lr_avg = estimator_avg[-1]
    lr_ppmi = results_ppmi["estimator"][i_fold][-1]
    lr_adni = results_adni["estimator"][i_fold][-1]
    lr_qpn = results_qpn["estimator"][i_fold][-1]

    coefs_all = np.vstack([lr_ppmi.coef_, lr_adni.coef_, lr_qpn.coef_])
    intercepts_all = np.vstack([lr_ppmi.intercept_, lr_adni.intercept_, lr_qpn.intercept_])
    # coefs_all = np.vstack([lr_ppmi.coef_, lr_qpn.coef_])
    # intercepts_all = np.vstack([lr_ppmi.intercept_, lr_qpn.intercept_])
    
    if PROBLEM_TYPE == 'regression':
        keepdims = False
    else:
        keepdims = True

    lr_avg.intercept_ = np.average(intercepts_all, axis=0, weights=weights, keepdims=False)
    lr_avg.coef_ = np.average(coefs_all, axis=0, weights=weights, keepdims=keepdims)

    # for plotting
    if PROBLEM_TYPE == 'classification':
        df_coefs = pd.concat(
            {
                'ppmi': pd.Series(lr_ppmi.coef_[0], index=X_fake.columns),
                'adni': pd.Series(lr_adni.coef_[0], index=X_fake.columns),
                'qpn': pd.Series(lr_qpn.coef_[0], index=X_fake.columns),
                'fed': pd.Series(lr_avg.coef_[0], index=X_fake.columns),
            },
            # axis='columns',
        ).reset_index()
    elif PROBLEM_TYPE == 'regression':
        df_coefs = pd.concat(
            {
                'ppmi': pd.Series(lr_ppmi.coef_, index=X_fake.columns),
                'adni': pd.Series(lr_adni.coef_, index=X_fake.columns),
                'qpn': pd.Series(lr_qpn.coef_, index=X_fake.columns),
                'fed': pd.Series(lr_avg.coef_, index=X_fake.columns),
            },
            # axis='columns',
        ).reset_index()
    df_coefs['i_fold'] = i_fold
    dfs_coefs.append(df_coefs)

    dataset_map = {
        'ppmi': (results_ppmi, df_ppmi),
        'adni': (results_adni, df_adni),
        'qpn': (results_qpn, df_qpn),
    }
    for metric_name, metric_func in metric_map.items():
        y_test_all = []
        y_pred_all = []
        for label, (results, df_data) in dataset_map.items():
            idx_test = results["indices"]["test"][i_fold]
            X_test = results['X'].iloc[idx_test]
            y_test = results['y'].iloc[idx_test]

            if metric_name == 'roc_auc':
                # technically scores but renaming to pred for concatenation
                y_pred = estimator_avg.predict_proba(X_test)[:, 1]
                metric = metric_func(y_test, y_pred)
            else:
                y_pred = estimator_avg.predict(X_test)
                metric = metric_func(y_test, y_pred)
            results_fl_avg[label][metric_name].append(metric)
            y_test_all.append(y_test)
            y_pred_all.append(y_pred)
        y_test_all = pd.concat(y_test_all)
        y_pred_all = np.concatenate(y_pred_all)
        results_fl_avg['all'][metric_name].append(metric_func(y_test_all, y_pred_all))

for dataset_name in results_fl_avg:
    print(f"===== {dataset_name} =====")
    for metric_name, metric_values in results_fl_avg[dataset_name].items():
        metric_values = np.array(metric_values)
        print(f"{metric_name}: {metric_values.mean():.2f} ({metric_values.std():.2f})\t{metric_values}")
        add_scores('fl_fedavg', dataset_name, metric_name, metric_values, is_null=False)

# df_coefs = pd.concat(dfs_coefs)
# grid = sns.catplot(data=df_coefs, kind='bar', x='level_1', hue='level_0', y=0, row='i_fold', aspect=3)
# grid.tick_params(axis='x', rotation=30)
# # grid.set_xticklabels(rotation=30)



In [None]:
# print(PROBLEM_TYPE)
# print(lr_ppmi.coef_.shape)
# print(lr_adni.coef_.shape)
# print(lr_qpn.coef_.shape)
# print(lr_avg.coef_.shape)

In [None]:
DF_RESULTS_TO_SAVE = DF_RESULTS_TO_SAVE.drop_duplicates()
DPATH_RESULTS.mkdir(exist_ok=True)
fpath_out = DPATH_RESULTS / f'results-{PROBLEM_TYPE}-{COMMON_TAGS}-{N_FOLDS}-{RNG_SEED}.tsv'
DF_RESULTS_TO_SAVE.to_csv(fpath_out, sep='\t', index=False)
# print(fpath_out)
 

In [None]:
DF_RESULTS_TO_SAVE

In [None]:
# import seaborn as sns

# # df_tmp = DF_RESULTS_TO_SAVE.query('is_null == False and metric == "mean_absolute_error" and method != "fl_voting" and method != "fl_fedavg"')#.groupby(['method', 'test_dataset']).score.mean()
# df_tmp = DF_RESULTS_TO_SAVE.query('is_null == False and metric == "balanced_accuracy" and method != "fl_voting" and method != "fl_fedavg"')#.groupby(['method', 'test_dataset']).score.mean()
# fg = sns.catplot(
#     data = df_tmp,
#     hue='test_dataset',
#     x='i_fold',
#     y='score',
#     col='method',
#     kind='strip',
#     aspect=3,
#     height=2,
#     sharex=False,
# )

# for method, ax in fg.axes_dict.items():
#     mean = df_tmp.query(f'method == "{method}" and test_dataset == "all"')['score'].mean()
#     std = df_tmp.query(f'method == "{method}" and test_dataset == "all"')['score'].std()
#     ax.set_title(f'{ax.get_title()} (mean: {mean:.2f}, std: {std:.2f})')