In [1]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import pickle
import datetime
import warnings
#warnings.simplefilter("ignore", FutureWarning)  # openml deprecation of array return type
warnings.simplefilter("ignore", UserWarning)  # scikit-learn select k best

from mothernet.datasets import load_openml_list, open_cc_valid_dids, open_cc_dids
from mothernet.evaluation.baselines.tabular_baselines import knn_metric, catboost_metric, logistic_metric, xgb_metric, random_forest_metric, mlp_metric, hyperfast_metric
from mothernet.evaluation.tabular_evaluation import evaluate, eval_on_datasets, transformer_metric
from mothernet.evaluation import tabular_metrics
from mothernet.prediction.tabpfn import TabPFNClassifier
from mothernet.evaluation.baselines import tabular_baselines

# Datasets

In [2]:
from mothernet.datasets import load_openml_list, open_cc_dids, open_cc_valid_dids, test_dids_classification

cc_valid_datasets_multiclass, cc_valid_datasets_multiclass_df = load_openml_list(open_cc_valid_dids, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = 10000, num_feats=100, return_capped=True)


  openml_list = openml.datasets.list_datasets(dids)


Number of datasets: 149


  arr = np.asarray(values, dtype=dtype)


In [3]:
cc_valid_datasets_multiclass_df['NumberOfInstances'] =  cc_valid_datasets_multiclass_df['NumberOfInstances'].astype(int)
cc_valid_datasets_multiclass_df['NumberOfFeatures'] =  cc_valid_datasets_multiclass_df['NumberOfFeatures'].astype(int)
cc_valid_datasets_multiclass_df['NumberOfClasses'] =  cc_valid_datasets_multiclass_df['NumberOfClasses'].astype(int)

# uncomment for latex table of datasets
# print(cc_valid_datasets_multiclass_df[['did', 'name', 'NumberOfFeatures', 'NumberOfInstances', 'NumberOfClasses']].rename(columns={'NumberOfFeatures': "d", "NumberOfInstances":"n", "NumberOfClasses": "k"}).to_latex(index=False))

# Setting params

In [4]:
import os
eval_positions = [1000]
max_features = 100
n_samples = 2000
base_path = os.path.join('..')
overwrite = False
metric_used = tabular_metrics.auc_metric
task_type = 'multiclass'

# Baseline Evaluation
This section runs baselines and saves results locally.

In [5]:
!mkdir -p {base_path}/results
!mkdir -p {base_path}/results/tabular/
!mkdir -p {base_path}/results/tabular/multiclass/

In [6]:
len(cc_valid_datasets_multiclass)

149

In [7]:
max_times = [1, 5, 15, 60, 5 * 60, 15 * 60, 60* 60]
# these will all be evaluated on CPU because they are given as  callables, which is a weird way to do it.
clf_dict= {
    'knn': knn_metric,
    'rf_new_params': random_forest_metric,
    'xgb': xgb_metric,
    'logistic': logistic_metric,
    'mlp': mlp_metric}

results_baselines = [
    eval_on_datasets('multiclass', model, model_name, cc_valid_datasets_multiclass, eval_positions=eval_positions, max_times=max_times,
                     metric_used=metric_used, split_numbers=[1, 2, 3, 4, 5],
                     n_samples=n_samples, base_path=base_path, n_jobs=1)
    for model_name, model in clf_dict.items()
]

evaluating knn on cpu


[Parallel(n_jobs=1)]: Done  40 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 161 tasks      | elapsed:    0.8s
[Parallel(n_jobs=1)]: Done 364 tasks      | elapsed:    1.4s
[Parallel(n_jobs=1)]: Done 647 tasks      | elapsed:    2.0s
[Parallel(n_jobs=1)]: Done 1012 tasks      | elapsed:    3.6s
[Parallel(n_jobs=1)]: Done 1457 tasks      | elapsed:    5.4s
[Parallel(n_jobs=1)]: Done 1984 tasks      | elapsed:    6.2s
[Parallel(n_jobs=1)]: Done 2591 tasks      | elapsed:    7.1s
[Parallel(n_jobs=1)]: Done 3280 tasks      | elapsed:    8.2s
[Parallel(n_jobs=1)]: Done 4049 tasks      | elapsed:    9.5s
[Parallel(n_jobs=1)]: Done 4900 tasks      | elapsed:   13.7s


evaluating rf_new_params on cpu


[Parallel(n_jobs=1)]: Done  40 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 161 tasks      | elapsed:    0.8s
[Parallel(n_jobs=1)]: Done 364 tasks      | elapsed:    1.3s
[Parallel(n_jobs=1)]: Done 647 tasks      | elapsed:    1.9s
[Parallel(n_jobs=1)]: Done 1012 tasks      | elapsed:    3.6s
[Parallel(n_jobs=1)]: Done 1457 tasks      | elapsed:    5.4s
[Parallel(n_jobs=1)]: Done 1984 tasks      | elapsed:    6.2s
[Parallel(n_jobs=1)]: Done 2591 tasks      | elapsed:    7.1s
[Parallel(n_jobs=1)]: Done 3280 tasks      | elapsed:    8.2s
[Parallel(n_jobs=1)]: Done 4049 tasks      | elapsed:    9.5s
[Parallel(n_jobs=1)]: Done 4900 tasks      | elapsed:   14.0s


evaluating xgb on cpu


[Parallel(n_jobs=1)]: Done  40 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 161 tasks      | elapsed:    0.8s
[Parallel(n_jobs=1)]: Done 364 tasks      | elapsed:    1.4s
[Parallel(n_jobs=1)]: Done 647 tasks      | elapsed:    2.0s
[Parallel(n_jobs=1)]: Done 1012 tasks      | elapsed:    3.6s
[Parallel(n_jobs=1)]: Done 1457 tasks      | elapsed:    5.4s
[Parallel(n_jobs=1)]: Done 1984 tasks      | elapsed:    6.2s
[Parallel(n_jobs=1)]: Done 2591 tasks      | elapsed:    7.1s
[Parallel(n_jobs=1)]: Done 3280 tasks      | elapsed:    8.1s
[Parallel(n_jobs=1)]: Done 4049 tasks      | elapsed:    9.4s
[Parallel(n_jobs=1)]: Done 4900 tasks      | elapsed:   13.7s


