In [2]:
from autotsc import utils
import os
from aeon.datasets.tsc_datasets import univariate
import random
from itertools import product
from tqdm import tqdm
import polars as pl
import numpy as np
from time import time
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from aeon.classification.convolution_based import RocketClassifier, MiniRocketClassifier
import matplotlib.pyplot as plt
from sklearn.pipeline import make_pipeline
from aeon.transformations.collection import Normalizer
from aeon.classification.sklearn import SklearnClassifierWrapper
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import RidgeClassifierCV
from time import perf_counter

In [3]:
write_dir = "experiments/automl_ca_vs_time_correlation"
os.makedirs(write_dir, exist_ok=True)

datasets = list(univariate)
random.shuffle(datasets)
model_types = ['raw-scale-ridge', 'quant', 'minirocket', 'catch22', 'hydra', 'tabpfn']#, 'shapelet']

n_jobs = 4
n_runs = 5

In [4]:
def get_model(model_name):
    from sklearn.linear_model import RidgeClassifier
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.ensemble import HistGradientBoostingClassifier
    from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
    from aeon.classification.interval_based import QUANTClassifier
    from aeon.classification.feature_based import Catch22Classifier
    from aeon.classification.convolution_based import HydraClassifier
    from aeon.classification.shapelet_based import ShapeletTransformClassifier
    from tabpfn import TabPFNClassifier

    if model_name == 'raw-scale-ridge':
        return SklearnClassifierWrapper(
            make_pipeline(
                StandardScaler(),
                RidgeClassifierCV(alphas=np.logspace(-3, 3, 10))
            )
        )
    elif model_name == 'tabpfn':
        return SklearnClassifierWrapper(
            TabPFNClassifier(n_preprocessing_jobs=n_jobs)
        )
    elif model_name == 'quant':
        return QUANTClassifier()
    elif model_name == 'minirocket':
        return MiniRocketClassifier(n_jobs=n_jobs)
    elif model_name == 'catch22':
        return Catch22Classifier(n_jobs=n_jobs)
    elif model_name == 'hydra':
        return HydraClassifier(n_jobs=n_jobs)
    elif model_name == 'shapelet':
        return ShapeletTransformClassifier(n_jobs=n_jobs)
    else:
        raise ValueError(f"Unknown model: {model_name}")
        

In [5]:
all_combinations = list(product(datasets, range(n_runs), model_types))

last_dataset = None
X_train, y_train, X_test, y_test = None, None, None, None

for ds, run, model_name in tqdm(all_combinations, desc="Processing"):
    try: 
        model = get_model(model_name)
        stats = {
            "dataset": ds,
            "run": run,
            "model": model_name,
        }

        hash_val = pl.DataFrame([stats]).hash_rows(
            seed=42, seed_1=1, seed_2=2, seed_3=3
        ).item()
        file = f"{write_dir}/{hash_val}.parquet"

        if os.path.exists(file):
            #print(f'Skipping {stats}')
            continue
        else:
            print(f'Processing {stats}')

        if ds != last_dataset:
            X_train, y_train, X_test, y_test = utils.load_dataset(ds)
            last_dataset = ds

        start_time = perf_counter()
        model.fit(X_train, y_train)
        training_time = perf_counter() - start_time
        y_pred = model.predict(X_test)
        test_accuracy = accuracy_score(y_test, y_pred)

        stats["test_accuracy"] = test_accuracy
        stats["training_time"] = training_time

        df_stat = pl.DataFrame([stats])
        df_stat.write_parquet(file)

    except Exception as e:
        print(f"Error processing {ds}, run {run}, model {model_name}: {e}")
        continue

Processing:   0%|          | 1/3840 [00:00<22:57,  2.79it/s]

Processing {'dataset': 'Crop', 'run': 0, 'model': 'tabpfn'}


Processing:   0%|          | 12/3840 [00:00<03:31, 18.08it/s]

Error processing Crop, run 0, model tabpfn: Number of classes 24 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'Crop', 'run': 1, 'model': 'tabpfn'}
Error processing Crop, run 1, model tabpfn: Number of classes 24 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'Crop', 'run': 2, 'model': 'tabpfn'}


