In [1]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import pickle
import datetime
import warnings
#warnings.simplefilter("ignore", FutureWarning)  # openml deprecation of array return type
warnings.simplefilter("ignore", UserWarning)  # scikit-learn select k best
warnings.simplefilter("ignore", RuntimeWarning)  # scikit-learn select k best

from ticl.datasets import load_openml_list, open_cc_valid_dids, open_cc_dids
from ticl.evaluation.baselines.tabular_baselines import knn_metric, catboost_metric, logistic_metric, xgb_metric, random_forest_metric, mlp_metric, hyperfast_metric, resnet_metric, mothernet_init_metric
from ticl.evaluation.tabular_evaluation import evaluate, eval_on_datasets, transformer_metric
from ticl.evaluation import tabular_metrics
from ticl.prediction.tabpfn import TabPFNClassifier
from ticl.evaluation.baselines import tabular_baselines

# Datasets

In [2]:
from ticl.datasets import load_openml_list, open_cc_dids, open_cc_valid_dids, test_dids_classification

cc_valid_datasets_multiclass, cc_valid_datasets_multiclass_df = load_openml_list(open_cc_valid_dids, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = 10000, num_feats=100, return_capped=True, classification=True)


  openml_list = openml.datasets.list_datasets(dids)


Number of datasets: 149


In [3]:
cc_valid_datasets_multiclass_df['NumberOfInstances'] =  cc_valid_datasets_multiclass_df['NumberOfInstances'].astype(int)
cc_valid_datasets_multiclass_df['NumberOfFeatures'] =  cc_valid_datasets_multiclass_df['NumberOfFeatures'].astype(int)
cc_valid_datasets_multiclass_df['NumberOfClasses'] =  cc_valid_datasets_multiclass_df['NumberOfClasses'].astype(int)

# uncomment for latex table of datasets
# print(cc_valid_datasets_multiclass_df[['did', 'name', 'NumberOfFeatures', 'NumberOfInstances', 'NumberOfClasses']].rename(columns={'NumberOfFeatures': "d", "NumberOfInstances":"n", "NumberOfClasses": "k"}).to_latex(index=False))

# Setting params

In [4]:
import os
eval_positions = [1000]
max_features = 100
n_samples = 2000
base_path = os.path.join('..')
overwrite = False
metric_used = tabular_metrics.auc_metric
task_type = 'multiclass'

# Baseline Evaluation
This section runs baselines and saves results locally.

In [5]:
!mkdir -p {base_path}/results
!mkdir -p {base_path}/results/tabular/
!mkdir -p {base_path}/results/tabular/multiclass/

In [6]:
len(cc_valid_datasets_multiclass)

149

In [7]:
from sklearn import set_config
set_config(skip_parameter_validation=True, assume_finite=True)

In [8]:
from ticl.evaluation.tabular_evaluation import eval_on_datasets
from ticl.prediction.mothernet import ShiftClassifier, EnsembleMeta, MotherNetClassifier, MotherNetInitMLPClassifier
from ticl.prediction.mothernet_additive import MotherNetAdditiveClassifier
from ticl.evaluation.baselines.distill_mlp import DistilledTabPFNMLP
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from interpret.glassbox import ExplainableBoostingClassifier
from functools import partial
from hyperfast import HyperFastClassifier

import warnings
max_times = [1]
device = "cuda:3"
#device = "cpu"

In [9]:
max_times = [15]
clf_dict= {
    'mothernet_gd_gpu_no_learn_default': mothernet_init_metric}

results_mlp = [
    eval_on_datasets('multiclass', model, model_name, cc_valid_datasets_multiclass, eval_positions=eval_positions, max_times=max_times,
                     metric_used=metric_used, split_numbers=[1, 2, 3, 4, 5],
                     n_samples=n_samples, base_path=base_path, n_jobs=1, device="cuda:1", verbose=0)
    for model_name, model in clf_dict.items()
]

evaluating mothernet_gd_gpu_no_learn_default on cuda:1


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 745/745 [00:01<00:00, 402.91it/s]


In [10]:
max_times = [1]
clf_dict= {
    'mothernet_gd_gpu4': mothernet_init_metric}

results_mothernet_gd = [
    eval_on_datasets('multiclass', model, model_name, cc_valid_datasets_multiclass, eval_positions=eval_positions, max_times=max_times,
                     metric_used=metric_used, split_numbers=[1, 2, 3, 4, 5],
                     n_samples=n_samples, base_path=base_path, n_jobs=1, device="cuda:1", verbose=0)
    for model_name, model in clf_dict.items()
]

evaluating mothernet_gd_gpu4 on cuda:1


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 745/745 [00:01<00:00, 404.27it/s]


In [11]:
# MLP GPU
max_times = [15]
# these will all be evaluated on CPU because they are given as callables, which is a weird way to do it.
clf_dict= {
    'mlp_gpu2': mlp_metric}

results_mlp = [
    eval_on_datasets('multiclass', model, model_name, cc_valid_datasets_multiclass, eval_positions=eval_positions, max_times=max_times,
                     metric_used=metric_used, split_numbers=[1, 2, 3, 4, 5],
                     n_samples=n_samples, base_path=base_path, n_jobs=1, device="cuda:3", verbose=1)
    for model_name, model in clf_dict.items()
]

evaluating mlp_gpu2 on cuda:3


  3%|██████████▉                                                                                                                                                                                                                                                                                                                                        | 24/745 [00:00<00:03, 224.24it/s]

Evaluating breast-cancer with 286 samples
Evaluating breast-cancer with 286 samples
Evaluating breast-cancer with 286 samples
Evaluating breast-cancer with 286 samples
Evaluating breast-cancer with 286 samples
Evaluating colic with 368 samples
Evaluating colic with 368 samples
Evaluating colic with 368 samples
Evaluating colic with 368 samples
Evaluating colic with 368 samples
Evaluating dermatology with 366 samples
Evaluating dermatology with 366 samples
Evaluating dermatology with 366 samples
Evaluating dermatology with 366 samples
Evaluating dermatology with 366 samples
Evaluating sonar with 208 samples
Evaluating sonar with 208 samples
Evaluating sonar with 208 samples
Evaluating sonar with 208 samples
Evaluating sonar with 208 samples
Evaluating glass with 214 samples
Evaluating glass with 214 samples
Evaluating glass with 214 samples
Evaluating glass with 214 samples
Evaluating glass with 214 samples
Evaluating haberman with 306 samples
Evaluating haberman with 306 samples
Evalua

 15%|██████████████████████████████████████████████████▊                                                                                                                                                                                                                                                                                               | 112/745 [00:00<00:01, 368.00it/s]

Evaluating wine with 178 samples
Evaluating wine with 178 samples
Evaluating wine with 178 samples
Evaluating wine with 178 samples
Evaluating hayes-roth with 160 samples
Evaluating hayes-roth with 160 samples
Evaluating hayes-roth with 160 samples
Evaluating hayes-roth with 160 samples
Evaluating hayes-roth with 160 samples
Evaluating monks-problems-1 with 556 samples
Evaluating monks-problems-1 with 556 samples
Evaluating monks-problems-1 with 556 samples
Evaluating monks-problems-1 with 556 samples
Evaluating monks-problems-1 with 556 samples
Evaluating monks-problems-2 with 601 samples
Evaluating monks-problems-2 with 601 samples
Evaluating monks-problems-2 with 601 samples
Evaluating monks-problems-2 with 601 samples
Evaluating monks-problems-2 with 601 samples
Evaluating monks-problems-3 with 554 samples
Evaluating monks-problems-3 with 554 samples
Evaluating monks-problems-3 with 554 samples
Evaluating monks-problems-3 with 554 samples
Evaluating monks-problems-3 with 554 sample

 20%|███████████████████████████████████████████████████████████████████▌                                                                                                                                                                                                                                                                              | 149/745 [00:00<00:02, 290.47it/s]

Evaluating analcatdata_broadwaymult with 285 samples
Evaluating analcatdata_broadwaymult with 285 samples
Evaluating analcatdata_reviewer with 379 samples
Evaluating analcatdata_reviewer with 379 samples
Evaluating analcatdata_reviewer with 379 samples
Evaluating analcatdata_reviewer with 379 samples
Evaluating analcatdata_reviewer with 379 samples
Evaluating backache with 180 samples
Evaluating backache with 180 samples
Evaluating backache with 180 samples
Evaluating backache with 180 samples
Evaluating backache with 180 samples
Evaluating prnn_synth with 250 samples
Evaluating prnn_synth with 250 samples
Evaluating prnn_synth with 250 samples
Evaluating prnn_synth with 250 samples
Evaluating prnn_synth with 250 samples
Evaluating schizo with 340 samples
Evaluating schizo with 340 samples
Evaluating schizo with 340 samples
Evaluating schizo with 340 samples
Evaluating schizo with 340 samples
Evaluating profb with 672 samples
Evaluating profb with 672 samples
Evaluating profb with 672 

 34%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                                                                                                                                                                                                | 251/745 [00:00<00:01, 345.54it/s]

