This notebook should be run outside of the folder containing CAAFE code. The CAAFE code should lie in a directory called 'cafe_feature_engineering' one level above this notebook.

# Installation

In [None]:
!pip install openai
!pip install kaggle
!pip install openml
!pip install submitit
!pip install tabpfn[full]

In [None]:
!pip install autofeat
!pip install featuretools

In [None]:
#!ls ~/.kaggle/kaggle.json

#!mkdir ~/.kaggle
!touch ~/.kaggle/kaggle.json

kaggle_api_token = {"username":"XXX","key":"XXX"}

import json
with open('~/.kaggle/kaggle.json', 'w') as file:
    json.dump(kaggle_api_token, file)

!chmod 600 ~/.kaggle/kaggle.json
!mkdir datasets_kaggle/

In [None]:
base_path = '.'

In [None]:
!mkdir {base_path}/results
!mkdir {base_path}/results/tabular/
!mkdir {base_path}/results/tabular/multiclass/

### Download from Kaggle

In [None]:
from cafe_feature_engineering import data

In [None]:
for (name, _, _, user) in data.kaggle_dataset_ids:
    !kaggle datasets download -d {user}/{name}
    !mkdir datasets_kaggle/{name}
    !unzip {name}.zip -d datasets_kaggle/{name}

In [None]:
# Accept rules at https://www.kaggle.com/c/spaceship-titanic/rules
for name in data.kaggle_competition_ids:
    print(name)
    !kaggle competitions download -c {name}
    !mkdir datasets_kaggle/{name}
    !unzip {name}.zip -d datasets_kaggle/{name}

# Code

In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
import copy
import matplotlib.pyplot as plt
from joblib import Parallel, delayed
import multiprocessing

from tabpfn.scripts import tabular_baselines

import numpy as np
from tabpfn.scripts.tabular_baselines import *
from tabpfn.scripts.tabular_evaluation import evaluate
from tabpfn.scripts import tabular_metrics
from tabpfn import TabPFNClassifier
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score
from cafe_feature_engineering import data, cafe, plotting, evaluate, feature_extension_baselines
import tabpfn
import submitit
import openai

In [None]:
os.environ["OPENAI_API_KEY"] = "XXX" # You can get an openai access key by creating an account at openai
openai.api_key = os.environ["OPENAI_API_KEY"]
os.environ["DATA_DIR"] = "cafe_feature_engineering/data"

In [None]:
metric_used = tabular_metrics.auc_metric
methods = ['transformer', 'logistic', 'gp', 'knn', 'catboost', 'xgb', 'autosklearn2', 'autogluon', 'random_forest']

In [None]:
from functools import partial
classifier = TabPFNClassifier(device="cpu", N_ensemble_configurations=16)
classifier.fit = partial(classifier.fit,  overwrite_warning=True)
tabpfn_method = partial(clf_dict["transformer"], classifier=classifier)

classifier_fast = TabPFNClassifier(device="cpu", N_ensemble_configurations=1)
classifier_fast.fit = partial(classifier_fast.fit,  overwrite_warning=True)
tabpfn_method_fast = partial(clf_dict["transformer"], classifier=classifier_fast)

### Load data

In [None]:
cc_test_datasets_multiclass = data.load_all_data()

### Test run

In [None]:
ds = cc_test_datasets_multiclass[5]
seed = 1
ds, df_train, df_test, df_train_old, df_test_old = data.get_data_split(ds, seed)

In [None]:
code, prompt, messages = cafe.generate_features(ds,
                                                df_train,
                                                just_print_prompt=False,
                                                model="gpt-4",
                                                iterative=10,
                                                iterative_method=tabpfn_method,
                                                metric_used=metric_used)

### Run experiments

#### Setup queue

In [None]:
job_queue = {}
global ex
global q
maximum_runtime = 0
log_folder = 'logs/'

In [None]:
def run_locally(f, *args, **kwargs):
    return f(*args, **kwargs)

#### Generate Feature Extension Code

In [None]:
import subprocess
jobs = []
submit_func = run_locally
submit_func = ex.submit
def exec_(seed, dsid):
    subprocess.run(f'python -m cafe_feature_engineering.generate_features_script --seed {seed} --dataset_id {dsid} --prompt_id v3 --iterations 1', shell=True)
    return None
    
