In [None]:
import pandas as pd
import numpy as np

from mothernet.evaluation.cd_plot_new.cd_plot_code import cd_evaluation
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.simplefilter("ignore", FutureWarning)

varies_over_time= ['XGBoost', 'RF', 'MLP', 'KNN', 'LogReg']

validation_result_file = 'results_validation_2024-05-22.csv'
#test_result_file = 'results_test_2024-04-18.csv'
test_result_file = None

valid_results = pd.read_csv(validation_result_file, index_col=0)
rename_dict = {#'additive_Dclass_average_factorizedoutputTrue_w001_03_02_2024_02_21_10_epoch_420': 'additive class average factorized',
                                                      'additive_Dclass_average_02_29_2024_04_15_55_epoch_1050': 'additive_class_average',
                                                      'additive_1_gpu_02_14_2024_16_34_15': 'additive dense',
                                                   'mn_Dclass_average_03_25_2024_17_14_32_epoch_3970_ohe_ensemble_8': 'MotherNet',
                                                    'MotherNet': 'mothernet_old',
                                                     'additive_Dclass_average_multiclassmaxsteps3_multiclasstypesteps_03_04_2024_19_04_03_epoch_270': 'steps_prior',
                                                     #'baam_nfeatures_20_no_ensemble_e1520': 'GammaNet',
    'batapfn_no_ensemble_e410': 'bi-attention TabPFN',
    #'ebm_default': 'EBM (interactions)',
    #'ebm_bins_main_effects': 'EBM (main effects)',
    'hyperfast_no_optimize_cpu': 'HyperFast (no GD)',
    'hyperfast_defaults_cpu' : 'HyperFast (default)',
    'LogReg': 'Logistic Regression',
     'RF': 'RandomForest',
    #'MLP-Distill': 'mlp_distill',
    'MLP': 'neural_network_old_drop',
    'TabPFN (ours)': 'tabpfn_ours_hide'}
#valid_results.model.unique()
#valid_results['model'] = valid_results.model.replace({'hyperfast_defaults_cpu': 'hyperfast_defaults_gpu'}) ## HACCCKKK
valid_results['model'] = valid_results.model.replace(rename_dict)
valid_results = valid_results[valid_results.model != "hyperfast_no_optimize_cpu"]
valid_results = valid_results[~valid_results.model.str.contains('_')]
all_models = valid_results.model.unique()



In [None]:
all_models

In [None]:
if test_result_file:
    test_results = pd.read_csv(test_result_file, index_col=0)
    test_results['model'] = test_results.model.replace(rename_dict)
    all_models = set(valid_results.model.unique()).union(set(test_results.model.unique()))

    print(test_results.model.unique())
    print(valid_results.model.unique())
    assert set(test_results.model.unique()) == set(valid_results.model.unique())

In [None]:
color_palette = sns.color_palette(n_colors=len(all_models))
color_mapping = dict(zip(all_models, color_palette))

In [None]:
def get_best_over_time(results):
    average_over_splits = results.groupby(["dataset", "model", "max_time"])[["mean_metric", "fit_time", "inference_time"]].mean().reset_index()
    best_tuned = average_over_splits[average_over_splits.max_time == average_over_splits.max_time.max()]
    untuned = average_over_splits[~average_over_splits.model.isin(best_tuned.model.unique())]
    return pd.concat([best_tuned, untuned])

In [None]:
plt.figure(figsize=(8, 4), dpi=300)
combined_best_valid = get_best_over_time(valid_results)
pivoted_for_cd = combined_best_valid.pivot(index="dataset", columns="model", values="mean_metric")
_ = cd_evaluation(pivoted_for_cd, maximize_metric=True, ax=plt.gca())
plt.savefig("../figures/cd_diagram_validation.pdf", bbox_inches="tight")

In [None]:
combined_best_valid.groupby("model").mean_metric.mean().sort_values()

In [None]:
if test_result_file:
    plt.figure(figsize=(8, 3), dpi=300)
    combined_best_test = get_best_over_time(test_results)
    pivoted_for_cd = combined_best_test.pivot(index="dataset", columns="model", values="mean_metric")
    _ = cd_evaluation(pivoted_for_cd, maximize_metric=True, ax=plt.gca())
    plt.savefig("../figures/cd_diagram_test.pdf", bbox_inches="tight")

