In [None]:
# import pandas as pd
# import matplotlib.pyplot as plt
# from pathlib import Path

# # --- Config ---
# input_dir = Path("results")
# plot_dir = Path("plots")
# plot_dir.mkdir(parents=True, exist_ok=True)

# # --- Load Data ---
# df_id = pd.read_csv(input_dir / "last_layer_laplace_metrics_id.csv")
# df_ood = pd.read_csv(input_dir / "last_layer_laplace_metrics_ood.csv")

# # --- Grouping Keys for Grid ---
# group_keys = ['model_type', 'hessian_structure', 'link_approx', 'diagonal_output']

# # --- Plot ---
# for group_vals, df_id_group in df_id.groupby(group_keys):
#     df_ood_group = df_ood[
#         (df_ood['model_type'] == group_vals[0]) &
#         (df_ood['hessian_structure'] == group_vals[1]) &
#         (df_ood['link_approx'] == group_vals[2]) &
#         (df_ood['diagonal_output'] == group_vals[3])
#     ]
    
#     fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)
#     title_str = f"Model={group_vals[0]} | Hess={group_vals[1]} | Link={group_vals[2]} | DiagOut={group_vals[3]}"
    
#     x_labels = df_id_group['prior_precision'].astype(str) + "\nτ=" + df_id_group['temperature'].astype(str)
    
#     # ID subplot
#     axes[0].errorbar(
#         x_labels,
#         df_id_group['acc_id_mean'], yerr=df_id_group['acc_id_std'],
#         fmt='-o', label='Accuracy'
#     )
#     axes[0].errorbar(
#         x_labels,
#         df_id_group['ece_id_mean'], yerr=df_id_group['ece_id_std'],
#         fmt='-s', label='ECE'
#     )
#     axes[0].set_title("In-Distribution (MNIST)")
#     axes[0].set_ylabel("Metric")
#     axes[0].set_xlabel("Prior Precision / Temperature")
#     axes[0].legend()
#     axes[0].tick_params(axis='x', rotation=45)
    
#     # OOD subplot
#     x_labels_ood = df_ood_group['prior_precision'].astype(str) + "\nτ=" + df_ood_group['temperature'].astype(str)
#     axes[1].errorbar(
#         x_labels_ood,
#         df_ood_group['acc_ood_mean'], yerr=df_ood_group['acc_ood_std'],
#         fmt='-o', label='Accuracy'
#     )
#     axes[1].errorbar(
#         x_labels_ood,
#         df_ood_group['ece_ood_mean'], yerr=df_ood_group['ece_ood_std'],
#         fmt='-s', label='ECE'
#     )
#     axes[1].set_title("Out-of-Distribution (FashionMNIST)")
#     axes[1].set_xlabel("Prior Precision / Temperature")
#     axes[1].legend()
#     axes[1].tick_params(axis='x', rotation=45)
    
#     fig.suptitle(title_str, fontsize=14)
#     plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    
#     # Save figure
#     fname = f"{group_vals[0]}_{group_vals[1]}_{group_vals[2]}_diag{group_vals[3]}.png".replace(".", "")
#     fig.savefig(plot_dir / fname)
#     plt.close()


In [5]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# ========== CONFIG ==========
input_dir = Path("results")  # ← your folder with the CSV files
plot_dir = Path("plots")
plot_dir.mkdir(parents=True, exist_ok=True)

# ========== LOAD ==========
df_id = pd.read_csv(input_dir / "last_layer_laplace_metrics_id.csv")
df_ood = pd.read_csv(input_dir / "last_layer_laplace_metrics_ood.csv")

# Add config ID
key_cols = ['model_type', 'prior_precision', 'temperature', 'hessian_structure', 'link_approx', 'diagonal_output']
df_id['config_id'] = df_id[key_cols].astype(str).agg('-'.join, axis=1)
df_ood['config_id'] = df_ood[key_cols].astype(str).agg('-'.join, axis=1)

# Rename columns for clarity after merging
df_id = df_id.add_suffix('_id')
df_ood = df_ood.add_suffix('_ood')

# ========== MERGE ID & OOD ==========
df = pd.merge(df_id, df_ood, left_on='config_id_id', right_on='config_id_ood', suffixes=('', ''))
df['config_id'] = df['config_id_id']

# ========== 1. HEATMAPS ==========
for metric in ['acc_id_mean_id', 'ece_id_mean_id']:
    for model in df['model_type_id'].unique():
        for hess in df['hessian_structure_id'].unique():
            subset = df[(df['model_type_id'] == model) & (df['hessian_structure_id'] == hess)]
            if subset.empty:
                continue
            pivot = subset.pivot_table(
                index='prior_precision_id', columns='temperature_id', values=metric, aggfunc='mean'
            )
            plt.figure(figsize=(8, 6))
            sns.heatmap(pivot, annot=True, fmt=".3f", cmap="viridis")
            plt.title(f"{metric} | {model} | {hess}")
            plt.ylabel("Prior Precision")
            plt.xlabel("Temperature")
            plt.tight_layout()
            plt.savefig(plot_dir / f"heatmap_{metric}_{model}_{hess}.png")
            plt.close()

