In [1]:
from autotsc import utils
import os
from aeon.datasets.tsc_datasets import univariate
import random
from itertools import product
from tqdm import tqdm
import polars as pl
import numpy as np
from time import time
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from aeon.classification.convolution_based import RocketClassifier, MiniRocketClassifier
import matplotlib.pyplot as plt
from sklearn.pipeline import make_pipeline
from aeon.transformations.collection import Normalizer
from aeon.classification.sklearn import SklearnClassifierWrapper
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import RidgeClassifierCV
from time import perf_counter

In [2]:
write_dir = "experiments/automl_ca_vs_time_correlation"
os.makedirs(write_dir, exist_ok=True)

datasets = list(univariate)
random.shuffle(datasets)
model_types = ['raw-scale-ridge', 'quant', 'minirocket', 'catch22', 'hydra', 'tabpfn', 'shapelet']

n_jobs = 4
n_runs = 5

In [3]:
def get_model(model_name):
    from sklearn.linear_model import RidgeClassifier
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.ensemble import HistGradientBoostingClassifier
    from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
    from aeon.classification.interval_based import QUANTClassifier
    from aeon.classification.feature_based import Catch22Classifier
    from aeon.classification.convolution_based import HydraClassifier
    from aeon.classification.shapelet_based import ShapeletTransformClassifier
    from tabpfn import TabPFNClassifier

    if model_name == 'raw-scale-ridge':
        return SklearnClassifierWrapper(
            make_pipeline(
                StandardScaler(),
                RidgeClassifierCV(alphas=np.logspace(-3, 3, 10))
            )
        )
    elif model_name == 'tabpfn':
        return SklearnClassifierWrapper(
            TabPFNClassifier(n_preprocessing_jobs=n_jobs)
        )
    elif model_name == 'quant':
        return QUANTClassifier()
    elif model_name == 'minirocket':
        return MiniRocketClassifier(n_jobs=n_jobs)
    elif model_name == 'catch22':
        return Catch22Classifier(n_jobs=n_jobs)
    elif model_name == 'hydra':
        return HydraClassifier(n_jobs=n_jobs)
    elif model_name == 'shapelet':
        return ShapeletTransformClassifier(n_jobs=n_jobs)
    else:
        raise ValueError(f"Unknown model: {model_name}")
        

In [4]:
all_combinations = list(product(datasets, range(n_runs), model_types))

last_dataset = None
X_train, y_train, X_test, y_test = None, None, None, None

for ds, run, model_name in tqdm(all_combinations, desc="Processing"):
    try: 
        model = get_model(model_name)
        stats = {
            "dataset": ds,
            "run": run,
            "model": model_name,
        }

        hash_val = pl.DataFrame([stats]).hash_rows(
            seed=42, seed_1=1, seed_2=2, seed_3=3
        ).item()
        file = f"{write_dir}/{hash_val}.parquet"

        if os.path.exists(file):
            #print(f'Skipping {stats}')
            continue
        else:
            print(f'Processing {stats}')

        if ds != last_dataset:
            X_train, y_train, X_test, y_test = utils.load_dataset(ds)
            last_dataset = ds

        start_time = perf_counter()
        model.fit(X_train, y_train)
        training_time = perf_counter() - start_time
        y_pred = model.predict(X_test)
        test_accuracy = accuracy_score(y_test, y_pred)

        stats["test_accuracy"] = test_accuracy
        stats["training_time"] = training_time

        df_stat = pl.DataFrame([stats])
        df_stat.write_parquet(file)

    except Exception as e:
        print(f"Error processing {ds}, run {run}, model {model_name}: {e}")
        continue

Processing:   0%|          | 1/4480 [00:00<27:27,  2.72it/s]

Processing {'dataset': 'Chinatown', 'run': 0, 'model': 'shapelet'}


Processing:   0%|          | 7/4480 [00:13<2:32:15,  2.04s/it]

Processing {'dataset': 'Chinatown', 'run': 1, 'model': 'shapelet'}


Processing:   0%|          | 14/4480 [00:21<1:50:04,  1.48s/it]

Processing {'dataset': 'Chinatown', 'run': 2, 'model': 'shapelet'}


Processing:   0%|          | 21/4480 [00:29<1:36:43,  1.30s/it]

Processing {'dataset': 'Chinatown', 'run': 3, 'model': 'shapelet'}


Processing:   1%|          | 28/4480 [00:36<1:30:41,  1.22s/it]

Processing {'dataset': 'Chinatown', 'run': 4, 'model': 'shapelet'}


Processing:   1%|          | 41/4480 [00:44<1:01:23,  1.21it/s]

