# Test dataset evaluation

In [1]:
import matplotlib.pyplot as plt

from mothernet.evaluation.baselines import tabular_baselines

import seaborn as sns
import numpy as np
import warnings
warnings.simplefilter("ignore", FutureWarning)  # openml deprecation of array return type
from mothernet.datasets import load_openml_list, open_cc_valid_dids, open_cc_dids
from mothernet.evaluation.baselines.tabular_baselines import knn_metric, catboost_metric, logistic_metric, xgb_metric, random_forest_metric, mlp_metric, hyperfast_metric, hyperfast_metric_tuning, resnet_metric, mothernet_init_metric
from mothernet.evaluation.tabular_evaluation import evaluate, eval_on_datasets, transformer_metric
from mothernet.evaluation import tabular_metrics
from mothernet.prediction.tabpfn import TabPFNClassifier

# Datasets

In [2]:
cc_test_datasets_multiclass, cc_test_datasets_multiclass_df = load_openml_list(open_cc_dids, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = 10000, num_feats=100, return_capped=True)

Number of datasets: 30


# Setting params

In [3]:
import os
eval_positions = [1000]
max_features = 100
n_samples = 2000
base_path = os.path.join('../')
overwrite = False
max_times = [1, 15, 30, 60, 60 * 5, 60 * 15, 60*60]
metric_used = tabular_metrics.auc_metric
task_type = 'multiclass'

In [4]:
from mothernet.evaluation.baselines.distill_mlp import DistilledTabPFNMLP
from mothernet.prediction.mothernet import MotherNetClassifier
from functools import partial

# Baseline Evaluation
This section runs baselines and saves results locally.

In [5]:
!mkdir -p {base_path}/results
!mkdir -p {base_path}/results/tabular/
!mkdir -p {base_path}/results/tabular/multiclass/

In [6]:
cc_test_datasets_multiclass_df['isNumeric'] = (cc_test_datasets_multiclass_df.NumberOfSymbolicFeatures == 1) & (cc_test_datasets_multiclass_df.NumberOfInstancesWithMissingValues == 0)

In [7]:
cc_test_datasets_multiclass_df['NumberOfInstances'] =  cc_test_datasets_multiclass_df['NumberOfInstances'].astype(int)
cc_test_datasets_multiclass_df['NumberOfFeatures'] =  cc_test_datasets_multiclass_df['NumberOfFeatures'].astype(int)
cc_test_datasets_multiclass_df['NumberOfClasses'] =  cc_test_datasets_multiclass_df['NumberOfClasses'].astype(int)

print(cc_test_datasets_multiclass_df[['did', 'name', 'NumberOfFeatures', 'NumberOfInstances', 'NumberOfClasses']].rename(columns={'NumberOfFeatures': "d", "NumberOfInstances":"n", "NumberOfClasses": "k"}).to_latex(index=False))

\begin{tabular}{rlrrr}
\toprule
did & name & d & n & k \\
\midrule
11 & balance-scale & 5 & 625 & 3 \\
14 & mfeat-fourier & 77 & 2000 & 10 \\
15 & breast-w & 10 & 699 & 2 \\
16 & mfeat-karhunen & 65 & 2000 & 10 \\
18 & mfeat-morphological & 7 & 2000 & 10 \\
22 & mfeat-zernike & 48 & 2000 & 10 \\
23 & cmc & 10 & 1473 & 3 \\
29 & credit-approval & 16 & 690 & 2 \\
31 & credit-g & 21 & 1000 & 2 \\
37 & diabetes & 9 & 768 & 2 \\
50 & tic-tac-toe & 10 & 958 & 2 \\
54 & vehicle & 19 & 846 & 4 \\
188 & eucalyptus & 20 & 736 & 5 \\
458 & analcatdata_authorship & 71 & 841 & 4 \\
469 & analcatdata_dmft & 5 & 797 & 6 \\
1049 & pc4 & 38 & 1458 & 2 \\
1050 & pc3 & 38 & 1563 & 2 \\
1063 & kc2 & 22 & 522 & 2 \\
1068 & pc1 & 22 & 1109 & 2 \\
1462 & banknote-authentication & 5 & 1372 & 2 \\
1464 & blood-transfusion-service-center & 5 & 748 & 2 \\
1480 & ilpd & 11 & 583 & 2 \\
1494 & qsar-biodeg & 42 & 1055 & 2 \\
1510 & wdbc & 31 & 569 & 2 \\
6332 & cylinder-bands & 40 & 540 & 2 \\
23381 & dresses-sales