Processing:   1%|          | 24/3840 [00:00<01:46, 35.99it/s]

Error processing Crop, run 2, model tabpfn: Number of classes 24 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'Crop', 'run': 3, 'model': 'tabpfn'}
Error processing Crop, run 3, model tabpfn: Number of classes 24 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'Crop', 'run': 4, 'model': 'tabpfn'}
Error processing Crop, run 4, model tabpfn: Number of classes 24 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/t

Processing:   1%|▏         | 48/3840 [00:01<01:09, 54.74it/s]

Error processing FacesUCR, run 0, model tabpfn: Number of classes 14 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'FacesUCR', 'run': 1, 'model': 'tabpfn'}
Error processing FacesUCR, run 1, model tabpfn: Number of classes 14 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'FacesUCR', 'run': 2, 'model': 'tabpfn'}
Error processing FacesUCR, run 2, model tabpfn: Number of classes 14 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extens

Processing:   2%|▏         | 60/3840 [00:01<01:00, 62.46it/s]

Error processing FacesUCR, run 3, model tabpfn: Number of classes 14 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'FacesUCR', 'run': 4, 'model': 'tabpfn'}
Error processing FacesUCR, run 4, model tabpfn: Number of classes 14 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'Lightning2', 'run': 0, 'model': 'raw-scale-ridge'}
Processing {'dataset': 'Lightning2', 'run': 0, 'model': 'quant'}
Processing {'dataset': 'Lightning2', 'run': 0, 'model': 'minirocket'}
Processing {'dataset': 'Lightning2', 'run': 0, 'model': 'catch22'}
Processing {'dataset': 'Light

Processing:   2%|▏         | 67/3840 [00:08<15:11,  4.14it/s]

Processing {'dataset': 'Lightning2', 'run': 1, 'model': 'raw-scale-ridge'}
Processing {'dataset': 'Lightning2', 'run': 1, 'model': 'quant'}
Processing {'dataset': 'Lightning2', 'run': 1, 'model': 'minirocket'}
Processing {'dataset': 'Lightning2', 'run': 1, 'model': 'catch22'}
Processing {'dataset': 'Lightning2', 'run': 1, 'model': 'hydra'}
Processing {'dataset': 'Lightning2', 'run': 1, 'model': 'tabpfn'}


Processing:   2%|▏         | 72/3840 [00:10<18:30,  3.39it/s]

Processing {'dataset': 'Lightning2', 'run': 2, 'model': 'raw-scale-ridge'}
Processing {'dataset': 'Lightning2', 'run': 2, 'model': 'quant'}
Processing {'dataset': 'Lightning2', 'run': 2, 'model': 'minirocket'}
Processing {'dataset': 'Lightning2', 'run': 2, 'model': 'catch22'}


Processing:   2%|▏         | 76/3840 [00:11<17:35,  3.57it/s]

Processing {'dataset': 'Lightning2', 'run': 2, 'model': 'hydra'}
Processing {'dataset': 'Lightning2', 'run': 2, 'model': 'tabpfn'}


Processing:   2%|▏         | 79/3840 [00:13<20:11,  3.11it/s]

Processing {'dataset': 'Lightning2', 'run': 3, 'model': 'raw-scale-ridge'}
Processing {'dataset': 'Lightning2', 'run': 3, 'model': 'quant'}


Processing:   2%|▏         | 81/3840 [00:13<19:46,  3.17it/s]

Processing {'dataset': 'Lightning2', 'run': 3, 'model': 'minirocket'}
Processing {'dataset': 'Lightning2', 'run': 3, 'model': 'catch22'}
Processing {'dataset': 'Lightning2', 'run': 3, 'model': 'hydra'}


Processing:   2%|▏         | 83/3840 [00:14<20:47,  3.01it/s]

Processing {'dataset': 'Lightning2', 'run': 3, 'model': 'tabpfn'}


Processing:   2%|▏         | 84/3840 [00:15<25:33,  2.45it/s]

Processing {'dataset': 'Lightning2', 'run': 4, 'model': 'raw-scale-ridge'}
Processing {'dataset': 'Lightning2', 'run': 4, 'model': 'quant'}


Processing:   2%|▏         | 87/3840 [00:16<20:40,  3.02it/s]

Processing {'dataset': 'Lightning2', 'run': 4, 'model': 'minirocket'}
Processing {'dataset': 'Lightning2', 'run': 4, 'model': 'catch22'}


Processing:   2%|▏         | 88/3840 [00:16<20:21,  3.07it/s]

Processing {'dataset': 'Lightning2', 'run': 4, 'model': 'hydra'}


Processing:   2%|▏         | 89/3840 [00:17<22:45,  2.75it/s]

Processing {'dataset': 'Lightning2', 'run': 4, 'model': 'tabpfn'}


Processing:   2%|▏         | 90/3840 [00:18<34:19,  1.82it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 0, 'model': 'raw-scale-ridge'}


Processing:   2%|▏         | 91/3840 [00:19<36:43,  1.70it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 0, 'model': 'quant'}


Processing:   2%|▏         | 92/3840 [00:19<38:32,  1.62it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 0, 'model': 'minirocket'}


Processing:   2%|▏         | 93/3840 [00:20<33:00,  1.89it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 0, 'model': 'catch22'}


Processing:   2%|▏         | 94/3840 [00:20<32:01,  1.95it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 0, 'model': 'hydra'}


Processing:   2%|▏         | 95/3840 [00:21<31:02,  2.01it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 0, 'model': 'tabpfn'}


Processing:   2%|▎         | 96/3840 [00:22<42:43,  1.46it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 1, 'model': 'raw-scale-ridge'}
Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 1, 'model': 'quant'}