Processing {'dataset': 'GestureMidAirD2', 'run': 0, 'model': 'tabpfn'}
Error processing GestureMidAirD2, run 0, model tabpfn: Number of classes 26 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'GestureMidAirD2', 'run': 0, 'model': 'shapelet'}


Processing:   1%|          | 42/4480 [01:34<5:08:27,  4.17s/it]

Processing {'dataset': 'GestureMidAirD2', 'run': 1, 'model': 'tabpfn'}
Error processing GestureMidAirD2, run 1, model tabpfn: Number of classes 26 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'GestureMidAirD2', 'run': 1, 'model': 'shapelet'}


Processing:   1%|          | 49/4480 [02:24<6:36:31,  5.37s/it]

Processing {'dataset': 'GestureMidAirD2', 'run': 2, 'model': 'tabpfn'}
Error processing GestureMidAirD2, run 2, model tabpfn: Number of classes 26 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'GestureMidAirD2', 'run': 2, 'model': 'shapelet'}


Processing:   1%|▏         | 56/4480 [03:14<7:21:28,  5.99s/it]

Processing {'dataset': 'GestureMidAirD2', 'run': 3, 'model': 'tabpfn'}
Error processing GestureMidAirD2, run 3, model tabpfn: Number of classes 26 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'GestureMidAirD2', 'run': 3, 'model': 'shapelet'}


Processing:   1%|▏         | 63/4480 [04:03<7:47:54,  6.36s/it]

Processing {'dataset': 'GestureMidAirD2', 'run': 4, 'model': 'tabpfn'}
Error processing GestureMidAirD2, run 4, model tabpfn: Number of classes 26 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'GestureMidAirD2', 'run': 4, 'model': 'shapelet'}


Processing:   2%|▏         | 76/4480 [04:53<5:45:29,  4.71s/it]

Processing {'dataset': 'GestureMidAirD1', 'run': 0, 'model': 'tabpfn'}
Error processing GestureMidAirD1, run 0, model tabpfn: Number of classes 26 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'GestureMidAirD1', 'run': 0, 'model': 'shapelet'}


Processing:   2%|▏         | 77/4480 [05:43<9:18:47,  7.61s/it]

Processing {'dataset': 'GestureMidAirD1', 'run': 1, 'model': 'tabpfn'}
Error processing GestureMidAirD1, run 1, model tabpfn: Number of classes 26 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'GestureMidAirD1', 'run': 1, 'model': 'shapelet'}


Processing:   2%|▏         | 84/4480 [06:33<9:03:28,  7.42s/it]

Processing {'dataset': 'GestureMidAirD1', 'run': 2, 'model': 'tabpfn'}
Error processing GestureMidAirD1, run 2, model tabpfn: Number of classes 26 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'GestureMidAirD1', 'run': 2, 'model': 'shapelet'}


Processing:   2%|▏         | 91/4480 [07:22<8:52:54,  7.29s/it]

Processing {'dataset': 'GestureMidAirD1', 'run': 3, 'model': 'tabpfn'}
Error processing GestureMidAirD1, run 3, model tabpfn: Number of classes 26 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'GestureMidAirD1', 'run': 3, 'model': 'shapelet'}


Processing:   2%|▏         | 98/4480 [08:12<8:49:50,  7.25s/it]

Processing {'dataset': 'GestureMidAirD1', 'run': 4, 'model': 'tabpfn'}
Error processing GestureMidAirD1, run 4, model tabpfn: Number of classes 26 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'GestureMidAirD1', 'run': 4, 'model': 'shapelet'}


Processing:   2%|▏         | 105/4480 [09:03<8:47:31,  7.23s/it]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 0, 'model': 'tabpfn'}


Processing:   2%|▏         | 111/4480 [09:03<6:18:13,  5.19s/it]

Error processing NonInvasiveFetalECGThorax1, run 0, model tabpfn: Number of classes 42 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 0, 'model': 'shapelet'}


Processing:   3%|▎         | 118/4480 [18:34<32:16:49, 26.64s/it]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 1, 'model': 'tabpfn'}
Error processing NonInvasiveFetalECGThorax1, run 1, model tabpfn: Number of classes 42 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 1, 'model': 'shapelet'}


Processing:   3%|▎         | 125/4480 [28:02<49:52:56, 41.23s/it]

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 2, 'model': 'tabpfn'}
Error processing NonInvasiveFetalECGThorax1, run 2, model tabpfn: Number of classes 42 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 2, 'model': 'shapelet'}