overlap:
balance-scale
mfeat-fourier
mfeat-karhunen
mfeat-morphological
credit-g
tic-tac-toe
vehicle
analcatdata_authorship
analcatdata_dmft
pc3
pc1
blood-transfusion-service-center
ilpd
qsar-biodeg
MiceProtein
car
steel-plates-fault
climate-simulation-model-crashes

non-overlap:
breast-w (valid)
mfeat-zernike (valid)
mcm (valid)
eucalyptus (valid)
wdbc (valid)
cylinder-bands (valid)
dresses-sales (valid)


banknote-authentication (test)
credit-approval (test)
diabetes (test)
pc4 (test)
kc2 (test)

In [None]:

max_times = [60 * 60]
clf_dict= {
    'mothernet_gd_gpu4': mothernet_init_metric}

results_mothernet_gd = [
    eval_on_datasets('multiclass', model, model_name, cc_test_datasets_multiclass[12:16], eval_positions=eval_positions, max_times=max_times,
                     metric_used=metric_used, split_numbers=[1, 2, 3, 4, 5],
                     n_samples=n_samples, base_path=base_path, n_jobs=1, device="cuda:3", verbose=1)
    for model_name, model in clf_dict.items()
]

evaluating mothernet_gd_gpu4 on cuda:3


  0%|                                                                                                                                                                                                                                                  | 0/20 [00:00<?, ?it/s]

Evaluating eucalyptus with 736 samples


  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = ms

Evaluating eucalyptus with 736 samples


  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = ms

Evaluating eucalyptus with 736 samples


  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = ms

Evaluating eucalyptus with 736 samples


  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = ms

Evaluating eucalyptus with 736 samples


  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = msb / msw
  f = ms

Evaluating analcatdata_authorship with 841 samples


 30%|████████████████████████████████████████████████████████████████████                                                                                                                                                               | 6/20 [6:00:30<14:01:00, 3604.31s/it]

Evaluating analcatdata_authorship with 841 samples


 35%|███████████████████████████████████████████████████████████████████████████████▍                                                                                                                                                   | 7/20 [7:00:38<13:01:12, 3605.58s/it]

Evaluating analcatdata_authorship with 841 samples


In [17]:
max_times = [60 * 60]
# these will all be evaluated on CPU because they are given as callables, which is a weird way to do it.
clf_dict= {
    'resnet_gpu': resnet_metric}

results_resnet = [
    eval_on_datasets('multiclass', model, model_name, cc_test_datasets_multiclass[24:], eval_positions=eval_positions, max_times=max_times,
                     metric_used=metric_used, split_numbers=[1, 2, 3, 4, 5],
                     n_samples=n_samples, base_path=base_path, n_jobs=1, device="cuda:3", verbose=1)
    for model_name, model in clf_dict.items()
]

evaluating resnet_gpu on cuda:3


  0%|                                                                                                                                                                                                                                                  | 0/30 [00:00<?, ?it/s]

Evaluating cylinder-bands with 540 samples


  3%|███████▌                                                                                                                                                                                                                           | 1/30 [1:00:12<29:05:54, 3612.23s/it]

Evaluating cylinder-bands with 540 samples


  7%|███████████████▏                                                                                                                                                                                                                   | 2/30 [2:00:16<28:03:32, 3607.58s/it]

Evaluating cylinder-bands with 540 samples


 10%|██████████████████████▋                                                                                                                                                                                                            | 3/30 [3:00:23<27:03:14, 3607.19s/it]

Evaluating cylinder-bands with 540 samples


 13%|██████████████████████████████▎                                                                                                                                                                                                    | 4/30 [4:00:36<26:04:08, 3609.56s/it]

Evaluating cylinder-bands with 540 samples


 17%|█████████████████████████████████████▊                                                                                                                                                                                             | 5/30 [5:00:43<25:03:39, 3608.76s/it]