Evaluating diggle_table_a2 with 310 samples
Evaluating diggle_table_a2 with 310 samples
Evaluating rmftsa_ladata with 508 samples
Evaluating rmftsa_ladata with 508 samples
Evaluating rmftsa_ladata with 508 samples
Evaluating rmftsa_ladata with 508 samples
Evaluating rmftsa_ladata with 508 samples
Evaluating pwLinear with 200 samples
Evaluating pwLinear with 200 samples
Evaluating pwLinear with 200 samples
Evaluating pwLinear with 200 samples
Evaluating pwLinear with 200 samples
Evaluating analcatdata_vineyard with 468 samples
Evaluating analcatdata_vineyard with 468 samples
Evaluating analcatdata_vineyard with 468 samples
Evaluating analcatdata_vineyard with 468 samples
Evaluating analcatdata_vineyard with 468 samples
Evaluating machine_cpu with 209 samples
Evaluating machine_cpu with 209 samples
Evaluating machine_cpu with 209 samples
Evaluating machine_cpu with 209 samples
Evaluating machine_cpu with 209 samples
Evaluating pharynx with 195 samples
Evaluating pharynx with 195 samples


 53%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                                                                                                                               | 393/745 [00:00<00:00, 515.34it/s]

Evaluating cholesterol with 303 samples
Evaluating cholesterol with 303 samples
Evaluating cholesterol with 303 samples
Evaluating cholesterol with 303 samples
Evaluating cholesterol with 303 samples
Evaluating chscase_funds with 185 samples
Evaluating chscase_funds with 185 samples
Evaluating chscase_funds with 185 samples
Evaluating chscase_funds with 185 samples
Evaluating chscase_funds with 185 samples
Evaluating pbcseq with 1945 samples
Evaluating pbcseq with 1945 samples
Evaluating pbcseq with 1945 samples
Evaluating pbcseq with 1945 samples
Evaluating pbcseq with 1945 samples
Evaluating pbc with 418 samples
Evaluating pbc with 418 samples
Evaluating pbc with 418 samples
Evaluating pbc with 418 samples
Evaluating pbc with 418 samples
Evaluating rmftsa_ctoarrivals with 264 samples
Evaluating rmftsa_ctoarrivals with 264 samples
Evaluating rmftsa_ctoarrivals with 264 samples
Evaluating rmftsa_ctoarrivals with 264 samples
Evaluating rmftsa_ctoarrivals with 264 samples
Evaluating chsc

 72%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                                              | 537/745 [00:01<00:00, 609.57it/s]

Evaluating disclosure_z with 662 samples
Evaluating disclosure_z with 662 samples
Evaluating disclosure_z with 662 samples
Evaluating disclosure_z with 662 samples
Evaluating disclosure_z with 662 samples
Evaluating socmob with 1156 samples
Evaluating socmob with 1156 samples
Evaluating socmob with 1156 samples
Evaluating socmob with 1156 samples
Evaluating socmob with 1156 samples
Evaluating chscase_whale with 228 samples
Evaluating chscase_whale with 228 samples
Evaluating chscase_whale with 228 samples
Evaluating chscase_whale with 228 samples
Evaluating chscase_whale with 228 samples
Evaluating water-treatment with 527 samples
Evaluating water-treatment with 527 samples
Evaluating water-treatment with 527 samples
Evaluating water-treatment with 527 samples
Evaluating water-treatment with 527 samples
Evaluating lowbwt with 189 samples
Evaluating lowbwt with 189 samples
Evaluating lowbwt with 189 samples
Evaluating lowbwt with 189 samples
Evaluating lowbwt with 189 samples
Evaluating

 81%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                 | 602/745 [00:01<00:00, 586.84it/s]

Evaluating sa-heart with 462 samples
Evaluating sa-heart with 462 samples
Evaluating sa-heart with 462 samples
Evaluating sa-heart with 462 samples
Evaluating seeds with 210 samples
Evaluating seeds with 210 samples
Evaluating seeds with 210 samples
Evaluating seeds with 210 samples
Evaluating seeds with 210 samples
Evaluating thoracic-surgery with 470 samples
Evaluating thoracic-surgery with 470 samples
Evaluating thoracic-surgery with 470 samples
Evaluating thoracic-surgery with 470 samples
Evaluating thoracic-surgery with 470 samples
Evaluating user-knowledge with 403 samples
Evaluating user-knowledge with 403 samples
Evaluating user-knowledge with 403 samples
Evaluating user-knowledge with 403 samples
Evaluating user-knowledge with 403 samples
Evaluating wholesale-customers with 440 samples
Evaluating wholesale-customers with 440 samples
Evaluating wholesale-customers with 440 samples
Evaluating wholesale-customers with 440 samples
Evaluating wholesale-customers with 440 samples
Ev

 96%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████               | 712/745 [00:01<00:00, 361.06it/s]

Evaluating thyroid-new with 215 samples
Evaluating thyroid-new with 215 samples
Evaluating solar-flare with 315 samples
Evaluating solar-flare with 315 samples
Evaluating solar-flare with 315 samples
Evaluating solar-flare with 315 samples
Evaluating solar-flare with 315 samples
Evaluating threeOf9 with 512 samples
Evaluating threeOf9 with 512 samples
Evaluating threeOf9 with 512 samples
Evaluating threeOf9 with 512 samples
Evaluating threeOf9 with 512 samples
Evaluating xd6 with 973 samples
Evaluating xd6 with 973 samples
Evaluating xd6 with 973 samples
Evaluating xd6 with 973 samples
Evaluating xd6 with 973 samples
Evaluating tokyo1 with 959 samples
Evaluating tokyo1 with 959 samples
Evaluating tokyo1 with 959 samples
Evaluating tokyo1 with 959 samples
Evaluating tokyo1 with 959 samples
Evaluating parity5_plus_5 with 1124 samples
Evaluating parity5_plus_5 with 1124 samples
Evaluating parity5_plus_5 with 1124 samples
Evaluating parity5_plus_5 with 1124 samples
Evaluating parity5_plus_

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 745/745 [00:01<00:00, 384.15it/s]

Evaluating Touch2 with 265 samples
Evaluating Touch2 with 265 samples
Evaluating Touch2 with 265 samples
Evaluating penguins with 344 samples
Evaluating penguins with 344 samples
Evaluating penguins with 344 samples
Evaluating penguins with 344 samples
Evaluating penguins with 344 samples
Evaluating titanic with 891 samples
Evaluating titanic with 891 samples
Evaluating titanic with 891 samples
Evaluating titanic with 891 samples
Evaluating titanic with 891 samples





In [12]:
max_times = [15, 60]
# these will all be evaluated on CPU because they are given as callables, which is a weird way to do it.
clf_dict= {
    'resnet_gpu': resnet_metric}

results_resnet = [
    eval_on_datasets('multiclass', model, model_name, cc_valid_datasets_multiclass, eval_positions=eval_positions, max_times=max_times,
                     metric_used=metric_used, split_numbers=[1, 2, 3, 4, 5],
                     n_samples=n_samples, base_path=base_path, n_jobs=1, device="cuda:3", verbose=0)
    for model_name, model in clf_dict.items()
]

evaluating resnet_gpu on cuda:3


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1490/1490 [00:03<00:00, 394.92it/s]


In [13]:
max_times = [1, 5, 15, 60, 5 * 60, 15 * 60, 60* 60]
# these will all be evaluated on CPU because they are given as callables, which is a weird way to do it.
clf_dict= {
    'knn': knn_metric,
    'rf_new_params': random_forest_metric,
    'xgb': xgb_metric,
    'logistic': logistic_metric,
    'mlp': mlp_metric}

results_baselines = [
    eval_on_datasets('multiclass', model, model_name, cc_valid_datasets_multiclass, eval_positions=eval_positions, max_times=max_times,
                     metric_used=metric_used, split_numbers=[1, 2, 3, 4, 5],
                     n_samples=n_samples, base_path=base_path, n_jobs=1)
    for model_name, model in clf_dict.items()
]

[Parallel(n_jobs=1)]: Done  40 tasks      | elapsed:    0.1s


evaluating knn on cpu


[Parallel(n_jobs=1)]: Done 161 tasks      | elapsed:    0.7s
[Parallel(n_jobs=1)]: Done 364 tasks      | elapsed:    1.1s
[Parallel(n_jobs=1)]: Done 647 tasks      | elapsed:    1.6s
[Parallel(n_jobs=1)]: Done 1012 tasks      | elapsed:    3.0s
[Parallel(n_jobs=1)]: Done 1457 tasks      | elapsed:    4.6s
[Parallel(n_jobs=1)]: Done 1984 tasks      | elapsed:    5.3s
[Parallel(n_jobs=1)]: Done 2591 tasks      | elapsed:    6.1s
[Parallel(n_jobs=1)]: Done 3280 tasks      | elapsed:    7.0s
[Parallel(n_jobs=1)]: Done 4049 tasks      | elapsed:    8.2s
[Parallel(n_jobs=1)]: Done 4900 tasks      | elapsed:   11.9s