Processing:   3%|▎         | 132/4480 [37:30<60:11:03, 49.83s/it] 

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 3, 'model': 'tabpfn'}
Error processing NonInvasiveFetalECGThorax1, run 3, model tabpfn: Number of classes 42 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 3, 'model': 'shapelet'}


Processing:   3%|▎         | 139/4480 [46:58<65:45:03, 54.53s/it] 

Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 4, 'model': 'tabpfn'}
Error processing NonInvasiveFetalECGThorax1, run 4, model tabpfn: Number of classes 42 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'NonInvasiveFetalECGThorax1', 'run': 4, 'model': 'shapelet'}


Processing:   3%|▎         | 140/4480 [56:31<125:14:24, 103.89s/it]

Processing {'dataset': 'UWaveGestureLibraryAll', 'run': 0, 'model': 'shapelet'}


Processing:   3%|▎         | 147/4480 [1:15:35<159:56:11, 132.88s/it]

Processing {'dataset': 'UWaveGestureLibraryAll', 'run': 1, 'model': 'shapelet'}


Processing:   3%|▎         | 154/4480 [1:34:32<174:16:56, 145.03s/it]

Processing {'dataset': 'UWaveGestureLibraryAll', 'run': 2, 'model': 'shapelet'}


Processing:   4%|▎         | 161/4480 [1:53:37<182:13:28, 151.89s/it]

Processing {'dataset': 'UWaveGestureLibraryAll', 'run': 3, 'model': 'shapelet'}


Processing:   4%|▍         | 168/4480 [2:12:34<186:17:16, 155.53s/it]

Processing {'dataset': 'UWaveGestureLibraryAll', 'run': 4, 'model': 'shapelet'}


Processing:   4%|▍         | 181/4480 [2:31:38<134:25:09, 112.56s/it]

Processing {'dataset': 'FacesUCR', 'run': 0, 'model': 'tabpfn'}
Error processing FacesUCR, run 0, model tabpfn: Number of classes 14 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'FacesUCR', 'run': 0, 'model': 'shapelet'}


Processing:   4%|▍         | 182/4480 [2:32:22<129:07:29, 108.15s/it]

Processing {'dataset': 'FacesUCR', 'run': 1, 'model': 'tabpfn'}
Error processing FacesUCR, run 1, model tabpfn: Number of classes 14 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'FacesUCR', 'run': 1, 'model': 'shapelet'}


Processing:   4%|▍         | 189/4480 [2:33:05<81:24:13, 68.30s/it]  

Processing {'dataset': 'FacesUCR', 'run': 2, 'model': 'tabpfn'}
Error processing FacesUCR, run 2, model tabpfn: Number of classes 14 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'FacesUCR', 'run': 2, 'model': 'shapelet'}


Processing:   4%|▍         | 196/4480 [2:33:49<54:49:21, 46.07s/it]

Processing {'dataset': 'FacesUCR', 'run': 3, 'model': 'tabpfn'}
Error processing FacesUCR, run 3, model tabpfn: Number of classes 14 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'FacesUCR', 'run': 3, 'model': 'shapelet'}


Processing:   5%|▍         | 203/4480 [2:34:33<38:43:15, 32.59s/it]

Processing {'dataset': 'FacesUCR', 'run': 4, 'model': 'tabpfn'}
Error processing FacesUCR, run 4, model tabpfn: Number of classes 14 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'FacesUCR', 'run': 4, 'model': 'shapelet'}


Processing:   5%|▍         | 216/4480 [2:35:17<20:19:30, 17.16s/it]

Processing {'dataset': 'CricketY', 'run': 0, 'model': 'tabpfn'}
Error processing CricketY, run 0, model tabpfn: Number of classes 12 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'CricketY', 'run': 0, 'model': 'shapelet'}


Processing:   5%|▍         | 223/4480 [2:36:43<16:31:16, 13.97s/it]

Processing {'dataset': 'CricketY', 'run': 1, 'model': 'tabpfn'}
Error processing CricketY, run 1, model tabpfn: Number of classes 12 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'CricketY', 'run': 1, 'model': 'shapelet'}


Processing:   5%|▌         | 230/4480 [2:38:10<13:55:36, 11.80s/it]

Processing {'dataset': 'CricketY', 'run': 2, 'model': 'tabpfn'}
Error processing CricketY, run 2, model tabpfn: Number of classes 12 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'CricketY', 'run': 2, 'model': 'shapelet'}


Processing:   5%|▌         | 237/4480 [2:39:37<12:23:25, 10.51s/it]

Processing {'dataset': 'CricketY', 'run': 3, 'model': 'tabpfn'}
Error processing CricketY, run 3, model tabpfn: Number of classes 12 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'CricketY', 'run': 3, 'model': 'shapelet'}