Processing:   3%|▎         | 98/3840 [00:22<33:07,  1.88it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 1, 'model': 'minirocket'}


Processing:   3%|▎         | 99/3840 [00:23<29:19,  2.13it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 1, 'model': 'catch22'}


Processing:   3%|▎         | 100/3840 [00:23<29:14,  2.13it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 1, 'model': 'hydra'}


Processing:   3%|▎         | 101/3840 [00:24<29:08,  2.14it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 1, 'model': 'tabpfn'}


Processing:   3%|▎         | 102/3840 [00:25<41:13,  1.51it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 2, 'model': 'raw-scale-ridge'}
Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 2, 'model': 'quant'}


Processing:   3%|▎         | 104/3840 [00:26<33:38,  1.85it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 2, 'model': 'minirocket'}


Processing:   3%|▎         | 105/3840 [00:26<37:04,  1.68it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 2, 'model': 'catch22'}


Processing:   3%|▎         | 106/3840 [00:27<35:08,  1.77it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 2, 'model': 'hydra'}


Processing:   3%|▎         | 107/3840 [00:27<33:13,  1.87it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 2, 'model': 'tabpfn'}


Processing:   3%|▎         | 108/3840 [00:28<43:45,  1.42it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 3, 'model': 'raw-scale-ridge'}
Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 3, 'model': 'quant'}


Processing:   3%|▎         | 110/3840 [00:29<33:40,  1.85it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 3, 'model': 'minirocket'}


Processing:   3%|▎         | 111/3840 [00:29<29:40,  2.09it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 3, 'model': 'catch22'}


Processing:   3%|▎         | 112/3840 [00:30<29:19,  2.12it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 3, 'model': 'hydra'}


Processing:   3%|▎         | 113/3840 [00:30<28:49,  2.15it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 3, 'model': 'tabpfn'}


Processing:   3%|▎         | 114/3840 [00:31<40:09,  1.55it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 4, 'model': 'raw-scale-ridge'}
Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 4, 'model': 'quant'}


Processing:   3%|▎         | 116/3840 [00:32<32:17,  1.92it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 4, 'model': 'minirocket'}


Processing:   3%|▎         | 117/3840 [00:32<29:02,  2.14it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 4, 'model': 'catch22'}


Processing:   3%|▎         | 118/3840 [00:33<29:07,  2.13it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 4, 'model': 'hydra'}


Processing:   3%|▎         | 119/3840 [00:33<28:58,  2.14it/s]

Processing {'dataset': 'ProximalPhalanxOutlineCorrect', 'run': 4, 'model': 'tabpfn'}