evaluating logistic on cpu


[Parallel(n_jobs=1)]: Done  40 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 161 tasks      | elapsed:    0.8s
[Parallel(n_jobs=1)]: Done 364 tasks      | elapsed:    1.3s
[Parallel(n_jobs=1)]: Done 647 tasks      | elapsed:    1.9s
[Parallel(n_jobs=1)]: Done 1012 tasks      | elapsed:    3.6s
[Parallel(n_jobs=1)]: Done 1457 tasks      | elapsed:    5.4s
[Parallel(n_jobs=1)]: Done 1984 tasks      | elapsed:    6.2s
[Parallel(n_jobs=1)]: Done 2591 tasks      | elapsed:    7.1s
[Parallel(n_jobs=1)]: Done 3280 tasks      | elapsed:    8.2s
[Parallel(n_jobs=1)]: Done 4049 tasks      | elapsed:    9.5s
[Parallel(n_jobs=1)]: Done 4900 tasks      | elapsed:   13.8s


evaluating mlp on cpu


[Parallel(n_jobs=1)]: Done  40 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 161 tasks      | elapsed:    0.8s
[Parallel(n_jobs=1)]: Done 364 tasks      | elapsed:    1.4s
[Parallel(n_jobs=1)]: Done 647 tasks      | elapsed:    2.0s
[Parallel(n_jobs=1)]: Done 1012 tasks      | elapsed:    3.6s
[Parallel(n_jobs=1)]: Done 1457 tasks      | elapsed:    5.4s
[Parallel(n_jobs=1)]: Done 1984 tasks      | elapsed:    6.2s
[Parallel(n_jobs=1)]: Done 2591 tasks      | elapsed:    7.1s
[Parallel(n_jobs=1)]: Done 3280 tasks      | elapsed:    8.2s
[Parallel(n_jobs=1)]: Done 4049 tasks      | elapsed:    9.4s
[Parallel(n_jobs=1)]: Done 4900 tasks      | elapsed:   13.7s


In [8]:
from mothernet.evaluation.tabular_evaluation import eval_on_datasets
from mothernet.prediction.mothernet import ShiftClassifier, EnsembleMeta, MotherNetClassifier, MotherNetInitMLPClassifier
from mothernet.prediction.mothernet_additive import MotherNetAdditiveClassifier
from mothernet.evaluation.baselines.distill_mlp import DistilledTabPFNMLP
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from interpret.glassbox import ExplainableBoostingClassifier
from functools import partial
from hyperfast import HyperFastClassifier

# transformers don't have max times
max_times = [1]
device = "cpu"

tabpfn_ours = TabPFNClassifier(device=device, model_string="tabpfn_nooptimizer_emsize_512_nlayers_12_steps_2048_bs_32ada_lr_0.0001_1_gpu_07_24_2023_01_43_33", epoch="1650", N_ensemble_configurations=3)

tabpfn_ours_ensemble_8 = TabPFNClassifier(device=device, model_string="tabpfn_nooptimizer_emsize_512_nlayers_12_steps_2048_bs_32ada_lr_0.0001_1_gpu_07_24_2023_01_43_33", epoch="1650", N_ensemble_configurations=8)
tabpfn_ours_ensemble_32 = TabPFNClassifier(device=device, model_string="tabpfn_nooptimizer_emsize_512_nlayers_12_steps_2048_bs_32ada_lr_0.0001_1_gpu_07_24_2023_01_43_33", epoch="1650", N_ensemble_configurations=32)

batapfn_no_ensemble = TabPFNClassifier(device=device, model_string="batabpfn_e128_inputembeddingfourier_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_11_35", epoch="330", N_ensemble_configurations=1)
batapfn_no_ensemble_e410 = TabPFNClassifier(device=device, model_string="batabpfn_e128_inputembeddingfourier_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_11_35", epoch="410", N_ensemble_configurations=1)
batapfn_no_ensemble_e530 = TabPFNClassifier(device=device, model_string="batabpfn_e128_inputembeddingfourier_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_11_35", epoch="530", N_ensemble_configurations=1)
batapfn_no_ensemble_exit = TabPFNClassifier(device=device, model_string="batabpfn_e128_inputembeddingfourier_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_11_35", epoch="on_exit", N_ensemble_configurations=1)


batabpfn_e256_nsamples500_numfeatures20_03_20_2024_22_14_45_e630 = TabPFNClassifier(device=device, model_string="batabpfn_e256_nsamples500_numfeatures20_03_20_2024_22_14_45", epoch="630", N_ensemble_configurations=1)
batabpfn_e256_nsamples500_numfeatures20_03_20_2024_22_14_45_e1130 = TabPFNClassifier(device=device, model_string="batabpfn_e256_nsamples500_numfeatures20_03_20_2024_22_14_45", epoch="1130", N_ensemble_configurations=1)


mlp_distill = make_pipeline(StandardScaler(), DistilledTabPFNMLP(n_epochs=1000, device=device, hidden_size=128, n_layers=2, dropout_rate=.1, learning_rate=0.01, model_string="tabpfn_nooptimizer_emsize_512_nlayers_12_steps_2048_bs_32ada_lr_0.0001_1_gpu_07_24_2023_01_43_33", epoch=1650, N_ensemble_configurations=3))
mothernet_21_46_25_3940_ensemble3 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_d2048_H4096_L2_W32_P512_1_gpu_warm_08_25_2023_21_46_25_epoch_3940_no_optimizer.pickle", device=device), n_estimators=3)

#mothernet_init_gd = MotherNetInitMLPClassifier(path="../models_diff/mn_d2048_H4096_L2_W32_P512_1_gpu_warm_08_25_2023_21_46_25_epoch_3940_no_optimizer.pickle", device=device, n_epochs=100)
#mothernet_init_gd_no_learning = MotherNetInitMLPClassifier(path="../models_diff/mn_d2048_H4096_L2_W32_P512_1_gpu_warm_08_25_2023_21_46_25_epoch_3940_no_optimizer.pickle", device=device, n_epochs=1, learning_rate=0.0)
#mothernet_init_gd_epochs_10 = MotherNetInitMLPClassifier(path="../models_diff/mn_d2048_H4096_L2_W32_P512_1_gpu_warm_08_25_2023_21_46_25_epoch_3940_no_optimizer.pickle", device=device, n_epochs=10, learning_rate=0.001, verbose=10)
#mothernet_init_gd_epochs_10_lr0001 = MotherNetInitMLPClassifier(path="../models_diff/mn_d2048_H4096_L2_W32_P512_1_gpu_warm_08_25_2023_21_46_25_epoch_3940_no_optimizer.pickle", device=device, n_epochs=10, learning_rate=0.0001, verbose=10)