evaluating rf_new_params on cpu


[Parallel(n_jobs=1)]: Done  40 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 161 tasks      | elapsed:    0.8s
[Parallel(n_jobs=1)]: Done 364 tasks      | elapsed:    1.3s
[Parallel(n_jobs=1)]: Done 647 tasks      | elapsed:    1.9s
[Parallel(n_jobs=1)]: Done 1012 tasks      | elapsed:    4.1s
[Parallel(n_jobs=1)]: Done 1457 tasks      | elapsed:    6.0s
[Parallel(n_jobs=1)]: Done 1984 tasks      | elapsed:    6.7s
[Parallel(n_jobs=1)]: Done 2591 tasks      | elapsed:    7.6s
[Parallel(n_jobs=1)]: Done 3280 tasks      | elapsed:    8.6s
[Parallel(n_jobs=1)]: Done 4049 tasks      | elapsed:    9.8s
[Parallel(n_jobs=1)]: Done 4900 tasks      | elapsed:   13.7s


evaluating xgb on cpu


[Parallel(n_jobs=1)]: Done  40 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 161 tasks      | elapsed:    0.8s
[Parallel(n_jobs=1)]: Done 364 tasks      | elapsed:    1.3s
[Parallel(n_jobs=1)]: Done 647 tasks      | elapsed:    1.9s
[Parallel(n_jobs=1)]: Done 1012 tasks      | elapsed:    3.5s
[Parallel(n_jobs=1)]: Done 1457 tasks      | elapsed:    5.2s
[Parallel(n_jobs=1)]: Done 1984 tasks      | elapsed:    5.9s
[Parallel(n_jobs=1)]: Done 2591 tasks      | elapsed:    6.7s
[Parallel(n_jobs=1)]: Done 3280 tasks      | elapsed:    7.6s
[Parallel(n_jobs=1)]: Done 4049 tasks      | elapsed:    8.8s
[Parallel(n_jobs=1)]: Done 4900 tasks      | elapsed:   12.6s


evaluating logistic on cpu


[Parallel(n_jobs=1)]: Done  40 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 161 tasks      | elapsed:    0.7s
[Parallel(n_jobs=1)]: Done 364 tasks      | elapsed:    1.2s
[Parallel(n_jobs=1)]: Done 647 tasks      | elapsed:    1.7s
[Parallel(n_jobs=1)]: Done 1012 tasks      | elapsed:    3.1s
[Parallel(n_jobs=1)]: Done 1457 tasks      | elapsed:    4.8s
[Parallel(n_jobs=1)]: Done 1984 tasks      | elapsed:    5.5s
[Parallel(n_jobs=1)]: Done 2591 tasks      | elapsed:    6.3s
[Parallel(n_jobs=1)]: Done 3280 tasks      | elapsed:    7.4s
[Parallel(n_jobs=1)]: Done 4049 tasks      | elapsed:    8.6s
[Parallel(n_jobs=1)]: Done 4900 tasks      | elapsed:   12.7s


evaluating mlp on cpu


[Parallel(n_jobs=1)]: Done  40 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 161 tasks      | elapsed:    0.8s
[Parallel(n_jobs=1)]: Done 364 tasks      | elapsed:    1.2s
[Parallel(n_jobs=1)]: Done 647 tasks      | elapsed:    1.8s
[Parallel(n_jobs=1)]: Done 1012 tasks      | elapsed:    3.7s
[Parallel(n_jobs=1)]: Done 1457 tasks      | elapsed:    5.4s
[Parallel(n_jobs=1)]: Done 1984 tasks      | elapsed:    6.2s
[Parallel(n_jobs=1)]: Done 2591 tasks      | elapsed:    7.1s
[Parallel(n_jobs=1)]: Done 3280 tasks      | elapsed:    8.0s
[Parallel(n_jobs=1)]: Done 4049 tasks      | elapsed:    9.2s
[Parallel(n_jobs=1)]: Done 4900 tasks      | elapsed:   13.3s


In [14]:
from ticl.evaluation.tabular_evaluation import eval_on_datasets
from ticl.prediction.mothernet import ShiftClassifier, EnsembleMeta, MotherNetClassifier, MotherNetInitMLPClassifier
from ticl.prediction.mothernet_additive import MotherNetAdditiveClassifier
from ticl.evaluation.baselines.distill_mlp import DistilledTabPFNMLP
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from interpret.glassbox import ExplainableBoostingClassifier
from functools import partial
from hyperfast import HyperFastClassifier

import warnings
max_times = [1]
device = "cuda:3"
#device = "cpu"

model_string = "tabpfn_nooptimizer_emsize_512_nlayers_12_steps_2048_bs_32ada_lr_0.0001_1_gpu_07_24_2023_01_43_33"
tabpfn_ours = TabPFNClassifier(device=device, model_string=model_string, epoch="1650", N_ensemble_configurations=3)
mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_03_25_2024_17_14_32_epoch_2910.cpkt", device=device), n_estimators=8, onehot=True)
mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8_no_power = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_03_25_2024_17_14_32_epoch_2910.cpkt", device=device), n_estimators=8, onehot=True, power=False)
mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8_quantile = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_03_25_2024_17_14_32_epoch_2910.cpkt", device=device), n_estimators=8, onehot=True, power=False, always_quantile=True)
mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_3_quantile = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_03_25_2024_17_14_32_epoch_2910.cpkt", device=device), n_estimators=3, onehot=True, power=False, always_quantile=True)

mn_Dclass_average_numfeaturessamplerdouble_sample_05_08_2024_22_58_18_epoch_1290_ohe_8_quantile = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_numfeaturessamplerdouble_sample_05_08_2024_22_58_18_epoch_1290.cpkt", device=device), n_estimators=8, onehot=True, power=False, always_quantile=True)
mn_Dclass_average_numfeaturessamplerdouble_sample_05_08_2024_22_58_18_epoch_1290_ohe_8 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_numfeaturessamplerdouble_sample_05_08_2024_22_58_18_epoch_1290.cpkt", device=device), n_estimators=8, onehot=True, power=True, always_quantile=False)
mn_Dclass_average_numfeaturessamplerdouble_sample_05_08_2024_22_58_18_epoch_1290_ohe_quantile_8 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_numfeaturessamplerdouble_sample_05_08_2024_22_58_18_epoch_1290.cpkt", device=device), n_estimators=8, onehot=True, power='quantile')
mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_quantile_8 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_03_25_2024_17_14_32_epoch_2910.cpkt", device=device), n_estimators=8, onehot=True, power='quantile')

mn_categoricalfeaturep09_numfeaturessamplerdouble_sample_05_09_2024_23_39_30_epoch_1060_ohe_8 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_categoricalfeaturep0.9_numfeaturessamplerdouble_sample_05_09_2024_23_39_30_epoch_1060.cpkt", device=device), n_estimators=8, onehot=True, power=True)
mn_categoricalfeaturep09_numfeaturessamplerdouble_sample_05_09_2024_23_39_30_epoch_1060_ohe_quantile_8 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_categoricalfeaturep0.9_numfeaturessamplerdouble_sample_05_09_2024_23_39_30_epoch_1060.cpkt", device=device), n_estimators=8, onehot=True, power="quantile")

mn_Dclass_average_numfeaturessamplerdouble_sample_05_08_2024_22_58_18_epoch_1490_ohe_quantile_8 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_numfeaturessamplerdouble_sample_05_08_2024_22_58_18_epoch_1490.cpkt", device=device), n_estimators=8, onehot=True, power="quantile")
mn_Dclass_average_numfeaturessamplerdouble_sample_05_08_2024_22_58_18_epoch_1490_ohe_8 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_numfeaturessamplerdouble_sample_05_08_2024_22_58_18_epoch_1490.cpkt", device=device), n_estimators=8, onehot=True)


mn_categoricalfeaturep09_numfeaturessamplerdouble_sample_05_09_2024_23_39_30_epoch_1270_ohe_quantile_8 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_categoricalfeaturep0.9_numfeaturessamplerdouble_sample_05_09_2024_23_39_30_epoch_1270.cpkt", device=device), n_estimators=8, onehot=True, power="quantile")


mlp_distill = make_pipeline(StandardScaler(), DistilledTabPFNMLP(n_epochs=1000, device=device, hidden_size=128, n_layers=2, dropout_rate=.1, learning_rate=0.01, model_string=model_string, epoch=1650, N_ensemble_configurations=3))
mothernet_21_46_25_3940_ensemble3 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_d2048_H4096_L2_W32_P512_1_gpu_warm_08_25_2023_21_46_25_epoch_3940_no_optimizer.pickle", device=device), n_estimators=3)
ebm_bins_main_effects = ExplainableBoostingClassifier(max_bins=64, interactions=0)
baam_nfeatures_20_no_ensemble_e1520 = MotherNetAdditiveClassifier(path="../models_diff/baam_H512_Dclass_average_e128_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_03_22_epoch_1520.cpkt", device=device)