In [None]:
def compare_splits(results):
    compare_splits_over_time = results.groupby(["model", "split", "max_time"])[['mean_metric', 'fit_time']].mean().reset_index()
    compare_splits_varies = compare_splits_over_time[compare_splits_over_time.model.isin(varies_over_time)]
    compare_splits_fixed = compare_splits_over_time[~compare_splits_over_time.model.isin(varies_over_time)]
    compare_splits_varies_last = compare_splits_varies[compare_splits_varies.max_time == compare_splits_varies.max_time.max()]
    return compare_splits_over_time, pd.concat([compare_splits_fixed, compare_splits_varies_last])

In [None]:
compare_splits_over_time_valid, compare_splits_valid = compare_splits(valid_results)
if test_result_file:
    compare_splits_over_time_test, compare_splits_test = compare_splits(test_results)

In [None]:
plt.figure(figsize=(8, 4))
order = compare_splits_valid.groupby("model").mean("mean_metric").sort_values("mean_metric").index
sns.boxplot(data=compare_splits_valid, y="model", x="mean_metric", order=order, ax=plt.gca(), palette=color_mapping)
#sns.boxplot(data=compare_splits_valid, y="model", x="mean_metric",order=order, ax=plt.gca(), hue="model")

plt.xlabel("Average ROC AUC")
plt.savefig("../figures/mean_roc_auc_validation.pdf", dpi=300, bbox_inches="tight")

In [None]:
if test_result_file:
    plt.figure(figsize=(8, 6))
    order = compare_splits_test.groupby("model").median("mean_metric").sort_values("mean_metric").index
    sns.boxplot(data=compare_splits_test, y="model", x="mean_metric", order=order, ax=plt.gca(), palette=color_mapping)
    plt.xlabel("Average ROC AUC")
    plt.savefig("../figures/mean_roc_auc_test.pdf", dpi=300, bbox_inches="tight")

In [None]:
def normalize_metric(results):
    dataset_min_max = results.groupby("dataset").mean_metric.agg(["min", "max"])
    results_normalized = results.merge(dataset_min_max, on="dataset")
    results_normalized['mean_metric'] = (results_normalized['mean_metric'] - results_normalized['min']) / (results_normalized['max'] - results_normalized['min'])
    return results_normalized

In [None]:
compare_splits_over_time_valid_normalized, compare_splits_valid_normalized = compare_splits(normalize_metric(valid_results))
if test_result_file:
    compare_splits_over_time_test_normalized, compare_splits_test_normalized = compare_splits(normalize_metric(test_results))

In [None]:
if test_result_file:
    plt.figure(figsize=(8, 6))
    order = compare_splits_test_normalized.groupby("model").median("mean_metric").sort_values("mean_metric").index
    sns.boxplot(data=compare_splits_test_normalized, y="model", x="mean_metric", order=order, ax=plt.gca(), palette=color_mapping)
    plt.xlabel("Average ROC AUC (normalized)", loc="right")
    plt.savefig("../figures/mean_roc_auc_test_normalized.pdf", dpi=300, bbox_inches="tight")
    pd.set_option("display.float_format", lambda x: f"{x:.2f}")
    table = compare_splits_test_normalized.groupby("model").median()[['mean_metric', 'fit_time']].sort_values("mean_metric")
    print(table.to_markdown(floatfmt=".2f"))

In [None]:
# table.drop(index=["mothernet_old", "mlp_distill", 'KNN']).rename({'mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8': 'MotherNet'})

In [None]:
plt.figure(figsize=(8, 4))
order = compare_splits_valid_normalized.groupby("model").mean("mean_metric").sort_values("mean_metric").index
sns.boxplot(data=compare_splits_valid_normalized, y="model", x="mean_metric", order=order, ax=plt.gca(), palette=color_mapping)
#sns.boxplot(data=compare_splits_valid_normalized, y="model", x="mean_metric", order=order, ax=plt.gca(), hue='model')

plt.xlabel("Average ROC AUC (normalized)", loc="right")
plt.savefig("../figures/mean_roc_auc_valid_normalized.pdf", dpi=300, bbox_inches="tight")

In [None]:
per_dataset = combined_best_valid.pivot(index="dataset", columns="model", values="mean_metric")

In [None]:
per_dataset.columns

In [None]:
compare = per_dataset[['MotherNet', 'TabPFN (Hollmann)', 'XGBoost', 'MLP-Distill']].copy()                

In [None]:
compare['diff'] = compare['XGBoost'] - compare['MotherNet']

In [None]:
print(compare[compare['diff'].abs() > 0.05].sort_values("diff").to_latex(float_format="%.3f"))

In [None]:
compare['diff'] = compare['EBM (main effects)'] - compare['baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_1430']

In [None]:
compare.sort_values("diff")

In [None]:
compare.columns

In [None]:
sns.clustermap(compare.drop(columns="diff"))