mn_P512_SFalse_L2_1_gpu_01_24_2024 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_P512_SFalse_L2_1_gpu_01_24_2024_00_31_59_epoch_3950.cpkt", device=device), n_estimators=3)
mn_SFalse_L2_1_gpu_01_25_2024_21_20_32 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_SFalse_L2_1_gpu_01_25_2024_21_20_32_epoch_4000.cpkt", device=device), n_estimators=3)
mn_Dclass_average_02_29_2024_04_16_00 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_02_29_2024_04_16_00_epoch_4000.cpkt", device=device), n_estimators=3)
mn_Dclass_average_02_29_2024_04_16_00_ohe_ensemble = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_02_29_2024_04_16_00_epoch_4000.cpkt", device=device), n_estimators=3, onehot=True)
mn_Dclass_average_02_29_2024_04_16_00_ohe_ensemble_32 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_02_29_2024_04_16_00_epoch_4000.cpkt", device=device), n_estimators=32, onehot=True)
mn_Dclass_average_02_29_2024_04_16_00_ohe_ensemble_8 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_02_29_2024_04_16_00_epoch_4000.cpkt", device=device), n_estimators=8, onehot=True)

mn_Dclass_average_fourierfeatures16_05_09_2024_01_03_23_epoch_100 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_fourierfeatures16_05_09_2024_01_03_23_epoch_100.cpkt", device=device), n_estimators=8, onehot=True)


#mn_Dclass_average_03_25_2024_17_14_32_epoch_1760_ohe_ensemble_8 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_03_25_2024_17_14_32_epoch_1760.cpkt", device=device), n_estimators=8, onehot=True)
#mn_Dclass_average_03_25_2024_17_14_32_epoch_2270_ohe_ensemble_8 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_03_25_2024_17_14_32_epoch_2270.cpkt", device=device), n_estimators=8, onehot=True)
#mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_03_25_2024_17_14_32_epoch_2910.cpkt", device=device), n_estimators=8, onehot=True)

mn_Dclass_average_03_25_2024_17_14_32_epoch_3970_ohe_ensemble_8 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_03_25_2024_17_14_32_epoch_3970.cpkt", device=device), n_estimators=8, onehot=True)

additive_1_gpu_02_14_2024_16_34_15 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_1_gpu_02_14_2024_16_34_15_epoch_950_fixed2.cpkt", device=device), n_estimators=3, power=False)

additive_step_prior_02_08_2024 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_b16_reducelronspikeTrue_multiclasstypesteps_1_gpu_02_08_2024_04_51_33_epoch_790.cpkt", device=device), n_estimators=3, power=False)
additive_11_08_2023 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_1_gpu_11_08_2023_23_02_58_continue_11_10_2023_03_01_40_epoch_3170_no_optimizer.cpkt", device=device), n_estimators=3, power=False)
additive_02_20_2024_factorized_weight_decay = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_w0.01_factorizedoutputTrue_outputrank64_1_gpu_02_20_2024_22_39_06_epoch_1260_fixed.cpkt", device=device), n_estimators=3, power=False)

additive_Dclass_average_02_29_2024_04_15_55_epoch_190 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_02_29_2024_04_15_55_epoch_190.cpkt", device=device), n_estimators=3, power=False)
additive_Dclass_average_02_29_2024_04_15_55_epoch_190_no_ensemble = MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_02_29_2024_04_15_55_epoch_190.cpkt", device=device)
additive_Dclass_average_02_29_2024_04_15_55_epoch_560 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_02_29_2024_04_15_55_epoch_560.cpkt", device=device), n_estimators=3, power=False)
additive_Dclass_average_02_29_2024_04_15_55_epoch_730 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_02_29_2024_04_15_55_epoch_730.cpkt", device=device), n_estimators=3, power=False)
additive_Dclass_average_02_29_2024_04_15_55_epoch_780 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_02_29_2024_04_15_55_epoch_780.cpkt", device=device), n_estimators=3, power=False)
additive_Dclass_average_02_29_2024_04_15_55_epoch_850 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_02_29_2024_04_15_55_epoch_850.cpkt", device=device), n_estimators=3, power=False)
additive_Dclass_average_02_29_2024_04_15_55_epoch_1050 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_02_29_2024_04_15_55_epoch_1050.cpkt", device=device), n_estimators=3, power=False)


additive_Dclass_average_inputlayernormTrue_02_29_2024_20_52_12_epoch_1340 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_inputlayernormTrue_02_29_2024_20_52_12_epoch_1340.cpkt", device=device), n_estimators=3, power=False)



# additive_Dclass_average_multiclassmaxsteps3_multiclasstypesteps_03_04_2024_19_04_03_epoch_100 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_multiclassmaxsteps3_multiclasstypesteps_03_04_2024_19_04_03_epoch_100.cpkt", device=device), n_estimators=3, power=False)
# additive_Dclass_average_multiclassmaxsteps3_multiclasstypesteps_03_04_2024_19_04_03_epoch_270 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_multiclassmaxsteps3_multiclasstypesteps_03_04_2024_19_04_03_epoch_270.cpkt", device=device), n_estimators=3, power=False)