clf_dict= {
    'mothernet': partial(transformer_metric, classifier=mothernet_21_46_25_3940_ensemble3, onehot=True),
    'mlp_distill': mlp_distill,
    'tabpfn': transformer_metric,
    #'tabpfn_ours': tabpfn_ours,
    
    
    
    "hyperfast_no_optimize_gpu": partial(hyperfast_metric, optimization=None),
    "hyperfast_defaults_gpu": hyperfast_metric,
    #'ebm_bins_main_effects': ebm_bins_main_effects,
    #'baam_nfeatures_20_no_ensemble_e1520': baam_nfeatures_20_no_ensemble_e1520,
    'mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8': mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8,
    'mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8_no_power': mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8_no_power,
    'mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8_quantile2':mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8,
    'mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8_quantile': mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8_quantile,
    'mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_3_quantile': mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_3_quantile,
    'mn_Dclass_average_numfeaturessamplerdouble_sample_05_08_2024_22_58_18_epoch_1290_ohe_8_quantile': mn_Dclass_average_numfeaturessamplerdouble_sample_05_08_2024_22_58_18_epoch_1290_ohe_8_quantile,
    'mn_Dclass_average_numfeaturessamplerdouble_sample_05_08_2024_22_58_18_epoch_1290_ohe_8': mn_Dclass_average_numfeaturessamplerdouble_sample_05_08_2024_22_58_18_epoch_1290_ohe_8,
    'mn_Dclass_average_numfeaturessamplerdouble_sample_05_08_2024_22_58_18_epoch_1290_ohe_quantile_8_fixed2': mn_Dclass_average_numfeaturessamplerdouble_sample_05_08_2024_22_58_18_epoch_1290_ohe_quantile_8,
    'mn_categoricalfeaturep09_numfeaturessamplerdouble_sample_05_09_2024_23_39_30_epoch_1060_ohe_8': mn_categoricalfeaturep09_numfeaturessamplerdouble_sample_05_09_2024_23_39_30_epoch_1060_ohe_8,
    'mn_categoricalfeaturep09_numfeaturessamplerdouble_sample_05_09_2024_23_39_30_epoch_1060_ohe_quantile_8_fixed2': mn_categoricalfeaturep09_numfeaturessamplerdouble_sample_05_09_2024_23_39_30_epoch_1060_ohe_quantile_8,
    'mn_Dclass_average_numfeaturessamplerdouble_sample_05_08_2024_22_58_18_epoch_1490_ohe_quantile_8': mn_Dclass_average_numfeaturessamplerdouble_sample_05_08_2024_22_58_18_epoch_1490_ohe_quantile_8,
    'mn_Dclass_average_numfeaturessamplerdouble_sample_05_08_2024_22_58_18_epoch_1490_ohe_8': mn_Dclass_average_numfeaturessamplerdouble_sample_05_08_2024_22_58_18_epoch_1490_ohe_8,
    'mn_categoricalfeaturep09_numfeaturessamplerdouble_sample_05_09_2024_23_39_30_epoch_1270_ohe_quantile_8': mn_categoricalfeaturep09_numfeaturessamplerdouble_sample_05_09_2024_23_39_30_epoch_1270_ohe_quantile_8,

    }
results_transformers = [
    eval_on_datasets('multiclass', model, model_name, cc_valid_datasets_multiclass, eval_positions=eval_positions, max_times=max_times,
                     metric_used=metric_used, split_numbers=[1, 2, 3, 4, 5],
                     n_samples=n_samples, base_path=base_path, overwrite=False, n_jobs=-1, device=device)
    for model_name, model in clf_dict.items()
]

TypeError: EnsembleMeta.__init__() got an unexpected keyword argument 'always_quantile'

In [31]:
from ticl.evaluation.tabular_evaluation import eval_on_datasets
from ticl.prediction.mothernet import ShiftClassifier, EnsembleMeta, MotherNetClassifier, MotherNetInitMLPClassifier
from ticl.prediction.mothernet_additive import MotherNetAdditiveClassifier
from ticl.evaluation.baselines.distill_mlp import DistilledTabPFNMLP
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from interpret.glassbox import ExplainableBoostingClassifier
from functools import partial
from hyperfast import HyperFastClassifier

# transformers don't have max times
max_times = [1]
device = "cpu"

tabpfn_ours = TabPFNClassifier(device=device, model_string="tabpfn_nooptimizer_emsize_512_nlayers_12_steps_2048_bs_32ada_lr_0.0001_1_gpu_07_24_2023_01_43_33", epoch="1650", N_ensemble_configurations=3)

tabpfn_ours_ensemble_8 = TabPFNClassifier(device=device, model_string="tabpfn_nooptimizer_emsize_512_nlayers_12_steps_2048_bs_32ada_lr_0.0001_1_gpu_07_24_2023_01_43_33", epoch="1650", N_ensemble_configurations=8)
tabpfn_ours_ensemble_32 = TabPFNClassifier(device=device, model_string="tabpfn_nooptimizer_emsize_512_nlayers_12_steps_2048_bs_32ada_lr_0.0001_1_gpu_07_24_2023_01_43_33", epoch="1650", N_ensemble_configurations=32)

batapfn_no_ensemble = TabPFNClassifier(device=device, model_string="batabpfn_e128_inputembeddingfourier_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_11_35", epoch="330", N_ensemble_configurations=1)
batapfn_no_ensemble_e410 = TabPFNClassifier(device=device, model_string="batabpfn_e128_inputembeddingfourier_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_11_35", epoch="410", N_ensemble_configurations=1)
batapfn_no_ensemble_e530 = TabPFNClassifier(device=device, model_string="batabpfn_e128_inputembeddingfourier_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_11_35", epoch="530", N_ensemble_configurations=1)
batapfn_no_ensemble_exit = TabPFNClassifier(device=device, model_string="batabpfn_e128_inputembeddingfourier_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_11_35", epoch="on_exit", N_ensemble_configurations=1)


batabpfn_e256_nsamples500_numfeatures20_03_20_2024_22_14_45_e630 = TabPFNClassifier(device=device, model_string="batabpfn_e256_nsamples500_numfeatures20_03_20_2024_22_14_45", epoch="630", N_ensemble_configurations=1)
batabpfn_e256_nsamples500_numfeatures20_03_20_2024_22_14_45_e1130 = TabPFNClassifier(device=device, model_string="batabpfn_e256_nsamples500_numfeatures20_03_20_2024_22_14_45", epoch="1130", N_ensemble_configurations=1)


mlp_distill = make_pipeline(StandardScaler(), DistilledTabPFNMLP(n_epochs=1000, device=device, hidden_size=128, n_layers=2, dropout_rate=.1, learning_rate=0.01, model_string="tabpfn_nooptimizer_emsize_512_nlayers_12_steps_2048_bs_32ada_lr_0.0001_1_gpu_07_24_2023_01_43_33", epoch=1650, N_ensemble_configurations=3))
mothernet_21_46_25_3940_ensemble3 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_d2048_H4096_L2_W32_P512_1_gpu_warm_08_25_2023_21_46_25_epoch_3940_no_optimizer.pickle", device=device), n_estimators=3)

mothernet_init_gd = MotherNetInitMLPClassifier(path="../models_diff/mn_d2048_H4096_L2_W32_P512_1_gpu_warm_08_25_2023_21_46_25_epoch_3940_no_optimizer.pickle", device=device, n_epochs=100)
mothernet_init_gd_no_learning = MotherNetInitMLPClassifier(path="../models_diff/mn_d2048_H4096_L2_W32_P512_1_gpu_warm_08_25_2023_21_46_25_epoch_3940_no_optimizer.pickle", device=device, n_epochs=1, learning_rate=0.0)
mothernet_init_gd_epochs_10 = MotherNetInitMLPClassifier(path="../models_diff/mn_d2048_H4096_L2_W32_P512_1_gpu_warm_08_25_2023_21_46_25_epoch_3940_no_optimizer.pickle", device=device, n_epochs=10, learning_rate=0.001, verbose=10)
mothernet_init_gd_epochs_10_lr0001 = MotherNetInitMLPClassifier(path="../models_diff/mn_d2048_H4096_L2_W32_P512_1_gpu_warm_08_25_2023_21_46_25_epoch_3940_no_optimizer.pickle", device=device, n_epochs=10, learning_rate=0.0001, verbose=10)