for n in tqdm(range(0, len(cc_test_datasets_multiclass))): # len(cc_test_datasets_multiclass)-1
    for seed in range(0, 5):
        jobs += [submit_func(exec_, seed, n)]

#### Run Evaluations of Baselines

In [None]:
results = []
jobs = []
methods = ['autogluon', 'autosklearn', "random_forest", tabpfn_method, "logistic"] # ,"random_forest", tabpfn_method, "logistic", "autosklearn2", "autogluon"
prompts = ['v3', 'v4+dfs', 'autofeat', 'v4+autofeat', 'v3+autofeat', 'v3+dfs']
submit_func = run_locally
for method in methods[::-1]:
    for prompt_id in prompts:
        for n in tqdm(range(0, len(cc_test_datasets_multiclass))): # len(cc_test_datasets_multiclass)-1
            for seed in range(0, 5):
                ds = cc_test_datasets_multiclass[n]
                method_str = method if type(method) == str else "transformer"
                data_dir = os.environ.get("DATA_DIR", "data/")
                path = (
                    f"{data_dir}/evaluations/result_{ds[0]}_{prompt_id}_{seed}_{method_str}.txt"
                )
                #if os.path.exists(path):
                #    continue
                #else:
                #    print('no exist')
                jobs += [submit_func(evaluate.evaluate_dataset_with_and_without_cafe, ds,
                                                     seed,
                                                     [method],
                                                    metric_used,
                                     overwrite=True,
                                                    prompt_id=prompt_id
                )]

## Visualize results

### Load Results

In [None]:
all_results = {}
all_prompts = ['', 'v4', 'v3', 'dfs', 'v4+dfs', 'autofeat', 'v4+autofeat']
all_methods = [tabpfn_method, "random_forest", "logistic", "autosklearn", "autogluon"]

In [None]:
for prompt_id in all_prompts:
    for method in all_methods: # tabpfn, "logistic",  "logistic", "random_forest", 
        method_str = method if type(method) == str else "transformer"
        for n in tqdm(range(0, len(cc_test_datasets_multiclass))): # len(cc_test_datasets_multiclass)-1
            for seed in range(0, 5):
                ds = cc_test_datasets_multiclass[n]
                r = evaluate.load_result(all_results, cc_test_datasets_multiclass[n],seed,method,prompt_id=prompt_id)
                

### Setup

In [None]:
clf_relabeler = {'transformer': 'Tabular PFN'
             , 'autogluon': 'Autogluon'
             , 'autosklearn2': 'Autosklearn2'
             , 'ridge':'Ridge'
             , 'gp': 'GP (RBF)'
             , 'bayes': 'BNN'
             , 'tabnet': 'Tabnet'
             , 'logistic': 'Log. Regr.'
             , 'knn': 'KNN'
             , 'catboost': 'Catboost'
            , 'xgb': 'XGB'}

def rename_table_vis(table):
    ren = {'blood-transfusion-service-center': 'blood-transfus..',
        'jungle_chess_2pcs_raw_endgame_complete': 'jungle_chess..',
       'bank-marketing': 'bank-market..',
       'kaggle_spaceship-titanic': '[Kaggle] spaceship-titanic',
       'kaggle_playground-series-s3e12': '[Kaggle] kidney-stone',
       'kaggle_health-insurance-lead-prediction-raw-data': '[Kaggle] health-insurance',
       'kaggle_pharyngitis': '[Kaggle] pharyngitis'
       
      }
    
    return table.rename(columns=clf_relabeler).T.rename(columns=ren).T

def table_sorter(x):
    methods_sort = {'logistic': 0, 'random_forest': 1, 'autogluon': 2, 'autosklearn': 3, 'transformer': 4}
    prompts_sort = {'': 0, 'dfs': 1, 'v4+dfs': 2, 'autofeat': 3, 'v4+autofeat': 4, 'v3': 5, 'v4': 6}
    x = x.split('_')
    return str(methods_sort.get(x[0], 9)) + str(prompts_sort.get(x[1], 9))

### Creating dataframes

In [None]:
metric = 'roc'

In [None]:
df_all = pd.DataFrame(all_results).T
df_all = df_all.set_index('name')

# Filtering
df_all = df_all[df_all.seed < 5]
df_all = df_all[df_all.index != "wine"]
df_all = df_all[(df_all.method != "autogluon") & (df_all.method != "autosklearn")]