# additive_H512_Dclass_average_factorizedoutputTrue_L6_03_01_2024_21_38_55_epoch_360 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_H512_Dclass_average_factorizedoutputTrue_L6_03_01_2024_21_38_55_epoch_360.cpkt", device=device), n_estimators=3, power=False)
# additive_Dclass_average_factorizedoutputTrue_w001_03_02_2024_02_21_10_epoch_340 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_factorizedoutputTrue_w0.01_03_02_2024_02_21_10_epoch_340.cpkt", device=device), n_estimators=3, power=False)
# additive_Dclass_average_factorizedoutputTrue_w001_03_02_2024_02_21_10_epoch_420 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_factorizedoutputTrue_w0.01_03_02_2024_02_21_10_epoch_420.cpkt", device=device), n_estimators=3, power=False)
# additive_Dclass_average_factorizedoutputTrue_w001_03_02_2024_02_21_10_epoch_1210 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_factorizedoutputTrue_w0.01_03_02_2024_02_21_10_epoch_1210.cpkt", device=device), n_estimators=3, power=False)

# additive_Dclass_average_factorizedoutputTrue_nshapefunctions128_outputrank64_shapeattentionTrue_shapeattentionheads8_03_07_2024_00_39_41_epoch_580 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_factorizedoutputTrue_nshapefunctions128_outputrank64_shapeattentionTrue_shapeattentionheads8_03_07_2024_00_39_41_epoch_580.cpkt", device=device), n_estimators=3, power=False)

# additive_Dclass_average_factorizedoutputTrue_nshapefunctions128_outputrank64_shapeattentionTrue_shapeattentionheads8_03_08_2024_21_19_43_epoch_1280 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_factorizedoutputTrue_nshapefunctions128_outputrank64_shapeattentionTrue_shapeattentionheads8_03_08_2024_21_19_43_epoch_1280.cpkt", device=device), n_estimators=3, power=False)


ebm_default = ExplainableBoostingClassifier()
ebm_bins = ExplainableBoostingClassifier(max_bins=64)
ebm_bins_main_effects = ExplainableBoostingClassifier(max_bins=64, interactions=0)


baam_nfeatures_20_no_ensemble = MotherNetAdditiveClassifier(path="../models_diff/baam_H512_Dclass_average_e128_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_03_22_epoch_400.cpkt", device=device)
baam_nfeatures_20_no_ensemble_e500 = MotherNetAdditiveClassifier(path="../models_diff/baam_H512_Dclass_average_e128_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_03_22_epoch_500.cpkt", device=device)
baam_nfeatures_20_no_ensemble_e650 = MotherNetAdditiveClassifier(path="../models_diff/baam_H512_Dclass_average_e128_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_03_22_epoch_650.cpkt", device=device)
baam_nfeatures_20_no_ensemble_e840 = MotherNetAdditiveClassifier(path="../models_diff/baam_H512_Dclass_average_e128_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_03_22_epoch_840.cpkt", device=device)
baam_nfeatures_20_no_ensemble_e1210 = MotherNetAdditiveClassifier(path="../models_diff/baam_H512_Dclass_average_e128_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_03_22_epoch_1210.cpkt", device=device)
baam_nfeatures_20_no_ensemble_e1520 = MotherNetAdditiveClassifier(path="../models_diff/baam_H512_Dclass_average_e128_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_03_22_epoch_1520.cpkt", device=device)
baam_nfeatures_20_no_ensemble_e1970 = MotherNetAdditiveClassifier(path="../models_diff/baam_H512_Dclass_average_e128_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_03_22_epoch_1970.cpkt", device=device)


baam_nfeatures_20 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/baam_H512_Dclass_average_e128_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_03_22_epoch_400.cpkt", device=device), n_estimators=3, power=False)

baam_nfeatures_100_no_ensemble = MotherNetAdditiveClassifier(path="../models_diff/baam_H512_Dclass_average_e64_nsamples500_N6_padzerosFalse_03_13_2024_20_28_36_epoch_360.cpkt", device=device)


baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_140 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_140.cpkt", device=device)
baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_150 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_150.cpkt", device=device)
baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_220 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_220.cpkt", device=device)
baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_410 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_410.cpkt", device=device)
baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_780 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_780.cpkt", device=device)
baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_1010 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_1010.cpkt", device=device)
baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_1210 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_1210.cpkt", device=device)
baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_1520 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_1520.cpkt", device=device)
baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_2940 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_2940.cpkt", device=device)

baam_nsamples500_numfeatures20_04_04_2024_03_07_12_epoch_1210 = MotherNetAdditiveClassifier(path="../models_diff/baam_nsamples500_numfeatures20_04_04_2024_03_07_12_epoch_1210.cpkt", device=device)

baam_e128_nsamples500_numfeatures20_04_01_2024_15_38_54_epoch_2220 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_nsamples500_numfeatures20_04_01_2024_15_38_54_epoch_2220.cpkt", device=device)

baam_nsamples500_numfeatures10_04_07_2024_17_04_53_epoch_1490 = MotherNetAdditiveClassifier(path="../models_diff/baam_nsamples500_numfeatures10_04_07_2024_17_04_53_epoch_1490.cpkt", device=device)
baam_nsamples500_numfeatures10_04_07_2024_17_04_53_epoch_1780 = MotherNetAdditiveClassifier(path="../models_diff/baam_nsamples500_numfeatures10_04_07_2024_17_04_53_epoch_1780.cpkt", device=device)

baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_1430 = MotherNetAdditiveClassifier(path="../models_diff/baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_1430.cpkt", device=device)


baam_categoricalembeddingTrue_nsamples500_nanbinTrue_numfeatures20_04_04_2024_14_06_39_epoch_970 = MotherNetAdditiveClassifier(path="../models_diff/baam_categoricalembeddingTrue_nsamples500_nanbinTrue_numfeatures20_04_04_2024_14_06_39_epoch_970.cpkt", device=device)


baam_e128_nbins512_nsamples500_numfeatures20_03_19_2024_22_53_00_epoch_160 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_nbins512_nsamples500_numfeatures20_03_19_2024_22_53_00_epoch_160.cpkt", device=device)

baam_fourierfeatures64_nbins512_nsamples500_numfeatures20_03_21_2024_01_02_32_epoch_310 = MotherNetAdditiveClassifier(path="../models_diff/baam_fourierfeatures64_nbins512_nsamples500_numfeatures20_03_21_2024_01_02_32_epoch_310.cpkt", device=device)