mn_P512_SFalse_L2_1_gpu_01_24_2024 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_P512_SFalse_L2_1_gpu_01_24_2024_00_31_59_epoch_3950.cpkt", device=device), n_estimators=3)
mn_SFalse_L2_1_gpu_01_25_2024_21_20_32 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_SFalse_L2_1_gpu_01_25_2024_21_20_32_epoch_4000.cpkt", device=device), n_estimators=3)
mn_Dclass_average_02_29_2024_04_16_00 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_02_29_2024_04_16_00_epoch_4000.cpkt", device=device), n_estimators=3)
mn_Dclass_average_02_29_2024_04_16_00_ohe_ensemble = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_02_29_2024_04_16_00_epoch_4000.cpkt", device=device), n_estimators=3, onehot=True)
mn_Dclass_average_02_29_2024_04_16_00_ohe_ensemble_32 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_02_29_2024_04_16_00_epoch_4000.cpkt", device=device), n_estimators=32, onehot=True)
mn_Dclass_average_02_29_2024_04_16_00_ohe_ensemble_8 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_02_29_2024_04_16_00_epoch_4000.cpkt", device=device), n_estimators=8, onehot=True)

mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_03_25_2024_17_14_32_epoch_2910.cpkt", device=device), n_estimators=8, onehot=True)


mn_Dclass_average_fourierfeatures16_05_09_2024_01_03_23_epoch_100 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_fourierfeatures16_05_09_2024_01_03_23_epoch_100.cpkt", device=device), n_estimators=8, onehot=True)


#mn_Dclass_average_03_25_2024_17_14_32_epoch_1760_ohe_ensemble_8 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_03_25_2024_17_14_32_epoch_1760.cpkt", device=device), n_estimators=8, onehot=True)
#mn_Dclass_average_03_25_2024_17_14_32_epoch_2270_ohe_ensemble_8 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_03_25_2024_17_14_32_epoch_2270.cpkt", device=device), n_estimators=8, onehot=True)
#mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_03_25_2024_17_14_32_epoch_2910.cpkt", device=device), n_estimators=8, onehot=True)

mn_Dclass_average_03_25_2024_17_14_32_epoch_3970_ohe_ensemble_8 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_03_25_2024_17_14_32_epoch_3970.cpkt", device=device), n_estimators=8, onehot=True)

additive_1_gpu_02_14_2024_16_34_15 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_1_gpu_02_14_2024_16_34_15_epoch_950_fixed2.cpkt", device=device), n_estimators=3, power=False)

additive_step_prior_02_08_2024 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_b16_reducelronspikeTrue_multiclasstypesteps_1_gpu_02_08_2024_04_51_33_epoch_790.cpkt", device=device), n_estimators=3, power=False)
additive_11_08_2023 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_1_gpu_11_08_2023_23_02_58_continue_11_10_2023_03_01_40_epoch_3170_no_optimizer.cpkt", device=device), n_estimators=3, power=False)
additive_02_20_2024_factorized_weight_decay = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_w0.01_factorizedoutputTrue_outputrank64_1_gpu_02_20_2024_22_39_06_epoch_1260_fixed.cpkt", device=device), n_estimators=3, power=False)

additive_Dclass_average_02_29_2024_04_15_55_epoch_190 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_02_29_2024_04_15_55_epoch_190.cpkt", device=device), n_estimators=3, power=False)
additive_Dclass_average_02_29_2024_04_15_55_epoch_190_no_ensemble = MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_02_29_2024_04_15_55_epoch_190.cpkt", device=device)
additive_Dclass_average_02_29_2024_04_15_55_epoch_560 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_02_29_2024_04_15_55_epoch_560.cpkt", device=device), n_estimators=3, power=False)
additive_Dclass_average_02_29_2024_04_15_55_epoch_730 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_02_29_2024_04_15_55_epoch_730.cpkt", device=device), n_estimators=3, power=False)
additive_Dclass_average_02_29_2024_04_15_55_epoch_780 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_02_29_2024_04_15_55_epoch_780.cpkt", device=device), n_estimators=3, power=False)
additive_Dclass_average_02_29_2024_04_15_55_epoch_850 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_02_29_2024_04_15_55_epoch_850.cpkt", device=device), n_estimators=3, power=False)
additive_Dclass_average_02_29_2024_04_15_55_epoch_1050 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_02_29_2024_04_15_55_epoch_1050.cpkt", device=device), n_estimators=3, power=False)


additive_Dclass_average_inputlayernormTrue_02_29_2024_20_52_12_epoch_1340 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_inputlayernormTrue_02_29_2024_20_52_12_epoch_1340.cpkt", device=device), n_estimators=3, power=False)



# additive_Dclass_average_multiclassmaxsteps3_multiclasstypesteps_03_04_2024_19_04_03_epoch_100 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_multiclassmaxsteps3_multiclasstypesteps_03_04_2024_19_04_03_epoch_100.cpkt", device=device), n_estimators=3, power=False)
# additive_Dclass_average_multiclassmaxsteps3_multiclasstypesteps_03_04_2024_19_04_03_epoch_270 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_multiclassmaxsteps3_multiclasstypesteps_03_04_2024_19_04_03_epoch_270.cpkt", device=device), n_estimators=3, power=False)


# additive_H512_Dclass_average_factorizedoutputTrue_L6_03_01_2024_21_38_55_epoch_360 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_H512_Dclass_average_factorizedoutputTrue_L6_03_01_2024_21_38_55_epoch_360.cpkt", device=device), n_estimators=3, power=False)
# additive_Dclass_average_factorizedoutputTrue_w001_03_02_2024_02_21_10_epoch_340 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_factorizedoutputTrue_w0.01_03_02_2024_02_21_10_epoch_340.cpkt", device=device), n_estimators=3, power=False)
# additive_Dclass_average_factorizedoutputTrue_w001_03_02_2024_02_21_10_epoch_420 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_factorizedoutputTrue_w0.01_03_02_2024_02_21_10_epoch_420.cpkt", device=device), n_estimators=3, power=False)
# additive_Dclass_average_factorizedoutputTrue_w001_03_02_2024_02_21_10_epoch_1210 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_factorizedoutputTrue_w0.01_03_02_2024_02_21_10_epoch_1210.cpkt", device=device), n_estimators=3, power=False)

# additive_Dclass_average_factorizedoutputTrue_nshapefunctions128_outputrank64_shapeattentionTrue_shapeattentionheads8_03_07_2024_00_39_41_epoch_580 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_factorizedoutputTrue_nshapefunctions128_outputrank64_shapeattentionTrue_shapeattentionheads8_03_07_2024_00_39_41_epoch_580.cpkt", device=device), n_estimators=3, power=False)

# additive_Dclass_average_factorizedoutputTrue_nshapefunctions128_outputrank64_shapeattentionTrue_shapeattentionheads8_03_08_2024_21_19_43_epoch_1280 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/additive_Dclass_average_factorizedoutputTrue_nshapefunctions128_outputrank64_shapeattentionTrue_shapeattentionheads8_03_08_2024_21_19_43_epoch_1280.cpkt", device=device), n_estimators=3, power=False)


ebm_default = ExplainableBoostingClassifier()
ebm_bins = ExplainableBoostingClassifier(max_bins=64)
ebm_bins_main_effects = ExplainableBoostingClassifier(max_bins=64, interactions=0)


baam_nfeatures_20_no_ensemble = MotherNetAdditiveClassifier(path="../models_diff/baam_H512_Dclass_average_e128_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_03_22_epoch_400.cpkt", device=device)
baam_nfeatures_20_no_ensemble_e500 = MotherNetAdditiveClassifier(path="../models_diff/baam_H512_Dclass_average_e128_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_03_22_epoch_500.cpkt", device=device)
baam_nfeatures_20_no_ensemble_e650 = MotherNetAdditiveClassifier(path="../models_diff/baam_H512_Dclass_average_e128_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_03_22_epoch_650.cpkt", device=device)
baam_nfeatures_20_no_ensemble_e840 = MotherNetAdditiveClassifier(path="../models_diff/baam_H512_Dclass_average_e128_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_03_22_epoch_840.cpkt", device=device)
baam_nfeatures_20_no_ensemble_e1210 = MotherNetAdditiveClassifier(path="../models_diff/baam_H512_Dclass_average_e128_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_03_22_epoch_1210.cpkt", device=device)
baam_nfeatures_20_no_ensemble_e1520 = MotherNetAdditiveClassifier(path="../models_diff/baam_H512_Dclass_average_e128_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_03_22_epoch_1520.cpkt", device=device)
baam_nfeatures_20_no_ensemble_e1970 = MotherNetAdditiveClassifier(path="../models_diff/baam_H512_Dclass_average_e128_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_03_22_epoch_1970.cpkt", device=device)


baam_nfeatures_20 = EnsembleMeta(MotherNetAdditiveClassifier(path="../models_diff/baam_H512_Dclass_average_e128_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_03_22_epoch_400.cpkt", device=device), n_estimators=3, power=False)