# How many features added?
feats_extended = df_all[np.logical_and.reduce((df_all.prompt == prompt_id, df_all.seed == 0, df_all.method == 'logistic'))].feats.sum()
feats_old = df_all[np.logical_and.reduce((df_all.prompt == '', df_all.seed == 0, df_all.method == 'logistic'))].feats.sum()
print('Features added', feats_extended, feats_old)

# Create results dataframe
df_all_agg_seeds = df_all.groupby(by=["name", "method", "prompt"])['acc'].mean()
rank_df = df_all_agg_seeds.groupby(by=["name", "method"]).rank(ascending=False)

df_all['rank_within_ds'] = rank_df
df_all['wins_within_ds'] = rank_df == 1.0
df_all['ties_within_ds'] = rank_df == 1.5

df_all_grouped_by_method = df_all.groupby(by=["prompt", "method"]).agg({'acc': ['mean'],
                                   'roc': ['mean'],
                                   'rank_within_ds': ['mean'],
                                   'wins_within_ds': ['sum'],
                                  'ties_within_ds': ['sum']
                                                                       }).T

df_all_grouped_by_ds = df_all.groupby(by=["name", "prompt", "method", "seed"]).agg({'acc': ['mean'], 'roc': ['mean']})
df_all_grouped_by_ds.columns = df_all_grouped_by_ds.columns.get_level_values(0)
df_all_grouped_by_ds = df_all_grouped_by_ds.reset_index()

### Print table

#### Check all results are ready

In [None]:
df_all_grouped_by_ds_table_1 = df_all_grouped_by_ds.copy()

In [None]:
df_all_grouped_by_ds_print = df_all_grouped_by_ds_table_1.groupby(by=["name", 
                                                              "prompt", "method"])[metric].count().reset_index().pivot(index='name', columns=['method', 'prompt'], values=metric)
df_all_grouped_by_ds_print.columns = ['_'.join(col) for col in df_all_grouped_by_ds_print.columns]

In [None]:
pd.set_option('display.max_columns', 500)
df_all_grouped_by_ds_print

#### Table only CAFE

In [None]:
df_all_grouped_by_ds_table_1 = df_all_grouped_by_ds.copy()
df_all_grouped_by_ds_table_1 = df_all_grouped_by_ds_table_1[(df_all_grouped_by_ds.method== "transformer")]
df_all_grouped_by_ds_table_1 = df_all_grouped_by_ds_table_1[(df_all_grouped_by_ds_table_1.prompt == "v4") | (df_all_grouped_by_ds_table_1.prompt == "") | (df_all_grouped_by_ds_table_1.prompt == "v3")]


In [None]:
df_all_grouped_by_ds_print = df_all_grouped_by_ds_table_1.groupby(by=["name", "prompt", "method"])[metric].mean().reset_index()


In [None]:
stds = df_all_grouped_by_ds_table_1.groupby(by=["name", "prompt", "method"])[metric].std().reset_index()

In [None]:
df_all_grouped_by_ds_print['ranks'] = df_all_grouped_by_ds_print.groupby(['name', 'method']).rank()['roc']

In [None]:
df_all_grouped_by_ds_print_ranks = df_all_grouped_by_ds_print.copy().pivot(index='name', columns=['method', 'prompt'], values='ranks')
df_all_grouped_by_ds_print_ranks.columns = ['_'.join(col) for col in df_all_grouped_by_ds_print_ranks.columns]

In [None]:
df_all_grouped_by_ds_print_ranks

In [None]:
df_all_grouped_by_ds_print = df_all_grouped_by_ds_print.pivot(index='name', columns=['method', 'prompt'], values=metric)
df_all_grouped_by_ds_print.columns = ['_'.join(col) for col in df_all_grouped_by_ds_print.columns]

In [None]:
df_all_grouped_by_ds_print

In [None]:
df_all_grouped_by_ds_print_stds = stds.pivot(index='name', columns=['method', 'prompt'], values=metric)
df_all_grouped_by_ds_print_stds.columns = ['_'.join(col) for col in df_all_grouped_by_ds_print_stds.columns]
df_all_grouped_by_ds_print_stds