baam_marginalresidualTrue_nsamples500_numfeatures20_shapeinitzero_03_29_2024_19_20_09_epoch_130 = MotherNetAdditiveClassifier(path="../models_diff/baam_marginalresidualTrue_nsamples500_numfeatures20_shapeinitzero_03_29_2024_19_20_09_epoch_130.cpkt", device=device)
baam_marginalresidualTrue_nsamples500_numfeatures20_shapeinitzero_03_29_2024_19_20_09_epoch_180 = MotherNetAdditiveClassifier(path="../models_diff/baam_marginalresidualTrue_nsamples500_numfeatures20_shapeinitzero_03_29_2024_19_20_09_epoch_180.cpkt", device=device)


baam_e128_featurecurriculumTrue_nsamples500_N6_03_21_2024_17_58_54_epoch_610 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_featurecurriculumTrue_nsamples500_N6_03_21_2024_17_58_54_epoch_610.cpkt", device=device)
baam_e128_featurecurriculumTrue_nsamples500_N6_03_21_2024_17_58_54_epoch_660 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_featurecurriculumTrue_nsamples500_N6_03_21_2024_17_58_54_epoch_660.cpkt", device=device)

baam_fourierfeatures64_nsamples500_numfeatures20_03_27_2024_17_56_13_epoch_250 = MotherNetAdditiveClassifier(path="../models_diff/baam_fourierfeatures64_nsamples500_numfeatures20_03_27_2024_17_56_13_epoch_250.cpkt", device=device)
baam_fourierfeatures64_nsamples500_numfeatures20_03_27_2024_17_56_13_epoch_1230 = MotherNetAdditiveClassifier(path="../models_diff/baam_fourierfeatures64_nsamples500_numfeatures20_03_27_2024_17_56_13_epoch_1230.cpkt", device=device)

baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_410 = MotherNetAdditiveClassifier(path="../models_diff/baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_410.cpkt", device=device)
baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_550 = MotherNetAdditiveClassifier(path="../models_diff/baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_550.cpkt", device=device)
baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_690 = MotherNetAdditiveClassifier(path="../models_diff/baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_690.cpkt", device=device)
baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_890 = MotherNetAdditiveClassifier(path="../models_diff/baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_890.cpkt", device=device)

baam_nsamples500_numfeatures20_03_27_2024_17_57_59_epoch_470 = MotherNetAdditiveClassifier(path="../models_diff/baam_nsamples500_numfeatures20_03_27_2024_17_57_59_epoch_470.cpkt", device=device)

baam_categoricalembeddingTrue_nsamples500_nanbinTrue_numfeatures20_04_04_2024_14_06_39_epoch_190 = MotherNetAdditiveClassifier(path="../models_diff/baam_categoricalembeddingTrue_nsamples500_nanbinTrue_numfeatures20_04_04_2024_14_06_39_epoch_190.cpkt", device=device)



#hyperfast_no_optimization = HyperFastClassifier(device=device, optimization=None)
#hyperfast_defaults = HyperFastClassifier(device=device)