baam_nfeatures_100_no_ensemble = MotherNetAdditiveClassifier(path="../models_diff/baam_H512_Dclass_average_e64_nsamples500_N6_padzerosFalse_03_13_2024_20_28_36_epoch_360.cpkt", device=device)


baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_140 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_140.cpkt", device=device)
baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_150 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_150.cpkt", device=device)
baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_220 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_220.cpkt", device=device)
baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_410 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_410.cpkt", device=device)
baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_780 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_780.cpkt", device=device)
baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_1010 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_1010.cpkt", device=device)
baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_1210 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_1210.cpkt", device=device)
baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_1520 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_1520.cpkt", device=device)
baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_2940 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_2940.cpkt", device=device)

baam_nsamples500_numfeatures20_04_04_2024_03_07_12_epoch_1210 = MotherNetAdditiveClassifier(path="../models_diff/baam_nsamples500_numfeatures20_04_04_2024_03_07_12_epoch_1210.cpkt", device=device)

baam_e128_nsamples500_numfeatures20_04_01_2024_15_38_54_epoch_2220 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_nsamples500_numfeatures20_04_01_2024_15_38_54_epoch_2220.cpkt", device=device)

baam_nsamples500_numfeatures10_04_07_2024_17_04_53_epoch_1490 = MotherNetAdditiveClassifier(path="../models_diff/baam_nsamples500_numfeatures10_04_07_2024_17_04_53_epoch_1490.cpkt", device=device)
baam_nsamples500_numfeatures10_04_07_2024_17_04_53_epoch_1780 = MotherNetAdditiveClassifier(path="../models_diff/baam_nsamples500_numfeatures10_04_07_2024_17_04_53_epoch_1780.cpkt", device=device)

baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_1430 = MotherNetAdditiveClassifier(path="../models_diff/baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_1430.cpkt", device=device)


baam_categoricalembeddingTrue_nsamples500_nanbinTrue_numfeatures20_04_04_2024_14_06_39_epoch_970 = MotherNetAdditiveClassifier(path="../models_diff/baam_categoricalembeddingTrue_nsamples500_nanbinTrue_numfeatures20_04_04_2024_14_06_39_epoch_970.cpkt", device=device)

baam_categoricalfeaturep09_nsamples500_numfeatures20_numfeaturessamplerdouble_sample_sklearnbinningTrue_05_15_2024_20_58_13_epoch_280 = MotherNetAdditiveClassifier(path="../models_diff/baam_categoricalfeaturep0.9_nsamples500_numfeatures20_numfeaturessamplerdouble_sample_sklearnbinningTrue_05_15_2024_20_58_13_epoch_280.cpkt", device=device)
baam_categoricalembeddingTrue_categoricalfeaturep09_l1e05_nsamples500_numfeatures20_numfeaturessamplerdouble_sample_sklearnbinningTrue_05_17_2024_00_02_36_epoch_230 = MotherNetAdditiveClassifier(
    path="../models_diff/baam_categoricalembeddingTrue_categoricalfeaturep0.9_l1e-05_nsamples500_numfeatures20_numfeaturessamplerdouble_sample_sklearnbinningTrue_05_17_2024_00_02_36_epoch_230.cpkt", device=device)




baam_e128_nbins512_nsamples500_numfeatures20_03_19_2024_22_53_00_epoch_160 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_nbins512_nsamples500_numfeatures20_03_19_2024_22_53_00_epoch_160.cpkt", device=device)

baam_fourierfeatures64_nbins512_nsamples500_numfeatures20_03_21_2024_01_02_32_epoch_310 = MotherNetAdditiveClassifier(path="../models_diff/baam_fourierfeatures64_nbins512_nsamples500_numfeatures20_03_21_2024_01_02_32_epoch_310.cpkt", device=device)

baam_marginalresidualTrue_nsamples500_numfeatures20_shapeinitzero_03_29_2024_19_20_09_epoch_130 = MotherNetAdditiveClassifier(path="../models_diff/baam_marginalresidualTrue_nsamples500_numfeatures20_shapeinitzero_03_29_2024_19_20_09_epoch_130.cpkt", device=device)
baam_marginalresidualTrue_nsamples500_numfeatures20_shapeinitzero_03_29_2024_19_20_09_epoch_180 = MotherNetAdditiveClassifier(path="../models_diff/baam_marginalresidualTrue_nsamples500_numfeatures20_shapeinitzero_03_29_2024_19_20_09_epoch_180.cpkt", device=device)


baam_e128_featurecurriculumTrue_nsamples500_N6_03_21_2024_17_58_54_epoch_610 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_featurecurriculumTrue_nsamples500_N6_03_21_2024_17_58_54_epoch_610.cpkt", device=device)
baam_e128_featurecurriculumTrue_nsamples500_N6_03_21_2024_17_58_54_epoch_660 = MotherNetAdditiveClassifier(path="../models_diff/baam_e128_featurecurriculumTrue_nsamples500_N6_03_21_2024_17_58_54_epoch_660.cpkt", device=device)

baam_fourierfeatures64_nsamples500_numfeatures20_03_27_2024_17_56_13_epoch_250 = MotherNetAdditiveClassifier(path="../models_diff/baam_fourierfeatures64_nsamples500_numfeatures20_03_27_2024_17_56_13_epoch_250.cpkt", device=device)
baam_fourierfeatures64_nsamples500_numfeatures20_03_27_2024_17_56_13_epoch_1230 = MotherNetAdditiveClassifier(path="../models_diff/baam_fourierfeatures64_nsamples500_numfeatures20_03_27_2024_17_56_13_epoch_1230.cpkt", device=device)

baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_410 = MotherNetAdditiveClassifier(path="../models_diff/baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_410.cpkt", device=device)
baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_550 = MotherNetAdditiveClassifier(path="../models_diff/baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_550.cpkt", device=device)
baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_690 = MotherNetAdditiveClassifier(path="../models_diff/baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_690.cpkt", device=device)
baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_890 = MotherNetAdditiveClassifier(path="../models_diff/baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_890.cpkt", device=device)

baam_nsamples500_numfeatures20_03_27_2024_17_57_59_epoch_470 = MotherNetAdditiveClassifier(path="../models_diff/baam_nsamples500_numfeatures20_03_27_2024_17_57_59_epoch_470.cpkt", device=device)

baam_categoricalembeddingTrue_nsamples500_nanbinTrue_numfeatures20_04_04_2024_14_06_39_epoch_190 = MotherNetAdditiveClassifier(path="../models_diff/baam_categoricalembeddingTrue_nsamples500_nanbinTrue_numfeatures20_04_04_2024_14_06_39_epoch_190.cpkt", device=device)



#hyperfast_no_optimization = HyperFastClassifier(device=device, optimization=None)
#hyperfast_defaults = HyperFastClassifier(device=device)