In [None]:
df_all_grouped_by_ds_print.loc[f'Mean ROC'] = df_all_grouped_by_ds_print.mean(axis=1,level=0).mean().values
df_all_grouped_by_ds_print.loc[f'Mean ROC Stds'] = df_all_grouped_by_ds_print_stds.mean(axis=1,level=0).mean().values
df_all_grouped_by_ds_print.loc[f'Mean Rank'] = df_all_grouped_by_ds_print_ranks.mean(axis=1,level=0).mean().values

In [None]:
cols = df_all_grouped_by_ds_print.columns.tolist()
cols = sorted(cols)
N_end = 0
N_cols = 10
N_methods = len(all_methods) - 2
offset = 0
#cols = [cols[i // N_cols + (i % N_cols) * N_methods] for i in range(0, len(cols) - N_end)]# + cols[-N_end:]
df_all_grouped_by_ds_print = df_all_grouped_by_ds_print[cols]

In [None]:
def bold_extreme_values(data, format_string="%.5g", max_=True):
    data = data.astype(float).round(4)
    if max_:
        extrema = data != data.max()
    else:
        extrema = data != data.min()
    bolded = data.apply(lambda x : "\\textbf{%s}" % format_string % x)
    formatted = data.apply(lambda x : format_string % x)
    return formatted.where(extrema, bolded) 

def to_str(data, format_string="%.3g", drop=False):
    if drop:
        formatted = data.apply(lambda x : (format_string % x)[1:])
    else:
        formatted = data.apply(lambda x : (format_string % x))
    return formatted

In [None]:
table.index[:-2]

In [None]:
table = rename_table_vis(df_all_grouped_by_ds_print).copy()

non_agg = table.index[:-3]
table.loc[non_agg] = table.loc[non_agg].apply(lambda data : bold_extreme_values(data),axis=1)
table.loc[non_agg] =  table.loc[non_agg] + ' {\\scriptsize $\pm$' +  rename_table_vis(df_all_grouped_by_ds_print_stds).loc[non_agg].apply(lambda data : to_str(data, format_string="%.2f", drop=True),axis=1) + '}'

table.loc[['Mean ROC']] = table.loc[['Mean ROC']].apply(lambda data : bold_extreme_values(data), axis=1)
table.loc[['Mean ROC Stds']] = table.loc[['Mean ROC Stds']].apply(lambda data : to_str(data, format_string="%.2f", drop=True),axis=1)
table.loc['Mean ROC'] = table.loc['Mean ROC'] + ' {\\scriptsize $\pm$' + table.loc['Mean ROC Stds'] +'}'
table = table.drop(['Mean ROC Stds'])

table.loc[['Mean Rank']] = table.loc[['Mean Rank']].apply(lambda data : bold_extreme_values(data, format_string="%.2f"), axis=1)

table

In [None]:
import re
tab_string = table.to_latex(escape=False).replace('[Kaggle]', '$\\langle Kaggle\\rangle$')
tab_string = re.sub(r' \\font-weightbold ([0-9\.]*) ', ' \\\\textbf{\\1} ', tab_string)
tab_string = tab_string.replace(r"""\begin{tabular}{llll}
\toprule
{} &             transformer_ &           transformer_v3 &           transformer_v4 \\
name                       &                          &                          &                          \\""", r"""\begin{tabular}{l|r|r|rr}
\toprule
{} & \multicolumn{1}{c}{TabPFN} & \multicolumn{1}{c}{TabPFN + CAAFE (GPT-3.5)} & \multicolumn{1}{c}{TabPFN + CAAFE (GPT-4)} \\""")
print(tab_string)

#### Table 1

In [None]:
df_all_grouped_by_ds_table_1 = df_all_grouped_by_ds.copy()
#df_all_grouped_by_ds_table_1 = df_all_grouped_by_ds_table_1[df_all_grouped_by_ds_table_1.prompt != "v3"]
#df_all_grouped_by_ds_table_1 = df_all_grouped_by_ds_table_1[np.logical_or(df_all_grouped_by_ds_table_1.prompt == "", df_all_grouped_by_ds_table_1.method == "transformer")]


In [None]:
df_all_grouped_by_ds_table_1['ranks'] = df_all_grouped_by_ds_table_1.groupby(by=["prompt", "method"])[metric].rank()

In [None]:
df_all_grouped_by_ds_table_1