Evaluating dresses-sales with 500 samples


 20%|█████████████████████████████████████████████▍                                                                                                                                                                                     | 6/30 [6:00:52<24:03:31, 3608.82s/it]

Evaluating dresses-sales with 500 samples


 23%|████████████████████████████████████████████████████▉                                                                                                                                                                              | 7/30 [7:00:59<23:03:09, 3608.23s/it]

Evaluating dresses-sales with 500 samples


 27%|████████████████████████████████████████████████████████████▌                                                                                                                                                                      | 8/30 [8:01:12<22:03:29, 3609.52s/it]

Evaluating dresses-sales with 500 samples


 30%|████████████████████████████████████████████████████████████████████                                                                                                                                                               | 9/30 [9:01:24<21:03:39, 3610.46s/it]

Evaluating dresses-sales with 500 samples


 33%|███████████████████████████████████████████████████████████████████████████                                                                                                                                                      | 10/30 [10:01:25<20:02:29, 3607.47s/it]

Evaluating MiceProtein with 1080 samples


 37%|██████████████████████████████████████████████████████████████████████████████████▌                                                                                                                                              | 11/30 [11:01:39<19:03:00, 3609.51s/it]

Evaluating MiceProtein with 1080 samples


 40%|██████████████████████████████████████████████████████████████████████████████████████████                                                                                                                                       | 12/30 [12:01:52<18:03:09, 3610.51s/it]

Evaluating MiceProtein with 1080 samples


 43%|█████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                                                               | 13/30 [13:02:01<17:02:50, 3610.00s/it]

Evaluating MiceProtein with 1080 samples


 47%|█████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                                                        | 14/30 [14:02:24<16:03:42, 3613.93s/it]

Evaluating MiceProtein with 1080 samples


 50%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                                                | 15/30 [15:02:32<15:03:01, 3612.10s/it]

Evaluating car with 1728 samples


 53%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                                         | 16/30 [16:02:47<14:03:04, 3613.21s/it]

Evaluating car with 1728 samples


 57%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                                 | 17/30 [17:02:59<13:02:46, 3612.84s/it]

Evaluating car with 1728 samples


 60%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                          | 18/30 [18:03:09<12:02:23, 3611.93s/it]

Evaluating car with 1728 samples


 63%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                  | 19/30 [19:03:41<11:03:17, 3617.92s/it]

Evaluating car with 1728 samples


 67%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                           | 20/30 [20:03:56<10:02:50, 3617.03s/it]

Evaluating steel-plates-fault with 1941 samples


 70%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                   | 21/30 [21:04:06<9:02:13, 3614.85s/it]

Evaluating steel-plates-fault with 1941 samples


 73%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                            | 22/30 [22:04:08<8:01:28, 3611.08s/it]

Evaluating steel-plates-fault with 1941 samples


 77%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                    | 23/30 [23:04:11<7:01:00, 3608.68s/it]

Evaluating steel-plates-fault with 1941 samples


 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                             | 24/30 [24:04:24<6:00:58, 3609.83s/it]

Evaluating steel-plates-fault with 1941 samples


 83%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                     | 25/30 [25:04:50<5:01:14, 3614.86s/it]

Evaluating climate-model-simulation-crashes with 540 samples


 87%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                              | 26/30 [26:04:53<4:00:44, 3611.17s/it]

Evaluating climate-model-simulation-crashes with 540 samples


 90%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                      | 27/30 [27:05:08<3:00:37, 3612.52s/it]

Evaluating climate-model-simulation-crashes with 540 samples


 93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉               | 28/30 [28:05:30<2:00:30, 3615.14s/it]

Evaluating climate-model-simulation-crashes with 540 samples


 97%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍       | 29/30 [29:05:44<1:00:15, 3615.03s/it]

Evaluating climate-model-simulation-crashes with 540 samples


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [30:05:53<00:00, 3611.77s/it]


In [11]:
from mothernet.evaluation.tabular_evaluation import eval_on_datasets
max_times = [1, 5, 15, 60, 5 * 60, 15 * 60, 60* 60]

clf_dict= {
    'knn': knn_metric,
    'rf_new_params': random_forest_metric,
    'xgb': xgb_metric,
    'logistic': logistic_metric,
    'mlp': mlp_metric}