Processing:   3%|▎         | 120/3840 [00:34<40:33,  1.53it/s]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 0, 'model': 'raw-scale-ridge'}


Processing:   3%|▎         | 121/3840 [00:37<1:09:08,  1.12s/it]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 0, 'model': 'quant'}


Processing:   3%|▎         | 122/3840 [00:55<6:16:57,  6.08s/it]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 0, 'model': 'minirocket'}


Processing:   3%|▎         | 123/3840 [00:57<4:59:48,  4.84s/it]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 0, 'model': 'catch22'}


Processing:   3%|▎         | 124/3840 [01:02<5:04:08,  4.91s/it]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 0, 'model': 'hydra'}


Processing:   3%|▎         | 125/3840 [01:23<9:53:17,  9.58s/it]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 0, 'model': 'tabpfn'}
Error processing NonInvasiveFetalECGThorax1, run 0, model tabpfn: Number of classes 42 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 1, 'model': 'raw-scale-ridge'}


Processing:   3%|▎         | 127/3840 [01:23<5:26:30,  5.28s/it]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 1, 'model': 'quant'}


Processing:   3%|▎         | 128/3840 [01:41<8:43:11,  8.46s/it]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 1, 'model': 'minirocket'}


Processing:   3%|▎         | 129/3840 [01:43<6:56:59,  6.74s/it]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 1, 'model': 'catch22'}


Processing:   3%|▎         | 130/3840 [01:48<6:28:00,  6.27s/it]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 1, 'model': 'hydra'}


Processing:   3%|▎         | 131/3840 [02:10<10:51:57, 10.55s/it]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 1, 'model': 'tabpfn'}
Error processing NonInvasiveFetalECGThorax1, run 1, model tabpfn: Number of classes 42 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 2, 'model': 'raw-scale-ridge'}


Processing:   3%|▎         | 133/3840 [02:10<6:07:48,  5.95s/it] 

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 2, 'model': 'quant'}


Processing:   3%|▎         | 134/3840 [02:28<9:08:31,  8.88s/it]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 2, 'model': 'minirocket'}


Processing:   4%|▎         | 135/3840 [02:35<8:42:27,  8.46s/it]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 2, 'model': 'catch22'}


Processing:   4%|▎         | 136/3840 [02:41<7:47:21,  7.57s/it]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 2, 'model': 'hydra'}


Processing:   4%|▎         | 137/3840 [03:02<11:47:43, 11.47s/it]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 2, 'model': 'tabpfn'}
Error processing NonInvasiveFetalECGThorax1, run 2, model tabpfn: Number of classes 42 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 3, 'model': 'raw-scale-ridge'}


Processing:   4%|▎         | 139/3840 [03:03<6:45:06,  6.57s/it] 

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 3, 'model': 'quant'}


Processing:   4%|▎         | 140/3840 [03:21<9:35:09,  9.33s/it]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 3, 'model': 'minirocket'}


Processing:   4%|▎         | 141/3840 [03:23<7:41:02,  7.48s/it]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 3, 'model': 'catch22'}


Processing:   4%|▎         | 142/3840 [03:28<7:02:53,  6.86s/it]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 3, 'model': 'hydra'}


Processing:   4%|▎         | 143/3840 [03:54<12:26:45, 12.12s/it]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 3, 'model': 'tabpfn'}
Error processing NonInvasiveFetalECGThorax1, run 3, model tabpfn: Number of classes 42 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 4, 'model': 'raw-scale-ridge'}


Processing:   4%|▍         | 145/3840 [03:55<7:05:04,  6.90s/it] 

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 4, 'model': 'quant'}


Processing:   4%|▍         | 146/3840 [04:15<10:13:25,  9.96s/it]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 4, 'model': 'minirocket'}


Processing:   4%|▍         | 147/3840 [04:17<8:11:04,  7.98s/it] 

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 4, 'model': 'catch22'}


Processing:   4%|▍         | 148/3840 [04:22<7:25:48,  7.24s/it]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 4, 'model': 'hydra'}