clf_dict= {
    'mothernet': partial(transformer_metric, classifier=mothernet_21_46_25_3940_ensemble3, onehot=True),
    'mlp_distill': mlp_distill,
    'tabpfn': transformer_metric,
    'tabpfn_ours': tabpfn_ours,
    'tabpfn_ours_ensemble_8': tabpfn_ours_ensemble_8,
    'tabpfn_ours_ensemble_32': tabpfn_ours_ensemble_32,
    'mothernet_init_gd': mothernet_init_gd,
    'mothernet_init_gd_no_learning': mothernet_init_gd_no_learning,
    'mothernet_init_gd_epochs_10': mothernet_init_gd_epochs_10,
    'mothernet_init_gd_epochs_10_lr0001': mothernet_init_gd_epochs_10_lr0001,
    'mothernet_init_gd_no_learning_ohe' : partial(transformer_metric, classifier=mothernet_init_gd_no_learning, onehot=True),
    'mothernet_init_gd_ohe' : partial(transformer_metric, classifier=mothernet_init_gd, onehot=True),


    #'batapfn_no_ensemble': batapfn_no_ensemble,
    #'batapfn_no_ensemble_e410': batapfn_no_ensemble_e410,
    #'batapfn_no_ensemble_exit': batapfn_no_ensemble_exit,
    'batabpfn_e256_nsamples500_numfeatures20_03_20_2024_22_14_45_e630': batabpfn_e256_nsamples500_numfeatures20_03_20_2024_22_14_45_e630,
    'batabpfn_e256_nsamples500_numfeatures20_03_20_2024_22_14_45_e1130': batabpfn_e256_nsamples500_numfeatures20_03_20_2024_22_14_45_e1130,
  #"hyperfast_no_optimize_gpu": partial(hyperfast_metric, optimization=None),
     "hyperfast_no_optimize_cpu":  partial(hyperfast_metric, optimization=None),
    "hyperfast_defaults_cpu":hyperfast_metric,
    
    'ebm_default': partial(transformer_metric, classifier=ebm_default),
    #'ebm_bins': partial(transformer_metric, classifier=ebm_bins),
    'ebm_bins_main_effects': partial(transformer_metric, classifier=ebm_bins_main_effects),

    'mn_Dclass_average_03_25_2024_17_14_32_epoch_3970_ohe_ensemble_8': mn_Dclass_average_03_25_2024_17_14_32_epoch_3970_ohe_ensemble_8,
    'mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8': mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8,
    'mn_Dclass_average_02_29_2024_04_16_00_ohe_ensemble_8': mn_Dclass_average_02_29_2024_04_16_00_ohe_ensemble_8,
    'baam_categoricalfeaturep09_nsamples500_numfeatures20_numfeaturessamplerdouble_sample_sklearnbinningTrue_05_15_2024_20_58_13_epoch_280': baam_categoricalfeaturep09_nsamples500_numfeatures20_numfeaturessamplerdouble_sample_sklearnbinningTrue_05_15_2024_20_58_13_epoch_280,
    'baam_categoricalembeddingTrue_categoricalfeaturep09_l1e05_nsamples500_numfeatures20_numfeaturessamplerdouble_sample_sklearnbinningTrue_05_17_2024_00_02_36_epoch_230': baam_categoricalembeddingTrue_categoricalfeaturep09_l1e05_nsamples500_numfeatures20_numfeaturessamplerdouble_sample_sklearnbinningTrue_05_17_2024_00_02_36_epoch_230,
    

    'mn_Dclass_average_fourierfeatures16_05_09_2024_01_03_23_epoch_100': mn_Dclass_average_fourierfeatures16_05_09_2024_01_03_23_epoch_100,
    
        'baam_nsamples500_numfeatures10_04_07_2024_17_04_53_epoch_1490': baam_nsamples500_numfeatures10_04_07_2024_17_04_53_epoch_1490,
   'baam_nsamples500_numfeatures10_04_07_2024_17_04_53_epoch_1780': baam_nsamples500_numfeatures10_04_07_2024_17_04_53_epoch_1780,  #  <- TRY THIS
    'baam_nsamples500_numfeatures20_04_04_2024_03_07_12_epoch_1210': baam_nsamples500_numfeatures20_04_04_2024_03_07_12_epoch_1210,

    'additive_Dclass_average_02_29_2024_04_15_55_epoch_1050': additive_Dclass_average_02_29_2024_04_15_55_epoch_1050,
    'baam_e128_nsamples500_numfeatures20_04_01_2024_15_38_54_epoch_2220': baam_e128_nsamples500_numfeatures20_04_01_2024_15_38_54_epoch_2220,
    #'additive_Dclass_average_inputlayernormTrue_02_29_2024_20_52_12_epoch_1340_retry': additive_Dclass_average_inputlayernormTrue_02_29_2024_20_52_12_epoch_1340,
    #'baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_140': baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_140,
    #'baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_150_redo': baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_150,
     #'baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_220': baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_220,
    
    'baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_550': baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_550,
    # this one overfitted and then kinda recovered? but not very good in the end.
#'baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_410': baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_410,
    #'baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_780': baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_780,
    #'baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_1010': baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_1010,
    #'baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_1210': baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_1210,
    #'baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_1520': baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_1520,
    #'baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_2940': baam_e128_fourierfeatures32_nsamples500_numfeatures20_03_20_2024_00_05_35_epoch_2940,
    

    #'additive_Dclass_average_multiclassmaxsteps3_multiclasstypesteps_03_04_2024_19_04_03_epoch_100': additive_Dclass_average_multiclassmaxsteps3_multiclasstypesteps_03_04_2024_19_04_03_epoch_100,
   # 'additive_Dclass_average_multiclassmaxsteps3_multiclasstypesteps_03_04_2024_19_04_03_epoch_270': additive_Dclass_average_multiclassmaxsteps3_multiclasstypesteps_03_04_2024_19_04_03_epoch_270,

  #  'additive_H512_Dclass_average_factorizedoutputTrue_L6_03_01_2024_21_38_55_epoch_360': additive_H512_Dclass_average_factorizedoutputTrue_L6_03_01_2024_21_38_55_epoch_360,
  #  'additive_Dclass_average_factorizedoutputTrue_w001_03_02_2024_02_21_10_epoch_340': additive_Dclass_average_factorizedoutputTrue_w001_03_02_2024_02_21_10_epoch_340,
    #'additive_Dclass_average_factorizedoutputTrue_w001_03_02_024_02_21_10_epoch_420': additive_Dclass_average_factorizedoutputTrue_w001_03_02_2024_02_21_10_epoch_420,
    #'additive_Dclass_average_factorizedoutputTrue_w001_03_02_2024_02_21_10_epoch_1210': additive_Dclass_average_factorizedoutputTrue_w001_03_02_2024_02_21_10_epoch_1210,
    #'additive_Dclass_average_factorizedoutputTrue_nshapefunctions128_outputrank64_shapeattentionTrue_shapeattentionheads8_03_07_2024_00_39_41_epoch_580': additive_Dclass_average_factorizedoutputTrue_nshapefunctions128_outputrank64_shapeattentionTrue_shapeattentionheads8_03_07_2024_00_39_41_epoch_580,
    #'additive_Dclass_average_factorizedoutputTrue_nshapefunctions128_outputrank64_shapeattentionTrue_shapeattentionheads8_03_08_2024_21_19_43_epoch_1280': additive_Dclass_average_factorizedoutputTrue_nshapefunctions128_outputrank64_shapeattentionTrue_shapeattentionheads8_03_08_2024_21_19_43_epoch_1280,
    #'baam_e128_featurecurriculumTrue_nsamples500_N6_03_21_2024_17_58_54_epoch_610': baam_e128_featurecurriculumTrue_nsamples500_N6_03_21_2024_17_58_54_epoch_610,
    #'baam_e128_featurecurriculumTrue_nsamples500_N6_03_21_2024_17_58_54_epoch_660': baam_e128_featurecurriculumTrue_nsamples500_N6_03_21_2024_17_58_54_epoch_660,
    # 'baam_e128_nbins512_nsamples500_numfeatures20_03_19_2024_22_53_00_epoch_160': baam_e128_nbins512_nsamples500_numfeatures20_03_19_2024_22_53_00_epoch_160,
    # 'baam_nfeatures_20_no_ensemble': baam_nfeatures_20_no_ensemble,
    # 'baam_nfeatures_20': baam_nfeatures_20,
    #'baam_nfeatures_100_no_ensemble': baam_nfeatures_100_no_ensemble,
    # 'baam_nfeatures_20_no_ensemble_e500': baam_nfeatures_20_no_ensemble_e500,
#     'baam_nfeatures_20_no_ensemble_e650': baam_nfeatures_20_no_ensemble_e650,
#    'baam_nfeatures_20_no_ensemble_e840': baam_nfeatures_20_no_ensemble_e840,
    "baam_nfeatures_20_no_ensemble_e1210": baam_nfeatures_20_no_ensemble_e1210,
    "baam_nfeatures_20_no_ensemble_e1520": baam_nfeatures_20_no_ensemble_e1520,
    "baam_nfeatures_20_no_ensemble_e1970": baam_nfeatures_20_no_ensemble_e1970,

    'baam_nsamples500_numfeatures20_03_27_2024_17_57_59_epoch_470': baam_nsamples500_numfeatures20_03_27_2024_17_57_59_epoch_470,

    #'baam_fourierfeatures64_nbins512_nsamples500_numfeatures20_03_21_2024_01_02_32_epoch_310': baam_fourierfeatures64_nbins512_nsamples500_numfeatures20_03_21_2024_01_02_32_epoch_310,
    #'baam_marginalresidualTrue_nsamples500_numfeatures20_shapeinitzero_03_29_2024_19_20_09_epoch_130': baam_marginalresidualTrue_nsamples500_numfeatures20_shapeinitzero_03_29_2024_19_20_09_epoch_130,
    #    'baam_marginalresidualTrue_nsamples500_numfeatures20_shapeinitzero_03_29_2024_19_20_09_epoch_180': baam_marginalresidualTrue_nsamples500_numfeatures20_shapeinitzero_03_29_2024_19_20_09_epoch_180,

#    'baam_fourierfeatures64_nsamples500_numfeatures20_03_27_2024_17_56_13_epoch_250': baam_fourierfeatures64_nsamples500_numfeatures20_03_27_2024_17_56_13_epoch_250,
    'baam_fourierfeatures64_nsamples500_numfeatures20_03_27_2024_17_56_13_epoch_1230': baam_fourierfeatures64_nsamples500_numfeatures20_03_27_2024_17_56_13_epoch_1230,
    'baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_410': baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_410,
#        'baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_690': baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_690,
#    'baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_890': baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_890,
        'baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_1430': baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_1430,

    #'baam_categoricalembeddingTrue_nsamples500_nanbinTrue_numfeatures20_04_04_2024_14_06_39_epoch_970': baam_categoricalembeddingTrue_nsamples500_nanbinTrue_numfeatures20_04_04_2024_14_06_39_epoch_970,
    #'baam_categoricalembeddingTrue_nsamples500_nanbinTrue_numfeatures20_04_04_2024_14_06_39_epoch_190': baam_categoricalembeddingTrue_nsamples500_nanbinTrue_numfeatures20_04_04_2024_14_06_39_epoch_190,
    }