results_baselines = [
    eval_on_datasets('multiclass', model, model_name, cc_test_datasets_multiclass, eval_positions=eval_positions, max_times=max_times,
                     metric_used=metric_used, split_numbers=[1, 2, 3, 4, 5],
                     n_samples=n_samples, base_path=base_path, fetch_only=True)
    for model_name, model in clf_dict.items()
]

evaluating knn on cpu


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.
[Parallel(n_jobs=-1)]: Done 173 tasks      | elapsed:    8.7s
[Parallel(n_jobs=-1)]: Done 524 tasks      | elapsed:    9.3s
[Parallel(n_jobs=-1)]: Done 1050 out of 1050 | elapsed:   18.3s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating rf_new_params on cpu


[Parallel(n_jobs=-1)]: Done 174 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 859 out of 1050 | elapsed:    0.6s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done 1050 out of 1050 | elapsed:    0.6s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating xgb on cpu


[Parallel(n_jobs=-1)]: Done 180 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 1050 out of 1050 | elapsed:    0.8s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating logistic on cpu


[Parallel(n_jobs=-1)]: Done 174 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 1050 out of 1050 | elapsed:    0.6s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.


evaluating mlp on cpu


[Parallel(n_jobs=-1)]: Done 174 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 1050 out of 1050 | elapsed:    0.7s finished


In [12]:
# from mothernet.evaluation.tabular_evaluation import eval_on_datasets

# max_times = [60 * 60]
# clf_dict= {
#     'hyperfast_tuning_gpu': hyperfast_metric_tuning}
# results_hyperfast = [
#     eval_on_datasets('multiclass', model, model_name, cc_test_datasets_multiclass[:8], eval_positions=eval_positions, max_times=max_times,
#                      metric_used=metric_used, split_numbers=[1, 2, 3, 4, 5],
#                      n_samples=n_samples, base_path=base_path, fetch_only=False, device='cuda:2')
#     for model_name, model in clf_dict.items()
# ]

In [22]:
from mothernet.evaluation.tabular_evaluation import eval_on_datasets
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import BaggingClassifier
from mothernet.prediction.mothernet import ShiftClassifier, EnsembleMeta, MotherNetClassifier
from sklearn.impute import SimpleImputer
from mothernet.prediction.mothernet_additive import MotherNetAdditiveClassifier

from interpret.glassbox import ExplainableBoostingClassifier


from hyperfast import HyperFastClassifier

# transformers don't have max times
import warnings
max_times = [1]
device = "cuda:2"

model_string = "tabpfn_nooptimizer_emsize_512_nlayers_12_steps_2048_bs_32ada_lr_0.0001_1_gpu_07_24_2023_01_43_33"
tabpfn_ours = TabPFNClassifier(device=device, model_string=model_string, epoch="1650", N_ensemble_configurations=3)
mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_Dclass_average_03_25_2024_17_14_32_epoch_2910.cpkt", device=device), n_estimators=8, onehot=True)

mlp_distill = make_pipeline(StandardScaler(), DistilledTabPFNMLP(n_epochs=1000, device=device, hidden_size=128, n_layers=2, dropout_rate=.1, learning_rate=0.01, model_string=model_string, epoch=1650, N_ensemble_configurations=3))
mothernet_21_46_25_3940_ensemble3 = EnsembleMeta(MotherNetClassifier(path="../models_diff/mn_d2048_H4096_L2_W32_P512_1_gpu_warm_08_25_2023_21_46_25_epoch_3940_no_optimizer.pickle", device=device), n_estimators=3)
ebm_bins_main_effects = ExplainableBoostingClassifier(max_bins=64, interactions=0)
baam_nfeatures_20_no_ensemble_e1520 = MotherNetAdditiveClassifier(path="../models_diff/baam_H512_Dclass_average_e128_nsamples500_numfeatures20_padzerosFalse_03_14_2024_15_03_22_epoch_1520.cpkt", device=device)