Processing:   4%|▍         | 149/3840 [04:44<11:45:48, 11.47s/it]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 4, 'model': 'tabpfn'}
Error processing NonInvasiveFetalECGThorax1, run 4, model tabpfn: Number of classes 42 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'TwoPatterns', 'run': 0, 'model': 'raw-scale-ridge'}


Processing:   4%|▍         | 151/3840 [04:46<6:49:19,  6.66s/it] 

Processing {'dataset': 'TwoPatterns', 'run': 0, 'model': 'quant'}


Processing:   4%|▍         | 152/3840 [04:47<5:30:25,  5.38s/it]

Processing {'dataset': 'TwoPatterns', 'run': 0, 'model': 'minirocket'}


Processing:   4%|▍         | 153/3840 [04:48<4:21:00,  4.25s/it]

Processing {'dataset': 'TwoPatterns', 'run': 0, 'model': 'catch22'}


Processing:   4%|▍         | 154/3840 [04:50<3:41:24,  3.60s/it]

Processing {'dataset': 'TwoPatterns', 'run': 0, 'model': 'hydra'}


Processing:   4%|▍         | 155/3840 [04:54<3:49:10,  3.73s/it]

Processing {'dataset': 'TwoPatterns', 'run': 0, 'model': 'tabpfn'}


Processing:   4%|▍         | 156/3840 [05:02<5:07:56,  5.02s/it]

Processing {'dataset': 'TwoPatterns', 'run': 1, 'model': 'raw-scale-ridge'}
Processing {'dataset': 'TwoPatterns', 'run': 1, 'model': 'quant'}


Processing:   4%|▍         | 158/3840 [05:03<3:09:24,  3.09s/it]

Processing {'dataset': 'TwoPatterns', 'run': 1, 'model': 'minirocket'}


Processing:   4%|▍         | 159/3840 [05:04<2:36:21,  2.55s/it]

Processing {'dataset': 'TwoPatterns', 'run': 1, 'model': 'catch22'}


Processing:   4%|▍         | 160/3840 [05:06<2:25:04,  2.37s/it]

Processing {'dataset': 'TwoPatterns', 'run': 1, 'model': 'hydra'}


Processing:   4%|▍         | 161/3840 [05:10<2:49:26,  2.76s/it]

Processing {'dataset': 'TwoPatterns', 'run': 1, 'model': 'tabpfn'}


Processing:   4%|▍         | 162/3840 [05:18<4:15:18,  4.16s/it]

Processing {'dataset': 'TwoPatterns', 'run': 2, 'model': 'raw-scale-ridge'}
Processing {'dataset': 'TwoPatterns', 'run': 2, 'model': 'quant'}


Processing:   4%|▍         | 164/3840 [05:19<2:41:24,  2.63s/it]

Processing {'dataset': 'TwoPatterns', 'run': 2, 'model': 'minirocket'}


Processing:   4%|▍         | 165/3840 [05:20<2:16:16,  2.22s/it]

Processing {'dataset': 'TwoPatterns', 'run': 2, 'model': 'catch22'}


Processing:   4%|▍         | 166/3840 [05:22<2:10:48,  2.14s/it]

Processing {'dataset': 'TwoPatterns', 'run': 2, 'model': 'hydra'}


Processing:   4%|▍         | 167/3840 [05:26<2:39:58,  2.61s/it]

Processing {'dataset': 'TwoPatterns', 'run': 2, 'model': 'tabpfn'}


Processing:   4%|▍         | 168/3840 [05:34<4:07:57,  4.05s/it]

Processing {'dataset': 'TwoPatterns', 'run': 3, 'model': 'raw-scale-ridge'}
Processing {'dataset': 'TwoPatterns', 'run': 3, 'model': 'quant'}


Processing:   4%|▍         | 170/3840 [05:35<2:37:28,  2.57s/it]

Processing {'dataset': 'TwoPatterns', 'run': 3, 'model': 'minirocket'}


Processing:   4%|▍         | 171/3840 [05:37<2:27:57,  2.42s/it]

Processing {'dataset': 'TwoPatterns', 'run': 3, 'model': 'catch22'}


Processing:   4%|▍         | 172/3840 [05:39<2:19:26,  2.28s/it]

Processing {'dataset': 'TwoPatterns', 'run': 3, 'model': 'hydra'}