clf_dict= {
    'mothernet': partial(transformer_metric, classifier=mothernet_21_46_25_3940_ensemble3, onehot=True),
    'mlp_distill': mlp_distill,
    'tabpfn': transformer_metric,
    'tabpfn_ours': tabpfn_ours,
    'tabpfn_ours_ensemble_8': tabpfn_ours_ensemble_8,
    'tabpfn_ours_ensemble_32': tabpfn_ours_ensemble_32,
    # 'mothernet_init_gd': mothernet_init_gd,
    # 'mothernet_init_gd_no_learning': mothernet_init_gd_no_learning,
    # 'mothernet_init_gd_epochs_10': mothernet_init_gd_epochs_10,
    # 'mothernet_init_gd_epochs_10_lr0001': mothernet_init_gd_epochs_10_lr0001,
    # 'mothernet_init_gd_no_learning_ohe' : partial(transformer_metric, classifier=mothernet_init_gd_no_learning, onehot=True),
    # 'mothernet_init_gd_ohe' : partial(transformer_metric, classifier=mothernet_init_gd, onehot=True),


    #'batapfn_no_ensemble': batapfn_no_ensemble,
    #'batapfn_no_ensemble_e410': batapfn_no_ensemble_e410,
    #'batapfn_no_ensemble_exit': batapfn_no_ensemble_exit,
    'batabpfn_e256_nsamples500_numfeatures20_03_20_2024_22_14_45_e630': batabpfn_e256_nsamples500_numfeatures20_03_20_2024_22_14_45_e630,
    'batabpfn_e256_nsamples500_numfeatures20_03_20_2024_22_14_45_e1130': batabpfn_e256_nsamples500_numfeatures20_03_20_2024_22_14_45_e1130,
  #"hyperfast_no_optimize_gpu": partial(hyperfast_metric, optimization=None),
     "hyperfast_no_optimize_cpu":  partial(hyperfast_metric, optimization=None),
    "hyperfast_defaults_cpu":hyperfast_metric,
    
    'ebm_default': partial(transformer_metric, classifier=ebm_default),
    #'ebm_bins': partial(transformer_metric, classifier=ebm_bins),
    'ebm_bins_main_effects': partial(transformer_metric, classifier=ebm_bins_main_effects),
    #'batapfn_no_ensemble_e530_redo': batapfn_no_ensemble_e530,
   # 'mn_P512_SFalse_L2_1_gpu_01_24_2024_onehot':  partial(transformer_metric, classifier=mn_P512_SFalse_L2_1_gpu_01_24_2024, onehot=True),
   # 'mn_SFalse_L2_1_gpu_01_25_2024_21_20_32_onehot':  partial(transformer_metric, classifier=mn_SFalse_L2_1_gpu_01_25_2024_21_20_32, onehot=True),
    #'mn_Dclass_average_02_29_2024_04_16_00' : partial(transformer_metric, classifier=mn_Dclass_average_02_29_2024_04_16_00, onehot=True),
    #    'mn_Dclass_average_02_29_2024_04_16_00_no_onehot' : mn_Dclass_average_02_29_2024_04_16_00,
        #'mn_Dclass_average_02_29_2024_04_16_00_onehot_and_passthrough2': partial(transformer_metric, classifier=mn_Dclass_average_02_29_2024_04_16_00, onehot=True),
    #'mn_Dclass_average_02_29_2024_04_16_00_ohe_ensemble_again': mn_Dclass_average_02_29_2024_04_16_00_ohe_ensemble,
   # 'mn_Dclass_average_02_29_2024_04_16_00_ohe_ensemble_32_again': mn_Dclass_average_02_29_2024_04_16_00_ohe_ensemble_32,
   # 'mn_Dclass_average_03_25_2024_17_14_32_epoch_1760_ohe_ensemble_8': mn_Dclass_average_03_25_2024_17_14_32_epoch_1760_ohe_ensemble_8,
    #'mn_Dclass_average_02_29_2024_04_16_00_ohe_ensemble_8': mn_Dclass_average_02_29_2024_04_16_00_ohe_ensemble_8,
    #'mn_Dclass_average_03_25_2024_17_14_32_epoch_2270_ohe_ensemble_8': mn_Dclass_average_03_25_2024_17_14_32_epoch_2270_ohe_ensemble_8,
    #'mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8': mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8,
    'mn_Dclass_average_03_25_2024_17_14_32_epoch_3970_ohe_ensemble_8': mn_Dclass_average_03_25_2024_17_14_32_epoch_3970_ohe_ensemble_8,

    'mn_Dclass_average_fourierfeatures16_05_09_2024_01_03_23_epoch_100': mn_Dclass_average_fourierfeatures16_05_09_2024_01_03_23_epoch_100,
    
        'baam_nsamples500_numfeatures10_04_07_2024_17_04_53_epoch_1490': baam_nsamples500_numfeatures10_04_07_2024_17_04_53_epoch_1490,
   'baam_nsamples500_numfeatures10_04_07_2024_17_04_53_epoch_1780': baam_nsamples500_numfeatures10_04_07_2024_17_04_53_epoch_1780,  #  <- TRY THIS
    'baam_nsamples500_numfeatures20_04_04_2024_03_07_12_epoch_1210': baam_nsamples500_numfeatures20_04_04_2024_03_07_12_epoch_1210,
    #'additive_1_gpu_02_14_2024_16_34_15': additive_1_gpu_02_14_2024_16_34_15,
    #'additive_step_prior_02_08_2024': additive_step_prior_02_08_2024,
  # 'additive_02_20_2024_factorized_weight_decay': additive_02_20_2024_factorized_weight_decay,
   # 'additive_Dclass_average_02_29_2024_04_15_55_epoch_190_no_ensemble': additive_Dclass_average_02_29_2024_04_15_55_epoch_190_no_ensemble,
   # 'additive_Dclass_average_02_29_2024_04_15_55_epoch_190': additive_Dclass_average_02_29_2024_04_15_55_epoch_190,
   # 'additive_Dclass_average_02_29_2024_04_15_55_epoch_560': additive_Dclass_average_02_29_2024_04_15_55_epoch_560,
    #'additive_Dclass_average_02_29_2024_04_15_55_epoch_730': additive_Dclass_average_02_29_2024_04_15_55_epoch_730,
    #'additive_Dclass_average_02_29_2024_04_15_55_epoch_780': additive_Dclass_average_02_29_2024_04_15_55_epoch_780,
    #'additive_Dclass_average_02_29_2024_04_15_55_epoch_850': additive_Dclass_average_02_29_2024_04_15_55_epoch_850,
    'additive_Dclass_average_02_29_2024_04_15_55_epoch_1050': additive_Dclass_average_02_29_2024_04_15_55_epoch_1050,
    'baam_e128_nsamples500_numfeatures20_04_01_2024_15_38_54_epoch_2220': baam_e128_nsamples500_numfeatures20_04_01_2024_15_38_54_epoch_2220,
    #'additive_Dclass_average_inputlayernormTrue_02_29_2024_20_52_12_epoch_1340_retry': additive_Dclass_average_inputlayernormTrue_02_29_2024_20_52_12_epoch_1340,
    #'baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_140': baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_140,
    #'baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_150_redo': baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_150,
     #'baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_220': baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_220,
    
    'baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_550': baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_550,
    # this one overfitted and then kinda recovered? but not very good in the end.
#'baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_410': baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_410,
    #'baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_780': baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_780,
    #'baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_1010': baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_1010,
    #'baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_1210': baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_1210,
    #'baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_1520': baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_1520,
    #'baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_2940': baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_2940,
    

    #'additive_Dclass_average_multiclassmaxsteps3_multiclasstypesteps_03_04_2024_19_04_03_epoch_100': additive_Dclass_average_multiclassmaxsteps3_multiclasstypesteps_03_04_2024_19_04_03_epoch_100,
   # 'additive_Dclass_average_multiclassmaxsteps3_multiclasstypesteps_03_04_2024_19_04_03_epoch_270': additive_Dclass_average_multiclassmaxsteps3_multiclasstypesteps_03_04_2024_19_04_03_epoch_270,

  #  'additive_H512_Dclass_average_factorizedoutputTrue_L6_03_01_2024_21_38_55_epoch_360': additive_H512_Dclass_average_factorizedoutputTrue_L6_03_01_2024_21_38_55_epoch_360,
  #  'additive_Dclass_average_factorizedoutputTrue_w001_03_02_2024_02_21_10_epoch_340': additive_Dclass_average_factorizedoutputTrue_w001_03_02_2024_02_21_10_epoch_340,
    #'additive_Dclass_average_factorizedoutputTrue_w001_03_02_024_02_21_10_epoch_420': additive_Dclass_average_factorizedoutputTrue_w001_03_02_2024_02_21_10_epoch_420,
    #'additive_Dclass_average_factorizedoutputTrue_w001_03_02_2024_02_21_10_epoch_1210': additive_Dclass_average_factorizedoutputTrue_w001_03_02_2024_02_21_10_epoch_1210,
    #'additive_Dclass_average_factorizedoutputTrue_nshapefunctions128_outputrank64_shapeattentionTrue_shapeattentionheads8_03_07_2024_00_39_41_epoch_580': additive_Dclass_average_factorizedoutputTrue_nshapefunctions128_outputrank64_shapeattentionTrue_shapeattentionheads8_03_07_2024_00_39_41_epoch_580,
    #'additive_Dclass_average_factorizedoutputTrue_nshapefunctions128_outputrank64_shapeattentionTrue_shapeattentionheads8_03_08_2024_21_19_43_epoch_1280': additive_Dclass_average_factorizedoutputTrue_nshapefunctions128_outputrank64_shapeattentionTrue_shapeattentionheads8_03_08_2024_21_19_43_epoch_1280,
    #'baam_e128_featurecurriculumTrue_nsamples500_N6_03_21_2024_17_58_54_epoch_610': baam_e128_featurecurriculumTrue_nsamples500_N6_03_21_2024_17_58_54_epoch_610,
    #'baam_e128_featurecurriculumTrue_nsamples500_N6_03_21_2024_17_58_54_epoch_660': baam_e128_featurecurriculumTrue_nsamples500_N6_03_21_2024_17_58_54_epoch_660,
    # 'baam_e128_nbins512_nsamples500_numfeatures20_03_19_2024_22_53_00_epoch_160': baam_e128_nbins512_nsamples500_numfeatures20_03_19_2024_22_53_00_epoch_160,
    # 'baam_nfeatures_20_no_ensemble': baam_nfeatures_20_no_ensemble,
    # 'baam_nfeatures_20': baam_nfeatures_20,
    #'baam_nfeatures_100_no_ensemble': baam_nfeatures_100_no_ensemble,
    # 'baam_nfeatures_20_no_ensemble_e500': baam_nfeatures_20_no_ensemble_e500,
#     'baam_nfeatures_20_no_ensemble_e650': baam_nfeatures_20_no_ensemble_e650,
#    'baam_nfeatures_20_no_ensemble_e840': baam_nfeatures_20_no_ensemble_e840,
    "baam_nfeatures_20_no_ensemble_e1210": baam_nfeatures_20_no_ensemble_e1210,
    "baam_nfeatures_20_no_ensemble_e1520": baam_nfeatures_20_no_ensemble_e1520,
    "baam_nfeatures_20_no_ensemble_e1970": baam_nfeatures_20_no_ensemble_e1970,

    'baam_nsamples500_numfeatures20_03_27_2024_17_57_59_epoch_470': baam_nsamples500_numfeatures20_03_27_2024_17_57_59_epoch_470,

    #'baam_fourierfeatures64_nbins512_nsamples500_numfeatures20_03_21_2024_01_02_32_epoch_310': baam_fourierfeatures64_nbins512_nsamples500_numfeatures20_03_21_2024_01_02_32_epoch_310,
    #'baam_marginalresidualTrue_nsamples500_numfeatures20_shapeinitzero_03_29_2024_19_20_09_epoch_130': baam_marginalresidualTrue_nsamples500_numfeatures20_shapeinitzero_03_29_2024_19_20_09_epoch_130,
    #    'baam_marginalresidualTrue_nsamples500_numfeatures20_shapeinitzero_03_29_2024_19_20_09_epoch_180': baam_marginalresidualTrue_nsamples500_numfeatures20_shapeinitzero_03_29_2024_19_20_09_epoch_180,

#    'baam_fourierfeatures64_nsamples500_numfeatures20_03_27_2024_17_56_13_epoch_250': baam_fourierfeatures64_nsamples500_numfeatures20_03_27_2024_17_56_13_epoch_250,
    'baam_fourierfeatures64_nsamples500_numfeatures20_03_27_2024_17_56_13_epoch_1230': baam_fourierfeatures64_nsamples500_numfeatures20_03_27_2024_17_56_13_epoch_1230,
    'baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_410': baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_410,
#        'baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_690': baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_690,
#    'baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_890': baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_890,
        'baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_1430': baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_1430,

    #'baam_categoricalembeddingTrue_nsamples500_nanbinTrue_numfeatures20_04_04_2024_14_06_39_epoch_970': baam_categoricalembeddingTrue_nsamples500_nanbinTrue_numfeatures20_04_04_2024_14_06_39_epoch_970,
    #'baam_categoricalembeddingTrue_nsamples500_nanbinTrue_numfeatures20_04_04_2024_14_06_39_epoch_190': baam_categoricalembeddingTrue_nsamples500_nanbinTrue_numfeatures20_04_04_2024_14_06_39_epoch_190,
    }