Processing:   5%|▌         | 244/4480 [2:41:04<11:32:00,  9.80s/it]

Processing {'dataset': 'CricketY', 'run': 4, 'model': 'tabpfn'}
Error processing CricketY, run 4, model tabpfn: Number of classes 12 exceeds the maximal number of classes supported by TabPFN. Consider using a strategy to reduce the number of classes. For code see https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/many_class/many_class_classifier.py
Processing {'dataset': 'CricketY', 'run': 4, 'model': 'shapelet'}


Processing:   5%|▌         | 245/4480 [2:42:31<20:09:20, 17.13s/it]

Processing {'dataset': 'BirdChicken', 'run': 0, 'model': 'shapelet'}


Processing:   6%|▌         | 252/4480 [2:42:42<11:12:47,  9.55s/it]

Processing {'dataset': 'BirdChicken', 'run': 1, 'model': 'shapelet'}


Processing:   6%|▌         | 259/4480 [2:42:53<7:20:08,  6.26s/it] 

Processing {'dataset': 'BirdChicken', 'run': 2, 'model': 'shapelet'}


Processing:   6%|▌         | 266/4480 [2:43:03<5:16:44,  4.51s/it]

Processing {'dataset': 'BirdChicken', 'run': 3, 'model': 'shapelet'}


Processing:   6%|▌         | 273/4480 [2:43:14<4:04:03,  3.48s/it]

Processing {'dataset': 'BirdChicken', 'run': 4, 'model': 'shapelet'}


Processing:   6%|▋         | 280/4480 [2:43:25<3:18:19,  2.83s/it]

Processing {'dataset': 'SmallKitchenAppliances', 'run': 0, 'model': 'shapelet'}


Processing:   6%|▋         | 287/4480 [2:47:30<15:18:18, 13.14s/it]

Processing {'dataset': 'SmallKitchenAppliances', 'run': 1, 'model': 'shapelet'}


Processing:   7%|▋         | 294/4480 [2:51:34<23:13:29, 19.97s/it]

Processing {'dataset': 'SmallKitchenAppliances', 'run': 2, 'model': 'shapelet'}


Processing:   7%|▋         | 301/4480 [2:55:37<28:28:38, 24.53s/it]

Processing {'dataset': 'SmallKitchenAppliances', 'run': 3, 'model': 'shapelet'}


Processing:   7%|▋         | 308/4480 [2:59:38<31:58:42, 27.59s/it]

Processing {'dataset': 'SmallKitchenAppliances', 'run': 4, 'model': 'shapelet'}


Processing:   7%|▋         | 315/4480 [3:03:44<34:33:01, 29.86s/it]

Processing {'dataset': 'HouseTwenty', 'run': 0, 'model': 'shapelet'}


Processing:   7%|▋         | 322/4480 [3:06:49<33:18:44, 28.84s/it]

Processing {'dataset': 'HouseTwenty', 'run': 1, 'model': 'shapelet'}


Processing:   7%|▋         | 329/4480 [3:09:52<32:19:45, 28.04s/it]

Processing {'dataset': 'HouseTwenty', 'run': 2, 'model': 'shapelet'}


Processing:   8%|▊         | 336/4480 [3:12:54<31:33:49, 27.42s/it]

Processing {'dataset': 'HouseTwenty', 'run': 3, 'model': 'shapelet'}


Processing:   8%|▊         | 343/4480 [3:15:58<31:06:15, 27.07s/it]

Processing {'dataset': 'HouseTwenty', 'run': 4, 'model': 'shapelet'}


Processing:   8%|▊         | 349/4480 [3:18:11<39:05:54, 34.07s/it]


KeyboardInterrupt: 

In [None]:
df = pl.read_parquet(f'{write_dir}/*.parquet').filter(pl.col("dataset").is_in(datasets))
df

In [None]:
datasets

In [None]:
gdf = df.group_by('dataset', 'model').agg([
    pl.col('test_accuracy').mean()
]).sort('test_accuracy')

In [None]:
together_dataset = set(
    gdf.filter(pl.col('model') == 'tabpfn')['dataset']
).intersection(
    set(gdf.filter(pl.col('model') == 'quant')['dataset'])
)
v1 = gdf.filter(pl.col('model') == 'tabpfn').filter(pl.col('dataset').is_in(together_dataset)).sort('dataset')['test_accuracy']
v2 = gdf.filter(pl.col('model') == 'quant').filter(pl.col('dataset').is_in(together_dataset)).sort('dataset')['test_accuracy']
plt.scatter(v1, v2)
# plot read line y=x
plt.plot([0, 1], [0, 1], color='red', linestyle='--')
plt.xlabel('TabPFN Test Accuracy')
plt.ylabel('QUANT Test Accuracy')
plt.title('Model Performance Comparison: TabPFN vs QUANT')
plt.grid(True)
plt.show()