Processing:   5%|▍         | 173/3840 [05:43<2:45:55,  2.71s/it]

Processing {'dataset': 'TwoPatterns', 'run': 3, 'model': 'tabpfn'}


Processing:   5%|▍         | 174/3840 [05:51<4:11:56,  4.12s/it]

Processing {'dataset': 'TwoPatterns', 'run': 4, 'model': 'raw-scale-ridge'}
Processing {'dataset': 'TwoPatterns', 'run': 4, 'model': 'quant'}


Processing:   5%|▍         | 176/3840 [05:52<2:40:02,  2.62s/it]

Processing {'dataset': 'TwoPatterns', 'run': 4, 'model': 'minirocket'}


Processing:   5%|▍         | 177/3840 [05:53<2:14:30,  2.20s/it]

Processing {'dataset': 'TwoPatterns', 'run': 4, 'model': 'catch22'}


Processing:   5%|▍         | 178/3840 [05:55<2:09:34,  2.12s/it]

Processing {'dataset': 'TwoPatterns', 'run': 4, 'model': 'hydra'}


Processing:   5%|▍         | 179/3840 [05:59<2:38:46,  2.60s/it]

Processing {'dataset': 'TwoPatterns', 'run': 4, 'model': 'tabpfn'}


Processing:   5%|▌         | 192/3840 [06:07<43:32,  1.40it/s]  

Processing {'dataset': 'GestureMidAirD3', 'run': 0, 'model': 'tabpfn'}
Error processing GestureMidAirD3, run 0, model tabpfn: Number of classes 26 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'GestureMidAirD3', 'run': 1, 'model': 'tabpfn'}
Error processing GestureMidAirD3, run 1, model tabpfn: Number of classes 26 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'GestureMidAirD3', 'run': 2, 'model': 'tabpfn'}


Processing:   5%|▌         | 204/3840 [06:07<20:15,  2.99it/s]

Error processing GestureMidAirD3, run 2, model tabpfn: Number of classes 26 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'GestureMidAirD3', 'run': 3, 'model': 'tabpfn'}
Error processing GestureMidAirD3, run 3, model tabpfn: Number of classes 26 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'GestureMidAirD3', 'run': 4, 'model': 'tabpfn'}
Error processing GestureMidAirD3, run 4, model tabpfn: Number of classes 26 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https:/

Processing:   6%|▌         | 216/3840 [06:08<14:45,  4.09it/s]

Processing {'dataset': 'ToeSegmentation1', 'run': 1, 'model': 'tabpfn'}


Processing:   6%|▌         | 221/3840 [06:09<1:40:47,  1.67s/it]


KeyboardInterrupt: 

In [None]:
df = pl.read_parquet(f'{write_dir}/*.parquet').filter(pl.col("dataset").is_in(datasets))
df

In [None]:
datasets

In [None]:
gdf = df.group_by('dataset', 'model').agg([
    pl.col('test_accuracy').mean()
]).sort('test_accuracy')

In [None]:
together_dataset = set(
    gdf.filter(pl.col('model') == 'tabpfn')['dataset']
).intersection(
    set(gdf.filter(pl.col('model') == 'quant')['dataset'])
)
v1 = gdf.filter(pl.col('model') == 'tabpfn').filter(pl.col('dataset').is_in(together_dataset)).sort('dataset')['test_accuracy']
v2 = gdf.filter(pl.col('model') == 'quant').filter(pl.col('dataset').is_in(together_dataset)).sort('dataset')['test_accuracy']
plt.scatter(v1, v2)
# plot read line y=x
plt.plot([0, 1], [0, 1], color='red', linestyle='--')
plt.xlabel('TabPFN Test Accuracy')
plt.ylabel('QUANT Test Accuracy')
plt.title('Model Performance Comparison: TabPFN vs QUANT')
plt.grid(True)
plt.show()