results_transformers = [
    eval_on_datasets('multiclass', model, model_name, cc_valid_datasets_multiclass, eval_positions=eval_positions, max_times=max_times,
                     metric_used=metric_used, split_numbers=[1, 2, 3, 4, 5],
                     n_samples=n_samples, base_path=base_path, overwrite=False, n_jobs=-1, device=device)
    for model_name, model in clf_dict.items()
]

evaluating mothernet on cpu


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.
[Parallel(n_jobs=-1)]: Done 173 tasks      | elapsed:    6.7s
[Parallel(n_jobs=-1)]: Done 456 tasks      | elapsed:    7.3s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:   17.8s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating mlp_distill on cpu


[Parallel(n_jobs=-1)]: Done 178 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating tabpfn on cpu


[Parallel(n_jobs=-1)]: Done 174 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating tabpfn_ours on cpu


[Parallel(n_jobs=-1)]: Done 176 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    0.4s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating tabpfn_ours_ensemble_8 on cpu


[Parallel(n_jobs=-1)]: Done 174 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating tabpfn_ours_ensemble_32 on cpu


[Parallel(n_jobs=-1)]: Done 175 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating batabpfn_e256_nsamples500_numfeatures20_03_20_2024_22_14_45_e630 on cpu


[Parallel(n_jobs=-1)]: Done 177 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    0.4s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating batabpfn_e256_nsamples500_numfeatures20_03_20_2024_22_14_45_e1130 on cpu


[Parallel(n_jobs=-1)]: Done 175 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    0.4s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating hyperfast_no_optimize_cpu on cpu


[Parallel(n_jobs=-1)]: Done 178 tasks      | elapsed:    0.4s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    1.2s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating hyperfast_defaults_cpu on cpu


[Parallel(n_jobs=-1)]: Done 174 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating ebm_default on cpu


[Parallel(n_jobs=-1)]: Done 177 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating ebm_bins_main_effects on cpu


[Parallel(n_jobs=-1)]: Done 176 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.
[Parallel(n_jobs=-1)]: Done 176 tasks      | elapsed:    0.2s


