# Import libraries and modules

In [3]:
import warnings
warnings.filterwarnings("ignore")
warnings.filterwarnings(action="ignore", category=DeprecationWarning)
warnings.filterwarnings(action="ignore", category=FutureWarning)

In [9]:
import sys, os
# add project root (parent folder of the notebook) to sys.path
sys.path.append(os.path.abspath("../.."))
import pickle
import pandas as pd

# Evaluation

## Load data

In [7]:
# Change the path to where the pickle file is stored
path = "results/multi_simulation/results_200_tau_45_20251203_015401.pkl"

# Load the pickle
with open(path, "rb") as f:
    results = pickle.load(f)

## Extract all the metric and make the summary

In [10]:
# Step 1: Convert results dict into a DataFrame
records = []
for seed, data in results.items():
    row = {'seed': seed, **data}
    records.append(row)

df = pd.DataFrame(records)

# Step 1: group by layers
metrics = [
    'true_relative_rmse',
    'true_mape',
    'true_spearman_corr',
    'model_relative_rmse',
    'model_mape',
    'model_spearman_corr'
]

grouped = df.groupby('n_layers')[metrics].agg(['mean', 'std'])

# Step 2: Compute ALL means and stds
all_means = df[metrics].mean()
all_stds = df[metrics].std()

# ---- Correct All row construction ----
# Create a flat dict with keys like ('true_mape', 'mean')
all_data = {}
for metric in metrics:
    all_data[(metric, 'mean')] = all_means[metric]
    all_data[(metric, 'std')] = all_stds[metric]

# Build DataFrame with ONE row named 'All'
all_row = pd.DataFrame(all_data, index=['All'])

# Reorder columns exactly like grouped
all_row = all_row[grouped.columns]

# Step 3: Append
grouped_with_all = pd.concat([grouped, all_row])

# Step 4: Format mean ± std
summary = grouped_with_all.copy()
for metric in metrics:
    summary[(metric, 'mean±std')] = (
        summary[(metric, 'mean')].round(4).astype(str)
        + " ± "
        + summary[(metric, 'std')].round(4).astype(str)
    )

formatted = summary[[(metric, 'mean±std') for metric in metrics]]
formatted.columns = metrics

print(formatted.to_string())

    true_relative_rmse         true_mape true_spearman_corr model_relative_rmse         model_mape model_spearman_corr
2      0.7876 ± 0.4364    0.6941 ± 0.381     0.983 ± 0.0502   12.9218 ± 10.2008    9.3105 ± 7.3807     0.7795 ± 0.2875
3      7.7055 ± 7.3577   4.4263 ± 4.0787     0.915 ± 0.1581   20.6487 ± 12.9721  15.9683 ± 10.7167      0.627 ± 0.3752
4     10.4814 ± 9.3757   6.3233 ± 5.7999    0.8715 ± 0.2382   25.5001 ± 13.7165  22.3435 ± 26.9063      0.5005 ± 0.442
5    13.5624 ± 10.8122   9.0031 ± 7.7712    0.8355 ± 0.2192   29.3402 ± 15.3677  24.5187 ± 14.1251      0.401 ± 0.4697
6    14.9618 ± 12.6695  10.1757 ± 9.0035    0.8365 ± 0.2324   32.7568 ± 15.4305  26.8508 ± 14.0662      0.365 ± 0.4883
All   9.4997 ± 10.4357   6.1245 ± 7.0466    0.8883 ± 0.2005   24.2335 ± 15.3126   19.7983 ± 17.256     0.5346 ± 0.4451


In [11]:
formatted_transposed = formatted.T
print(formatted_transposed.to_string())

                                     2                  3                  4                  5                  6                All
true_relative_rmse     0.7876 ± 0.4364    7.7055 ± 7.3577   10.4814 ± 9.3757  13.5624 ± 10.8122  14.9618 ± 12.6695   9.4997 ± 10.4357
true_mape               0.6941 ± 0.381    4.4263 ± 4.0787    6.3233 ± 5.7999    9.0031 ± 7.7712   10.1757 ± 9.0035    6.1245 ± 7.0466
true_spearman_corr      0.983 ± 0.0502     0.915 ± 0.1581    0.8715 ± 0.2382    0.8355 ± 0.2192    0.8365 ± 0.2324    0.8883 ± 0.2005
model_relative_rmse  12.9218 ± 10.2008  20.6487 ± 12.9721  25.5001 ± 13.7165  29.3402 ± 15.3677  32.7568 ± 15.4305  24.2335 ± 15.3126
model_mape             9.3105 ± 7.3807  15.9683 ± 10.7167  22.3435 ± 26.9063  24.5187 ± 14.1251  26.8508 ± 14.0662   19.7983 ± 17.256
model_spearman_corr    0.7795 ± 0.2875     0.627 ± 0.3752     0.5005 ± 0.442     0.401 ± 0.4697     0.365 ± 0.4883    0.5346 ± 0.4451
