In [None]:
from pathlib import Path
import pandas as pd
import json
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
def get_from_json(json_path: Path, key: str) -> pd.DataFrame:
    with open(json_path, "r") as file:
        data = json.load(file)
    if key in data:
        return data[key]
    else:
        raise KeyError(f"Key '{key}' not found in JSON file.")

In [None]:
RESULTS_PATH = Path("../results")

### Synthetic

In [None]:
results_dfs = []
for method in (RESULTS_PATH / "synthetic").iterdir():
    if not method.is_dir():
        continue
    accuracies = [
        get_from_json(p, "mae_by_model") for p in method.rglob("*.json")
    ]
    corr_performances = [
        get_from_json(p, "spearmanr_corr") for p in method.rglob("*.json")
    ]
    results_dfs.append(
        pd.DataFrame(
            {
                "problem": ["synthetic"] * len(accuracies),
                "meta-model": [method.stem] * len(accuracies),
                "mae": accuracies,
                "corr": corr_performances,
            }
        )
    )
synthetic_results_df = pd.concat(results_dfs, ignore_index=True)

In [None]:
results_dfs = []
for method in (RESULTS_PATH / "uci").iterdir():
    if not method.is_dir():
        continue
    accuracies = [get_from_json(p, "accuracy") for p in method.rglob("*.json")]
    results_dfs.append(
        pd.DataFrame(
            {
                "problem": ["uci"] * len(accuracies),
                "meta-model": [method.stem] * len(accuracies),
                "accuracy": accuracies,
            }
        )
    )
uci_results_df = pd.concat(results_dfs, ignore_index=True)

In [None]:
fig, ax = plt.subplots(figsize=(15, 5), ncols=3)

sns.barplot(
    synthetic_results_df, x="meta-model", y="mae", errorbar="se", ax=ax[0]
)
ax[0].set_title("Synthetic - MAE", fontsize=16)
ax[0].set_ylabel("MAE (lower is better)", fontsize=14)
ax[0].set_xlabel("Meta-model", fontsize=14)

sns.barplot(
    synthetic_results_df, x="meta-model", y="corr", errorbar="se", ax=ax[1]
)
ax[1].set_title("Synthetic - Spearman correlation", fontsize=16)
ax[1].set_ylabel("Correlation (higher is better)", fontsize=14)
ax[1].set_xlabel("Meta-model", fontsize=14)

sns.barplot(
    uci_results_df, x="meta-model", y="accuracy", errorbar="se", ax=ax[2]
)
ax[2].set_title("UCI - Accuracy", fontsize=16)
ax[2].set_ylabel("Accuracy (higher is better)", fontsize=14)
ax[2].set_xlabel("Meta-model", fontsize=14)
ax[2].set_ylim(0.8, 0.9)
fig.tight_layout()

plt.savefig("../resutls.png")