evaluating mn_Dclass_average_03_25_2024_17_14_32_epoch_3970_ohe_ensemble_8 on cpu


[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    0.4s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating mn_Dclass_average_fourierfeatures16_05_09_2024_01_03_23_epoch_100 on cpu


[Parallel(n_jobs=-1)]: Done 176 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    0.4s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_nsamples500_numfeatures10_04_07_2024_17_04_53_epoch_1490 on cpu


[Parallel(n_jobs=-1)]: Done 178 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_nsamples500_numfeatures10_04_07_2024_17_04_53_epoch_1780 on cpu


[Parallel(n_jobs=-1)]: Done 176 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    0.4s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_nsamples500_numfeatures20_04_04_2024_03_07_12_epoch_1210 on cpu


[Parallel(n_jobs=-1)]: Done 175 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating additive_Dclass_average_02_29_2024_04_15_55_epoch_1050 on cpu


[Parallel(n_jobs=-1)]: Done 175 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_e128_nsamples500_numfeatures20_04_01_2024_15_38_54_epoch_2220 on cpu


[Parallel(n_jobs=-1)]: Done 173 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_550 on cpu


[Parallel(n_jobs=-1)]: Done 173 tasks      | elapsed:    0.7s
[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    0.9s remaining:    0.3s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    1.0s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_nfeatures_20_no_ensemble_e1210 on cpu


[Parallel(n_jobs=-1)]: Done 173 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_nfeatures_20_no_ensemble_e1520 on cpu


[Parallel(n_jobs=-1)]: Done 177 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_nfeatures_20_no_ensemble_e1970 on cpu


[Parallel(n_jobs=-1)]: Done 177 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.
[Parallel(n_jobs=-1)]: Done 174 tasks      | elapsed:    0.2s


evaluating baam_nsamples500_numfeatures20_03_27_2024_17_57_59_epoch_470 on cpu


[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    0.4s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_fourierfeatures64_nsamples500_numfeatures20_03_27_2024_17_56_13_epoch_1230 on cpu


[Parallel(n_jobs=-1)]: Done 176 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    0.4s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_410 on cpu


[Parallel(n_jobs=-1)]: Done 179 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    0.4s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_1430 on cpu


[Parallel(n_jobs=-1)]: Done 176 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    0.4s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished


In [15]:
flat_results = []
for per_dataset in results_baselines + results_transformers:
    for result in per_dataset:
        row = {}
        for key in ['dataset', 'model', 'mean_metric', 'split', 'max_time']:
            row[key] = result[key]
        best_configs_key, = [k for k in result.keys() if "best_configs" in k]
        if result[best_configs_key][0] is not None:
            row.update(result[best_configs_key][0])
        row['mean_metric'] = float(row["mean_metric"].numpy())
        flat_results.append(row)

results_df = pd.DataFrame(flat_results)

In [16]:
results_df

Unnamed: 0,dataset,model,mean_metric,split,max_time,best,fit_time,inference_time
0,breast-cancer,knn,0.661028,1,1,{'n_neighbors': 13},0.000223,0.001148
1,breast-cancer,knn,0.687982,2,1,{'n_neighbors': 13},0.000218,0.001148
2,breast-cancer,knn,0.639348,3,1,{'n_neighbors': 14},0.000214,0.001040
3,breast-cancer,knn,0.697394,4,1,{'n_neighbors': 14},0.000223,0.001092
4,breast-cancer,knn,0.687948,5,1,{'n_neighbors': 11},0.000226,0.001065
...,...,...,...,...,...,...,...,...
46185,titanic,baam_fourierfeatures64_nbins128_nsamples500_nu...,0.865758,1,1,,6.587969,0.001242
46186,titanic,baam_fourierfeatures64_nbins128_nsamples500_nu...,0.872072,2,1,,6.986520,0.001358
46187,titanic,baam_fourierfeatures64_nbins128_nsamples500_nu...,0.865263,3,1,,6.652177,0.000900
46188,titanic,baam_fourierfeatures64_nbins128_nsamples500_nu...,0.874002,4,1,,6.489160,0.000972


In [17]:
results_df.model.unique()

array(['knn', 'rf_new_params', 'xgb', 'logistic', 'mlp', 'mothernet',
       'mlp_distill', 'tabpfn', 'tabpfn_ours', 'tabpfn_ours_ensemble_8',
       'tabpfn_ours_ensemble_32',
       'batabpfn_e256_nsamples500_numfeatures20_03_20_2024_22_14_45_e630',
       'batabpfn_e256_nsamples500_numfeatures20_03_20_2024_22_14_45_e1130',
       'hyperfast_no_optimize_cpu', 'hyperfast_defaults_cpu',
       'ebm_default', 'ebm_bins_main_effects',
       'mn_Dclass_average_03_25_2024_17_14_32_epoch_3970_ohe_ensemble_8',
       'mn_Dclass_average_fourierfeatures16_05_09_2024_01_03_23_epoch_100',
       'baam_nsamples500_numfeatures10_04_07_2024_17_04_53_epoch_1490',
       'baam_nsamples500_numfeatures10_04_07_2024_17_04_53_epoch_1780',
       'baam_nsamples500_numfeatures20_04_04_2024_03_07_12_epoch_1210',
       'additive_Dclass_average_02_29_2024_04_15_55_epoch_1050',
       'baam_e128_nsamples500_numfeatures20_04_01_2024_15_38_54_epoch_2220',
       'baam_fourierfeatures64_nbins128_nsamples500_num

In [18]:
# with open(f"results_validation_{datetime.date.today()}.pickle", "wb") as f:
#    pickle.dump(results_baselines + results_transformers, f)

In [19]:
results_df['model'] = results_df.model.replace({'knn': "KNN", 'rf_new_params': 'RF', 'mlp': "MLP",'mlp_distill': 'MLP-Distill', 'xgb':'XGBoost', 'logistic': 'LogReg',  'mothernet': 'MotherNet', 'tabpfn': 'TabPFN'})

In [20]:
filename = f"results_validation_{datetime.date.today()}.csv"
results_df.to_csv(filename)
filename

'results_validation_2024-05-22.csv'