In [None]:
together_dataset = set(
    gdf.filter(pl.col('model') == 'tabpfn')['dataset']
).intersection(
    set(gdf.filter(pl.col('model') == 'raw-scale-ridge')['dataset'])
)
v1 = gdf.filter(pl.col('model') == 'tabpfn').filter(pl.col('dataset').is_in(together_dataset)).sort('dataset')['test_accuracy']
v2 = gdf.filter(pl.col('model') == 'raw-scale-ridge').filter(pl.col('dataset').is_in(together_dataset)).sort('dataset')['test_accuracy']
plt.scatter(v1, v2)
# plot read line y=x
plt.plot([0, 1], [0, 1], color='red', linestyle='--')
plt.xlabel('TabPFN Test Accuracy')
plt.ylabel('Raw-scale Ridge Classifier Accuracy')
plt.title('Model Performance Comparison: TabPFN vs Raw-scale Ridge Classifier')
plt.grid(True)
plt.show()

In [None]:
v1 = gdf.filter(pl.col('model') == 'raw-scale-ridge').sort('dataset')['test_accuracy']
v2 = gdf.filter(pl.col('model') == 'quant').sort('dataset')['test_accuracy']
plt.scatter(v1, v2)
# plot read line y=x
plt.plot([0, 1], [0, 1], color='red', linestyle='--')
plt.xlabel('Raw-scale Ridge Classifier Accuracy')
plt.ylabel('QUANT Classifier Accuracy')
plt.title('Model Performance Comparison')
plt.grid(True)
plt.show()

In [None]:
v1 = gdf.filter(pl.col('model') == 'quant').sort('dataset')['test_accuracy']
v2 = gdf.filter(pl.col('model') == 'minirocket').sort('dataset')['test_accuracy']
plt.scatter(v1, v2)
# plot read line y=x
plt.plot([0, 1], [0, 1], color='red', linestyle='--')
plt.xlabel('QUANT Classifier Accuracy')
plt.ylabel('MiniRocket Classifier Accuracy')
plt.title('Model Performance Comparison')
plt.grid(True)
plt.show()

In [None]:
v1 = gdf.filter(pl.col('model') == 'catch22').sort('dataset')['test_accuracy']
v2 = gdf.filter(pl.col('model') == 'minirocket').sort('dataset')['test_accuracy']
plt.scatter(v1, v2)
# plot read line y=x
plt.plot([0, 1], [0, 1], color='red', linestyle='--')
plt.xlabel('Catch22 Classifier Accuracy')
plt.ylabel('MiniRocket Classifier Accuracy')
plt.title('Model Performance Comparison')
plt.grid(True)
plt.show()

In [None]:
v1 = gdf.filter(pl.col('model') == 'catch22').sort('dataset')['test_accuracy']
v2 = gdf.filter(pl.col('model') == 'raw-scale-ridge').sort('dataset')['test_accuracy']
plt.scatter(v1, v2)
# plot read line y=x
plt.plot([0, 1], [0, 1], color='red', linestyle='--')
plt.xlabel('Catch22 Classifier Accuracy')
plt.ylabel('Raw-scale Ridge Classifier Accuracy')
plt.title('Model Performance Comparison')
plt.grid(True)
plt.show()

In [None]:
gdf = df.group_by('dataset', 'model').agg([
    pl.col('training_time').mean()
]).sort('dataset', 'model')

In [None]:
import seaborn as sns

plt.figure(figsize=(10, 20))
sns.stripplot(data=gdf, y="dataset", x="training_time", hue="model", dodge=False)
plt.grid(True)
plt.xscale('log')
plt.tight_layout()
plt.show()

In [None]:
gdf = df.group_by('dataset', 'model').agg([
    pl.col('test_accuracy').mean()
]).sort('dataset', 'model')

In [None]:
df['dataset'].unique()

In [None]:
import seaborn as sns

plt.figure(figsize=(10, 20))
sns.stripplot(data=gdf, y="dataset", x="test_accuracy", hue="model", dodge=False)
plt.grid(True)
plt.xscale('log')
plt.tight_layout()
plt.show()

In [None]:
gdf = df.group_by('dataset', 'model').agg([
    pl.col('test_accuracy').mean(),
    pl.col('training_time').mean()
]).sort('dataset', 'model')

In [None]:
sns.scatterplot(data=gdf, x='training_time', y='test_accuracy', hue='model')
plt.xscale('log')