In [None]:
df_all_grouped_by_ds_print = df_all_grouped_by_ds_table_1.groupby(by=["prompt", "method"])[metric].mean().reset_index()
#df_all_grouped_by_ds_print.columns = ['_'.join(col) for col in df_all_grouped_by_ds_print.columns]

#df_all_grouped_by_ds_print.loc[f'Mean ROC'] = df_all_grouped_by_ds_print.mean(axis=1,level=0).mean().values

In [None]:
df_all_grouped_by_ds_print = df_all_grouped_by_ds_print.reset_index().pivot(index='method', columns=['prompt'], values=metric)
df_all_grouped_by_ds_print

In [None]:
stds = df_all_grouped_by_ds_table_1.groupby(by=["prompt", "method"])[metric].std().reset_index()
df_all_grouped_by_ds_print_stds = stds.reset_index().pivot(index='method', columns=['prompt'], values=metric)
df_all_grouped_by_ds_print_stds

In [None]:
table = df_all_grouped_by_ds_print.reindex(["logistic", "random_forest", "transformer"]).copy()

non_agg = table.index
table.loc[non_agg] = table.loc[non_agg].apply(lambda data : bold_extreme_values(data, format_string="%.3g"),axis=1)
table.loc[non_agg] =  table.loc[non_agg] + ' {\\scriptsize $\pm$' + rename_table_vis(df_all_grouped_by_ds_print_stds).loc[non_agg].apply(lambda data : to_str(data, format_string="%.2f", drop=True),axis=1)+ '}'

table

In [None]:
import re
tab_string = table.to_latex(escape=False).replace('[Kaggle]', '$\\langle Kaggle\\rangle$')
tab_string = re.sub(r' \\font-weightbold ([0-9\.]*) ', ' \\\\textbf{\\1} ', tab_string)
tab_string = tab_string.replace("transformer", "TabPFN")
tab_string = tab_string.replace("logistic", "Log. Reg.")
tab_string = tab_string.replace("random_forest", "Random Forest")
tab_string = tab_string.replace(r"""\begin{tabular}{llllllll}
\toprule
prompt &        autofeat &             dfs &              v3 &                       v4 &             v4+autofeat &                   v4+dfs \\
method        &                 &                 &                 &                 &                          &                         &                          \\"""
                                , r"""\begin{tabular}{llllllll}
\toprule
{} & \multicolumn{1}{c}{} & \multicolumn{2}{c}{Baselines} & \multicolumn{2}{c}{CAAFE} & \multicolumn{2}{c}{Baseline + CAAFE} \\""")
print(tab_string)

#### Table 2

In [None]:
# This table only shows TabPFN performance and then different extension baselines

In [None]:
df_all_grouped_by_ds_table_1 = df_all_grouped_by_ds.copy()
#df_all_grouped_by_ds_table_1 = df_all_grouped_by_ds_table_1[df_all_grouped_by_ds_table_1.prompt != "v3"]
#df_all_grouped_by_ds_table_1 = df_all_grouped_by_ds_table_1[np.logical_or(df_all_grouped_by_ds_table_1.prompt == "", df_all_grouped_by_ds_table_1.method == "transformer")]


In [None]:
df_all_grouped_by_ds_print = df_all_grouped_by_ds_table_1.groupby(by=["name", 
                                                              "prompt", "method"])[metric].mean().reset_index().pivot(index='name', columns=['method', 'prompt'], values=metric)
df_all_grouped_by_ds_print.columns = ['_'.join(col) for col in df_all_grouped_by_ds_print.columns]

df_all_grouped_by_ds_print.loc[f'Mean ROC'] = df_all_grouped_by_ds_print.mean(axis=1,level=0).mean().values


In [None]:
cols = df_all_grouped_by_ds_print.columns.tolist()
cols = sorted(cols, key=table_sorter)
N_end = 0
N_cols = len(all_prompts)
N_methods = len(all_methods)
offset = 0
#cols = [cols[i // N_cols + (i % N_cols) * N_methods] for i in range(0, len(cols) - N_end)]# + cols[-N_end:]
df_all_grouped_by_ds_print = df_all_grouped_by_ds_print[cols]