# ========== 2. BOX PLOTS ==========
for metric in ['acc_id_mean_id', 'ece_id_mean_id', 'acc_ood_mean_ood', 'ece_ood_mean_ood']:
    factor = 'hessian_structure_id'
    plt.figure(figsize=(10, 6))
    sns.boxplot(data=df, x=factor, y=metric)
    plt.title(f"{metric} by {factor}")
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(plot_dir / f"box_{metric}_by_{factor}.png")
    plt.close()

# ========== 3. ECE vs Accuracy SCATTER ==========
for mode in ['id', 'ood']:
    plt.figure(figsize=(10, 7))
    sns.scatterplot(
        data=df,
        x=f'ece_{mode}_mean_{mode}',
        y=f'acc_{mode}_mean_{mode}',
        hue='model_type_id',
        style='hessian_structure_id',
        s=80
    )
    plt.title(f"Accuracy vs ECE ({mode.upper()})")
    plt.xlabel("ECE")
    plt.ylabel("Accuracy")
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.savefig(plot_dir / f"scatter_acc_vs_ece_{mode}.png")
    plt.close()

# ========== 4. TOP-N CONFIGURATIONS ==========
top_n = 10
metrics = {
    'acc_id_mean': 'acc_id_mean_id',
    'acc_ood_mean': 'acc_ood_mean_ood',
    'ece_id_mean': 'ece_id_mean_id',
    'ece_ood_mean': 'ece_ood_mean_ood'
}

for label, metric in metrics.items():
    ascending = 'ece' in label
    top_df = df.sort_values(by=metric, ascending=ascending).head(top_n)
    plt.figure(figsize=(10, 6))
    sns.barplot(data=top_df, x=metric, y='config_id', palette='crest')
    plt.title(f"Top-{top_n} configs by {label}")
    plt.xlabel(label)
    plt.ylabel("Config ID")
    plt.tight_layout()
    plt.savefig(plot_dir / f"top_{top_n}_{label}.png")
    plt.close()



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(data=top_df, x=metric, y='config_id', palette='crest')

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(data=top_df, x=metric, y='config_id', palette='crest')

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(data=top_df, x=metric, y='config_id', palette='crest')

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(data=top_df, x=metric, y='config_id', palette='crest')


In [None]:
# import pandas as pd
# import matplotlib.pyplot as plt
# import seaborn as sns
# from pathlib import Path

# # ========== CONFIG ==========
# input_dir = Path("results")  # ← CSVs in current directory
# plot_dir = Path("plots_insightful_all")
# plot_dir.mkdir(parents=True, exist_ok=True)

# # ========== LOAD ==========
# df_id = pd.read_csv(input_dir / "last_layer_laplace_metrics_id.csv")
# df_ood = pd.read_csv(input_dir / "last_layer_laplace_metrics_ood.csv")

# # Add config ID
# key_cols = ['model_type', 'prior_precision', 'temperature', 'hessian_structure', 'link_approx', 'diagonal_output']
# df_id['config_id'] = df_id[key_cols].astype(str).agg('-'.join, axis=1)
# df_ood['config_id'] = df_ood[key_cols].astype(str).agg('-'.join, axis=1)

# # Rename columns for clarity after merging
# df_id = df_id.add_suffix('_id')
# df_ood = df_ood.add_suffix('_ood')
# df = pd.merge(df_id, df_ood, left_on='config_id_id', right_on='config_id_ood')
# df['config_id'] = df['config_id_id']

# # ========== DERIVED METRICS ==========
# df['delta_acc'] = df['acc_id_mean_id'] - df['acc_ood_mean_ood']
# df['delta_ece'] = df['ece_ood_mean_ood'] - df['ece_id_mean_id']
# df['ece_ratio'] = df['ece_ood_mean_ood'] / df['ece_id_mean_id'].replace(0, 1e-8)

# # ========== USEFUL PLOTS ==========

# # 1. ID vs OOD accuracy (scatter)
# plt.figure(figsize=(8, 8))
# sns.scatterplot(data=df, x='acc_id_mean_id', y='acc_ood_mean_ood', hue='hessian_structure_id', style='model_type_id', s=80)
# plt.plot([0, 1], [0, 1], 'k--')
# plt.title("ID vs OOD Accuracy")
# plt.tight_layout()
# plt.savefig(plot_dir / "scatter_id_vs_ood_accuracy.png")
# plt.close()