clf_dict= {
    'mothernet': partial(transformer_metric, classifier=mothernet_21_46_25_3940_ensemble3, onehot=True),
    'mlp_distill': mlp_distill,
    'tabpfn': transformer_metric,
    #'tabpfn_ours': tabpfn_ours,
    #"hyperfast_no_optimize_gpu": partial(hyperfast_metric, optimization=None),
    #"hyperfast_defaults_gpu": hyperfast_metric,
    'ebm_bins_main_effects': ebm_bins_main_effects,
    'baam_nfeatures_20_no_ensemble_e1520': baam_nfeatures_20_no_ensemble_e1520,
    'mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8': mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8,

    }
results_transformers = [
    eval_on_datasets('multiclass', model, model_name, cc_test_datasets_multiclass, eval_positions=eval_positions, max_times=max_times,
                     metric_used=metric_used, split_numbers=[1, 2, 3, 4, 5],
                     n_samples=n_samples, base_path=base_path, overwrite=False, n_jobs=1, device=device)
    for model_name, model in clf_dict.items()
]

evaluating mothernet on cuda:2


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 150/150 [00:01<00:00, 122.14it/s]


evaluating mlp_distill on cuda:2


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 150/150 [00:01<00:00, 120.55it/s]


evaluating tabpfn on cuda:2


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 150/150 [00:01<00:00, 123.47it/s]

evaluating ebm_bins_main_effects on cpu



  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
[Parallel(n_jobs=1)]: Done  40 tasks      | elapsed:   39.1s
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(


evaluating baam_nfeatures_20_no_ensemble_e1520 on cuda:2


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 150/150 [00:16<00:00,  8.83it/s]


evaluating mn_Dclass_average_03_25_2024_17_14_32_epoch_2910_ohe_ensemble_8 on cuda:2


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 150/150 [00:01<00:00, 122.25it/s]


In [23]:
import pandas as pd
flat_results = []
for per_dataset in results_baselines + results_transformers:
    for result in per_dataset:
        row = {}
        for key in ['dataset', 'model', 'mean_metric', 'split', 'max_time']:
            row[key] = result[key]
        best_configs_key, = [k for k in result.keys() if "best_configs" in k]
        if result[best_configs_key][0] is not None:
            row.update(result[best_configs_key][0])
        row['mean_metric'] = float(row["mean_metric"].numpy())
        flat_results.append(row)

results_df = pd.DataFrame(flat_results)

In [24]:
results_df

Unnamed: 0,dataset,model,mean_metric,split,max_time,best,fit_time,inference_time
0,balance-scale,knn,0.898451,1,1,{'n_neighbors': 14},0.000790,0.031152
1,balance-scale,knn,0.848925,2,1,{'n_neighbors': 8},0.000795,0.026842
2,balance-scale,knn,0.852651,3,1,{'n_neighbors': 10},0.000851,0.027301
3,balance-scale,knn,0.885874,4,1,{'n_neighbors': 10},0.000786,0.028868
4,balance-scale,knn,0.895205,5,1,{'n_neighbors': 15},0.000829,0.027410
...,...,...,...,...,...,...,...,...
6145,climate-model-simulation-crashes,mn_Dclass_average_03_25_2024_17_14_32_epoch_29...,0.937723,1,1,,0.317008,0.049277
6146,climate-model-simulation-crashes,mn_Dclass_average_03_25_2024_17_14_32_epoch_29...,0.943915,2,1,,0.313077,0.034901
6147,climate-model-simulation-crashes,mn_Dclass_average_03_25_2024_17_14_32_epoch_29...,0.939702,3,1,,0.398591,0.036673
6148,climate-model-simulation-crashes,mn_Dclass_average_03_25_2024_17_14_32_epoch_29...,0.944275,4,1,,0.313154,0.035129


In [25]:
#import pickle
#with open("results_test.pickle", "wb") as f:
#    pickle.dump(results_baselines + results_transformers, f)

In [26]:
results_df['model'] = results_df.model.replace({'knn': "KNN", 'rf_new_params': 'RF', 'mlp': "MLP",'mlp_distill': 'MLP-Distill', 'xgb':'XGBoost', 'logistic': 'LogReg',  'mothernet': 'MotherNet', 'tabpfn': 'TabPFN (Hollmann)', 'tabpfn_ours': 'TabPFN (ours)'})

In [27]:
import datetime
filename = f"results_test_{datetime.date.today()}.csv"
results_df.to_csv(filename)
filename

'results_test_2024-04-18.csv'