In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme(style="whitegrid")

In [None]:
# === CONFIGURATION ===
RESULTS_DIR = "../results/swdpf_folds"  # Base directory containing results per scale/fold
SCALES = [10, 20, 30]  # Percent of labeled data
FOLDS = list(range(8))  # Fold indices (0–7)

In [None]:
# === Load fold-level results ===
all_results = []

for scale in SCALES:
    for fold in FOLDS:
        file_path = os.path.join(RESULTS_DIR, f"scale_{scale}", f"fold{fold}.csv")
        if os.path.exists(file_path):
            df = pd.read_csv(file_path)
            df['scale'] = scale
            df['fold'] = fold
            all_results.append(df)
        else:
            print(f"Missing file: {file_path}")

results_df = pd.concat(all_results, ignore_index=True)


In [None]:
# === Aggregate metrics by scale ===
summary_df = results_df.groupby("scale").agg(
    rmse_mean=("rmse", "mean"),
    rmse_std=("rmse", "std"),
    mae_mean=("mae", "mean"),
    mae_std=("mae", "std"),
    rse_mean=("rse", "mean"),
    rse_std=("rse", "std"),
    fold_count=("fold", "count")
).reset_index()

In [None]:
# === Display fold-level results ===
print("Fold-level evaluation results:")
display(results_df.sort_values(["scale", "fold"]))

# === Display summary across folds ===
print("Aggregated summary by scale:")
display(summary_df)

Here we visualize the trends of each evaluation metric (RMSE, MAE, RSE) with respect to the amount of labeled data.

In [None]:
def plot_metric(metric_name, y_label):
    plt.figure(figsize=(8, 5))
    ax = sns.barplot(
        data=summary_df,
        x="scale",
        y=f"{metric_name}_mean",
        yerr=summary_df[f"{metric_name}_std"],
        palette="crest"
    )
    ax.set_title(f"{metric_name.upper()} vs. % Labeled Data")
    ax.set_xlabel("Labeled Data (%)")
    ax.set_ylabel(y_label)
    plt.tight_layout()
    plt.show()

# === Plot all key metrics ===
plot_metric("rmse", "Root Mean Squared Error (RMSE)")
plot_metric("mae", "Mean Absolute Error (MAE)")
plot_metric("rse", "Relative Squared Error (RSE)")