# # 2. delta accuracy vs ECE ratio
# plt.figure(figsize=(10, 7))
# sns.scatterplot(data=df, x='ece_ratio', y='delta_acc', hue='model_type_id', style='hessian_structure_id')
# plt.axhline(0, linestyle='--', color='gray')
# plt.title("Δ Accuracy vs ECE Ratio")
# plt.tight_layout()
# plt.savefig(plot_dir / "scatter_delta_acc_vs_ece_ratio.png")
# plt.close()

# # 3. delta accuracy per hessian structure
# plt.figure(figsize=(10, 6))
# sns.boxplot(data=df, x='hessian_structure_id', y='delta_acc', hue='model_type_id')
# plt.title("Δ Accuracy per Hessian Structure")
# plt.tight_layout()
# plt.savefig(plot_dir / "box_delta_acc_per_hessian.png")
# plt.close()

# # 4. ECE OOD / ID heatmap by temperature & prior
# for model in df['model_type_id'].unique():
#     for hess in df['hessian_structure_id'].unique():
#         subset = df[(df['model_type_id'] == model) & (df['hessian_structure_id'] == hess)]
#         if not subset.empty:
#             pivot = subset.pivot_table(index='prior_precision_id', columns='temperature_id', values='ece_ratio', aggfunc='mean')
#             plt.figure(figsize=(8, 6))
#             sns.heatmap(pivot, annot=True, fmt=".2f", cmap="coolwarm")
#             plt.title(f"ECE Ratio | {model} | {hess}")
#             plt.tight_layout()
#             plt.savefig(plot_dir / f"heatmap_ece_ratio_{model}_{hess}.png")
#             plt.close()

# # 5. delta accuracy heatmap
# for model in df['model_type_id'].unique():
#     for hess in df['hessian_structure_id'].unique():
#         subset = df[(df['model_type_id'] == model) & (df['hessian_structure_id'] == hess)]
#         if not subset.empty:
#             pivot = subset.pivot_table(index='prior_precision_id', columns='temperature_id', values='delta_acc', aggfunc='mean')
#             plt.figure(figsize=(8, 6))
#             sns.heatmap(pivot, annot=True, fmt=".3f", cmap="viridis")
#             plt.title(f"Δ Accuracy | {model} | {hess}")
#             plt.tight_layout()
#             plt.savefig(plot_dir / f"heatmap_delta_acc_{model}_{hess}.png")
#             plt.close()

# # 6. Lineplot of OOD acc vs temperature grouped by prior + model
# for model in df['model_type_id'].unique():
#     plt.figure(figsize=(10, 6))
#     sns.lineplot(
#         data=df[df['model_type_id'] == model],
#         x='temperature_id', y='acc_ood_mean_ood',
#         hue='prior_precision_id', style='hessian_structure_id', markers=True, dashes=False
#     )
#     plt.title(f"OOD Accuracy vs Temperature | {model}")
#     plt.tight_layout()
#     plt.savefig(plot_dir / f"line_ood_accuracy_vs_temp_{model}.png")
#     plt.close()

# # 7. Swarm plot of ECE OOD
# plt.figure(figsize=(12, 6))
# sns.swarmplot(data=df, x='hessian_structure_id', y='ece_ood_mean_ood', hue='model_type_id', dodge=True)
# plt.title("ECE OOD by Hessian Structure")
# plt.tight_layout()
# plt.savefig(plot_dir / "swarm_ece_ood_by_hessian.png")
# plt.close()

# # 8. Violin + Stripplot: OOD accuracy
# plt.figure(figsize=(12, 6))
# sns.violinplot(data=df, x='hessian_structure_id', y='acc_ood_mean_ood', hue='model_type_id', inner=None)
# sns.stripplot(data=df, x='hessian_structure_id', y='acc_ood_mean_ood', hue='model_type_id', dodge=True, marker='o', alpha=0.5)
# plt.title("OOD Accuracy Distribution by Hessian + Model")
# plt.tight_layout()
# plt.savefig(plot_dir / "violin_ood_accuracy_dist.png")
# plt.close()

# # 9. Correlation matrix
# metrics_only = df[['acc_id_mean_id', 'acc_ood_mean_ood', 'ece_id_mean_id', 'ece_ood_mean_ood', 'delta_acc', 'delta_ece', 'ece_ratio']]
# corr = metrics_only.corr()
# plt.figure(figsize=(10, 8))
# sns.heatmap(corr, annot=True, cmap='coolwarm')
# plt.title("Correlation between Metrics")
# plt.tight_layout()
# plt.savefig(plot_dir / "heatmap_metric_correlations.png")
# plt.close()