In [None]:
together_dataset = set(
    gdf.filter(pl.col('model') == 'tabpfn')['dataset']
).intersection(
    set(gdf.filter(pl.col('model') == 'raw-scale-ridge')['dataset'])
)
v1 = gdf.filter(pl.col('model') == 'tabpfn').filter(pl.col('dataset').is_in(together_dataset)).sort('dataset')['test_accuracy']
v2 = gdf.filter(pl.col('model') == 'raw-scale-ridge').filter(pl.col('dataset').is_in(together_dataset)).sort('dataset')['test_accuracy']
plt.scatter(v1, v2)
# plot read line y=x
plt.plot([0, 1], [0, 1], color='red', linestyle='--')
plt.xlabel('TabPFN Test Accuracy')
plt.ylabel('Raw-scale Ridge Classifier Accuracy')
plt.title('Model Performance Comparison: TabPFN vs Raw-scale Ridge Classifier')
plt.grid(True)
plt.show()

In [None]:
v1 = gdf.filter(pl.col('model') == 'raw-scale-ridge').sort('dataset')['test_accuracy']
v2 = gdf.filter(pl.col('model') == 'quant').sort('dataset')['test_accuracy']
plt.scatter(v1, v2)
# plot read line y=x
plt.plot([0, 1], [0, 1], color='red', linestyle='--')
plt.xlabel('Raw-scale Ridge Classifier Accuracy')
plt.ylabel('QUANT Classifier Accuracy')
plt.title('Model Performance Comparison')
plt.grid(True)
plt.show()

In [None]:
v1 = gdf.filter(pl.col('model') == 'quant').sort('dataset')['test_accuracy']
v2 = gdf.filter(pl.col('model') == 'minirocket').sort('dataset')['test_accuracy']
plt.scatter(v1, v2)
# plot read line y=x
plt.plot([0, 1], [0, 1], color='red', linestyle='--')
plt.xlabel('QUANT Classifier Accuracy')
plt.ylabel('MiniRocket Classifier Accuracy')
plt.title('Model Performance Comparison')
plt.grid(True)
plt.show()

In [None]:
v1 = gdf.filter(pl.col('model') == 'catch22').sort('dataset')['test_accuracy']
v2 = gdf.filter(pl.col('model') == 'minirocket').sort('dataset')['test_accuracy']
plt.scatter(v1, v2)
# plot read line y=x
plt.plot([0, 1], [0, 1], color='red', linestyle='--')
plt.xlabel('Catch22 Classifier Accuracy')
plt.ylabel('MiniRocket Classifier Accuracy')
plt.title('Model Performance Comparison')
plt.grid(True)
plt.show()

In [None]:
v1 = gdf.filter(pl.col('model') == 'catch22').sort('dataset')['test_accuracy']
v2 = gdf.filter(pl.col('model') == 'raw-scale-ridge').sort('dataset')['test_accuracy']
plt.scatter(v1, v2)
# plot read line y=x
plt.plot([0, 1], [0, 1], color='red', linestyle='--')
plt.xlabel('Catch22 Classifier Accuracy')
plt.ylabel('Raw-scale Ridge Classifier Accuracy')
plt.title('Model Performance Comparison')
plt.grid(True)
plt.show()

In [None]:
gdf = df.group_by('dataset', 'model').agg([
    pl.col('training_time').mean()
]).sort('dataset', 'model')

In [None]:
import seaborn as sns

plt.figure(figsize=(10, 20))
sns.stripplot(data=gdf, y="dataset", x="training_time", hue="model", dodge=False)
plt.grid(True)
plt.xscale('log')
plt.tight_layout()
plt.show()

In [None]:
gdf = df.group_by('dataset', 'model').agg([
    pl.col('test_accuracy').mean()
]).sort('dataset', 'model')

In [None]:
df['dataset'].unique()

In [None]:
import seaborn as sns

plt.figure(figsize=(10, 20))
sns.stripplot(data=gdf, y="dataset", x="test_accuracy", hue="model", dodge=False)
plt.grid(True)
plt.xscale('log')
plt.tight_layout()
plt.show()

In [None]:
gdf = df.group_by('dataset', 'model').agg([
    pl.col('test_accuracy').mean(),
    pl.col('training_time').mean()
]).sort('dataset', 'model')

In [None]:
sns.scatterplot(data=gdf, x='training_time', y='test_accuracy', hue='model')
plt.xscale('log')