In [None]:
for i in range(offset, len(df_all_grouped_by_ds_print.columns)):
    if (i - offset) % N_cols == 0:
        continue
    comparison_idx = offset + N_cols * ((i - offset) // N_cols)
    df_all_grouped_by_ds_print.iloc[:, i] = df_all_grouped_by_ds_print.iloc[:, i] - df_all_grouped_by_ds_print.iloc[:, comparison_idx]


In [None]:
table = (rename_table_vis(df_all_grouped_by_ds_print).round(decimals=4)
         .style
         .highlight_max(subset=df_all_grouped_by_ds_print.columns[offset+1:offset+N_cols], axis=1, props='font-weight: bold;')
         .highlight_max(subset=df_all_grouped_by_ds_print.columns[offset+N_cols+1:offset+N_cols*2], axis=1, props='font-weight: bold;')
         .highlight_max(subset=df_all_grouped_by_ds_print.columns[offset+N_cols*2+1:offset+N_cols*3], axis=1, props='font-weight: bold;')
         .format(precision=4))
table

In [None]:
import re
tab_string = table.to_latex().replace('[Kaggle]', '$\\langle Kaggle\\rangle$')
tab_string = re.sub(r' \\font-weightbold ([0-9\.]*) ', ' \\\\textbf{\\1} ', tab_string)
print(tab_string)

#### Print overview of datasets

In [None]:
# Printing all dataset descriptions
for n in tqdm(range(0, len(cc_test_datasets_multiclass))):
    ds = cc_test_datasets_multiclass[n]
    import re
    print("""\\begin{figure}[h]
    \\centering
    \\begin{minipage}{\\textwidth}
    \\begin{lstlisting}""")
    print(data.get_data_split(ds, 0)[0][-1])
    print("""\\end{lstlisting}
    \\end{minipage}
    \\caption{Dataset description for """+re.escape(ds[0])+""".}
    \\label{fig:llm_prompt}
\\end{figure}""")

In [None]:
df = [{'Name': ds[0], '# Features': ds[1].shape[1], '# Samples': ds[1].shape[0], '# Classes': len(np.unique(ds[2]))
      , 'OpenML ID / Kaggle Name': cc_test_datasets_multiclass_df.iloc[i].did if i < len(cc_test_datasets_multiclass_df) else ''} for i, ds in enumerate(cc_test_datasets_multiclass)]
print(pd.DataFrame(df).set_index('Name').to_latex())


### Create a stripplot of results

In [None]:
method = 'transformer'
metric = 'roc'

In [None]:
diff_ds = (df_all_grouped_by_ds[df_all_grouped_by_ds.method == method].groupby(by=["name", "prompt", "method"]).agg({metric: ['mean']})
 .groupby(by=["name", "method"])).diff().groupby(by=["name"]).max()

In [None]:
df_all_grouped_by_ds['diff'] = df_all_grouped_by_ds.apply(lambda x : diff_ds.loc[x['name']], axis=1)

In [None]:
df_all_grouped_by_ds = df_all_grouped_by_ds.sort_values(by=['diff'], ascending=False)

In [None]:
df_all_grouped_by_ds[df_all_grouped_by_ds.prompt == ""].prompt = "none"

In [None]:
ren = {'blood-transfusion-service-center': 'blood-transfus..',
        'jungle_chess_2pcs_raw_endgame_complete': 'jungle_chess..',
       'bank-marketing': 'bank-market..',
       'kaggle_spaceship-titanic': '[Kaggle] spaceship-titanic',
       'kaggle_playground-series-s3e12': '[Kaggle] kidney-stone',
       'kaggle_health-insurance-lead-prediction-raw-data': '[Kaggle] health-insurance',
       'kaggle_pharyngitis': '[Kaggle] pharyngitis'
       
      }
df_all_grouped_by_ds.name = df_all_grouped_by_ds.name.apply(lambda x : ren[x] if x in ren else x)

In [None]:
plotting.draw_stripplot(
    df_all_grouped_by_ds[df_all_grouped_by_ds.method == method], x=metric, y="name", hue="prompt", size=(15, 6)
    , xbound=[0.5, 1.05]
#, legend_labels=['Using CAFE', 'Using DFS', ]
, legend_title=' '
, legend_loc='upper left')
plt.subplots_adjust(left=0.2, right=1.0, top=1.0, bottom=0.0)
import tikzplotlib
plt.savefig(f"results_{method}_{metric}.pdf")
#tikzplotlib.save(f"results_{method}_{metric}.tex")