results_transformers = [
    eval_on_datasets('multiclass', model, model_name, cc_valid_datasets_multiclass, eval_positions=eval_positions, max_times=max_times,
                     metric_used=metric_used, split_numbers=[1, 2, 3, 4, 5],
                     n_samples=n_samples, base_path=base_path, overwrite=False, n_jobs=-1, device=device)
    for model_name, model in clf_dict.items()
]

evaluating mothernet on cpu


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.
[Parallel(n_jobs=-1)]: Done 175 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.6s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating mlp_distill on cpu


[Parallel(n_jobs=-1)]: Done 176 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    0.4s remaining:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating tabpfn on cpu


[Parallel(n_jobs=-1)]: Done 175 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating tabpfn on cpu


[Parallel(n_jobs=-1)]: Done 174 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    0.4s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating tabpfn_ours on cpu


[Parallel(n_jobs=-1)]: Done 173 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating tabpfn_ours_ensemble_8 on cpu


[Parallel(n_jobs=-1)]: Done 174 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating tabpfn_ours_ensemble_32 on cpu


[Parallel(n_jobs=-1)]: Done 175 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating mothernet_init_gd on cpu


[Parallel(n_jobs=-1)]: Done 179 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating mothernet_init_gd_no_learning on cpu


[Parallel(n_jobs=-1)]: Done 176 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    0.4s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating mothernet_init_gd_epochs_10 on cpu


[Parallel(n_jobs=-1)]: Done 177 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.
[Parallel(n_jobs=-1)]: Done 178 tasks      | elapsed:    0.2s


evaluating mothernet_init_gd_epochs_10_lr0001 on cpu


[Parallel(n_jobs=-1)]: Done 177 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating mothernet_init_gd_no_learning_ohe on cpu


[Parallel(n_jobs=-1)]: Done 177 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating mothernet_init_gd_ohe on cpu


[Parallel(n_jobs=-1)]: Done 175 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating batabpfn_e256_nsamples500_numfeatures20_03_20_2024_22_14_45_e630 on cpu


[Parallel(n_jobs=-1)]: Done 177 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    1.1s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating batabpfn_e256_nsamples500_numfeatures20_03_20_2024_22_14_45_e1130 on cpu


[Parallel(n_jobs=-1)]: Done 177 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    0.4s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating ebm_default on cpu


[Parallel(n_jobs=-1)]: Done 174 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating ebm_bins_main_effects on cpu


[Parallel(n_jobs=-1)]: Done 174 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating mn_Dclass_average_03_25_2024_17_14_32_epoch_3970_ohe_ensemble_8 on cpu


[Parallel(n_jobs=-1)]: Done 176 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8 on cpu


[Parallel(n_jobs=-1)]: Done 176 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    0.4s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating mn_Dclass_average_02_29_2024_04_16_00_ohe_ensemble_8 on cpu


[Parallel(n_jobs=-1)]: Done 174 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_categoricalfeaturep09_nsamples500_numfeatures20_numfeaturessamplerdouble_sample_sklearnbinningTrue_05_15_2024_20_58_13_epoch_280 on cpu


[Parallel(n_jobs=-1)]: Done 176 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_categoricalembeddingTrue_categoricalfeaturep09_l1e05_nsamples500_numfeatures20_numfeaturessamplerdouble_sample_sklearnbinningTrue_05_17_2024_00_02_36_epoch_230 on cpu


[Parallel(n_jobs=-1)]: Done 173 tasks      | elapsed:  1.3min
[Parallel(n_jobs=-1)]: Done 456 tasks      | elapsed:  1.6min


evaluating mn_Dclass_average_fourierfeatures16_05_09_2024_01_03_23_epoch_100 on cpu


[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:  2.5min finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.
[Parallel(n_jobs=-1)]: Done   1 out of   1 | elapsed:    0.9s finished
[Parallel(n_jobs=-1)]: Done 174 tasks      | elapsed:    1.1s
[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    1.4s remaining:    0.5s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    6.3s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_nsamples500_numfeatures10_04_07_2024_17_04_53_epoch_1490 on cpu


[Parallel(n_jobs=-1)]: Done 174 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    0.4s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_nsamples500_numfeatures10_04_07_2024_17_04_53_epoch_1780 on cpu


[Parallel(n_jobs=-1)]: Done 173 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    0.4s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_nsamples500_numfeatures20_04_04_2024_03_07_12_epoch_1210 on cpu


[Parallel(n_jobs=-1)]: Done 175 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    0.4s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating additive_Dclass_average_02_29_2024_04_15_55_epoch_1050 on cpu


[Parallel(n_jobs=-1)]: Done 178 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_e128_nsamples500_numfeatures20_04_01_2024_15_38_54_epoch_2220 on cpu


[Parallel(n_jobs=-1)]: Done 177 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.6s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_550 on cpu


[Parallel(n_jobs=-1)]: Done 179 tasks      | elapsed:    1.0s
[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    1.2s remaining:    0.4s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    1.4s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_nfeatures_20_no_ensemble_e1210 on cpu


[Parallel(n_jobs=-1)]: Done 175 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_nfeatures_20_no_ensemble_e1520 on cpu


[Parallel(n_jobs=-1)]: Done 176 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_nfeatures_20_no_ensemble_e1970 on cpu


[Parallel(n_jobs=-1)]: Done 178 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    0.4s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_nsamples500_numfeatures20_03_27_2024_17_57_59_epoch_470 on cpu


[Parallel(n_jobs=-1)]: Done 174 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    0.4s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_fourierfeatures64_nsamples500_numfeatures20_03_27_2024_17_56_13_epoch_1230 on cpu


[Parallel(n_jobs=-1)]: Done 174 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    0.4s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_410 on cpu


[Parallel(n_jobs=-1)]: Done 174 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating baam_fourierfeatures64_nbins128_nsamples500_numfeatures20_03_24_2024_21_54_58_epoch_1430 on cpu


[Parallel(n_jobs=-1)]: Done 175 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 554 out of 745 | elapsed:    0.4s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done 745 out of 745 | elapsed:    0.5s finished


In [37]:
flat_results = []
for per_dataset in results_baselines + results_transformers:
    for result in per_dataset:
        row = {}
        for key in ['dataset', 'model', 'mean_metric', 'split', 'max_time']:
            row[key] = result[key]
        best_configs_key, = [k for k in result.keys() if "best_configs" in k]
        if result[best_configs_key][0] is not None:
            row.update(result[best_configs_key][0])
        row['mean_metric'] = float(row["mean_metric"].numpy())
        flat_results.append(row)

results_df = pd.DataFrame(flat_results)

In [38]:
result['titanic_best_configs_at_1000'][0]

{'fit_time': 0.11869001388549805, 'inference_time': 0.044876813888549805}

In [39]:
results_df.model.unique()

array(['knn', 'rf_new_params', 'xgb', 'logistic', 'mlp', 'mothernet',
       'mlp_distill', 'tabpfn', 'hyperfast_no_optimize_gpu',
       'hyperfast_defaults_gpu',
       'mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8',
       'mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8_no_power',
       'mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8_quantile2',
       'mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8_quantile',
       'mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_3_quantile',
       'mn_Dclass_average_numfeaturessamplerdouble_sample_05_08_2024_22_58_18_epoch_1290_ohe_8_quantile',
       'mn_Dclass_average_numfeaturessamplerdouble_sample_05_08_2024_22_58_18_epoch_1290_ohe_8',
       'mn_Dclass_average_numfeaturessamplerdouble_sample_05_08_2024_22_58_18_epoch_1290_ohe_quantile_8_fixed2',
       'mn_categoricalfeaturep09_numfeaturessamplerdouble_sample_05_09_2024_23_39_30_epoch_1060_ohe_8',
       'm

In [40]:
# with open(f"results_validation_{datetime.date.today()}.pickle", "wb") as f:
#    pickle.dump(results_baselines + results_transformers, f)

In [41]:
results_df['model'] = results_df.model.replace({'knn': "KNN", 'rf_new_params': 'RF', 'mlp': "MLP",'mlp_distill': 'MLP-Distill', 'xgb':'XGBoost', 'logistic': 'LogReg',  'mothernet': 'MotherNet', 'tabpfn': 'TabPFN'})

In [42]:
filename = f"results/results_validation_{datetime.date.today()}.csv"
results_df.to_csv(filename)
filename

'results/results_validation_2024-05-20.csv'