In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import itertools

# Setup
input_dir = Path("results")
# Create separate directories for the two types of plots
plot_dir_comparison = Path("plots_model_comparison")
plot_dir_aggregated = Path("plots_aggregated_summary")
plot_dir_comparison.mkdir(parents=True, exist_ok=True)
plot_dir_aggregated.mkdir(parents=True, exist_ok=True)

df = pd.read_csv(input_dir / "sensitivity_results_plus_avg.csv")

# Define constants for the analysis
arches = df['arch'].unique()
metrics = {
    "nll": ("nll_id", "nll_ood", "Negative Log-Likelihood"),
    "ece": ("ece_id", "ece_ood", "Expected Calibration Error"),
    "brier": ("brier_id", "brier_ood", "Brier Score"),
}
angles = sorted(df['rotation'].unique())
hp_to_sweep = 'prior_precision'
context_hps = ['hessian_structure', 'link_approx']

# Get all possible combinations of the context hyperparameters
context_values = [df[hp].unique() for hp in context_hps]
context_combinations = list(itertools.product(*context_values))

# DETAILED SIDE-BY-SIDE MODEL COMPARISON PLOTS
print(f"--- PART 1: Generating {len(context_combinations) * len(metrics)} detailed model comparison plots ---")
print(f"Plots will be saved to: {plot_dir_comparison}")

for context_combo in context_combinations:
    context = dict(zip(context_hps, context_combo))
    for mkey, (col_id, col_ood, mname) in metrics.items():
        # Create a 1x2 grid of subplots, sharing the Y-axis
        fig, axes = plt.subplots(1, 2, figsize=(20, 8), sharey=True)
        
        for i, arch in enumerate(arches):
            ax = axes[i]
            df_arch = df[df['arch'] == arch]
            query = ' & '.join([f'`{k}` == {repr(v)}' for k, v in context.items()])
            df_context = df_arch.query(query)

            if df_context.empty: continue

            id_means = df_context.groupby(hp_to_sweep)[col_id].mean()
            ax.plot(id_means.index, id_means.values, marker='o', linestyle='-', color='black', linewidth=2.5, label='ID')

            palette = sns.color_palette("viridis_r", n_colors=len(angles))
            for j, angle in enumerate(angles):
                df_angle = df_context[df_context['rotation'] == angle]
                if df_angle.empty: continue
                ood_means = df_angle.groupby(hp_to_sweep)[col_ood].mean()
                ax.plot(ood_means.index, ood_means.values, marker='.', linestyle='--', color=palette[j], label=f'OOD {angle}°')
            
            ax.set_xscale('log')
            ax.set_title(f'Model: {arch}', fontsize=14)
            ax.grid(True, which='major', linestyle='--', alpha=0.6)
            ax.set_xlabel('Prior Precision (log scale)')

            # Add a legend to each individual subplot
            ax.legend(loc="upper right", title="Distribution", fontsize='medium')
            
            # Set Y-label only on the first plot since it's shared
            if i == 0:
                ax.set_ylabel(mname)
        
        # Remove old figure-level legend and x-label code
        
        context_str = ', '.join([f'{k.split("_")[0]}={v}' for k, v in context.items()])
        fig.suptitle(f"Model Comparison | {mname} vs. Prior Precision\nContext: {context_str} (Averaged over Temperature)", fontsize=18)
        
        # Adjust layout for the new horizontal orientation
        fig.tight_layout(rect=[0, 0, 1, 0.93])

        fname_context = '_'.join([f'{k.split("_")[0]}{v}' for k, v in context.items()])
        filename = plot_dir_comparison / f"model_comparison_{mkey}_vs_{hp_to_sweep}_context_{fname_context}_avg_temp.png"
        plt.savefig(filename)
        plt.close(fig)

print("\n--- PART 1 Complete ---\n")

# AGGREGATED SUMMARY PLOTS (UNCHANGED) ---
print(f"--- PART 2: Generating {len(context_combinations) * len(metrics)} aggregated summary plots ---")
print(f"Plots will be saved to: {plot_dir_aggregated}")

for context_combo in context_combinations:
    context = dict(zip(context_hps, context_combo))
    for mkey, (col_id, col_ood, mname) in metrics.items():
        plt.figure(figsize=(10, 7))
        
        query = ' & '.join([f'`{k}` == {repr(v)}' for k, v in context.items()])
        df_context = df.query(query)

        if df_context.empty: continue
        
        id_means = df_context.groupby(hp_to_sweep)[col_id].mean()
        plt.plot(id_means.index, id_means.values, marker='o', linestyle='-', color='black', linewidth=2.5, label='ID')

        palette = sns.color_palette("viridis_r", n_colors=len(angles))
        for i, angle in enumerate(angles):
            df_angle = df_context[df_context['rotation'] == angle]
            if df_angle.empty: continue
            ood_means = df_angle.groupby(hp_to_sweep)[col_ood].mean()
            plt.plot(ood_means.index, ood_means.values, marker='.', linestyle='--', color=palette[i], label=f'OOD {angle}°')

        plt.xscale('log')
        plt.xlabel('Prior Precision (log scale)')
        plt.ylabel(mname)
        
        context_str = ', '.join([f'{k.split("_")[0]}={v}' for k, v in context.items()])
        title = f"Aggregated Result | {mname} vs. Prior Precision\nContext: {context_str} (Averaged over Models & Temperature)"
        plt.title(title, fontsize=14)
        
        plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left", title="Distribution")
        
        plt.grid(True, which='major', linestyle='--', alpha=0.6)
        plt.tight_layout(rect=[0, 0, 0.85, 0.95])

        fname_context = '_'.join([f'{k.split("_")[0]}{v}' for k, v in context.items()])
        filename = plot_dir_aggregated / f"aggregated_{mkey}_vs_{hp_to_sweep}_context_{fname_context}_avg_temp.png"
        plt.savefig(filename)
        plt.close()

print("\n--- PART 2 Complete ---")
print("\nAnalysis finished.")

--- PART 1: Generating 12 detailed model comparison plots ---
Plots will be saved to: plots_model_comparison

--- PART 1 Complete ---

--- PART 2: Generating 12 aggregated summary plots ---
Plots will be saved to: plots_aggregated_summary

--- PART 2 Complete ---

Analysis finished.


In [None]:
print("\n--- Generating special case summary plot for (kron, bridge) context [FINAL CORRECTED LAYOUT] ---")

# Define the specific context we are interested in
special_context = {
    'hessian_structure': 'kron',
    'link_approx': 'bridge'
}

# Create the figure with 1 row and 3 columns of subplots
fig, axes = plt.subplots(1, 3, figsize=(22, 6))

# Filter the main dataframe for this specific context
query = ' & '.join([f'`{k}` == {repr(v)}' for k, v in special_context.items()])
df_context = df.query(query)

# Loop through the metrics to create each subplot
for i, (mkey, (col_id, col_ood, mname)) in enumerate(metrics.items()):
    ax = axes[i]
    
    # Plot ID performance (averaged over models & temps)
    id_means = df_context.groupby('prior_precision')[col_id].mean()
    ax.plot(id_means.index, id_means.values, marker='o', linestyle='-', color='black', linewidth=2.5, label='ID')

    # Plot OOD performance (averaged over models & temps)
    palette = sns.color_palette("viridis_r", n_colors=len(angles))
    for j, angle in enumerate(angles):
        df_angle = df_context[df_context['rotation'] == angle]
        if df_angle.empty: continue
        ood_means = df_angle.groupby('prior_precision')[col_ood].mean()
        ax.plot(ood_means.index, ood_means.values, marker='.', linestyle='--', color=palette[j], label=f'OOD {angle}°')

    # Formatting each subplot
    ax.set_xscale('log')
    ax.set_xlabel('Prior Precision (log scale)')
    ax.set_ylabel(mname)
    ax.set_title(f'Metric: {mname}', fontsize=14)
    ax.grid(True, which='major', linestyle='--', alpha=0.6)
    
    #Add the legend only to the last (rightmost) subplot
    if i == len(metrics) - 1:
        ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", title="Distribution")


# Final formatting for the entire figure
context_str = ', '.join([f'{k.split("_")[0]}={v}' for k, v in special_context.items()])
fig.suptitle(f"Aggregated Result for Insensitive Context\nContext: {context_str} (Averaged over Models & Temperature)", fontsize=16)

#Use tight_layout with a rect that reserves space on the right for the legend
fig.tight_layout(rect=[0, 0, 0.9, 0.9]) # rect=[left, bottom, right, top]


# Saving the combined plot
plot_dir_aggregated = Path("plots_aggregated_summary")
plot_dir_aggregated.mkdir(parents=True, exist_ok=True)
save_path = plot_dir_aggregated / "special_case_summary_kron_bridge_horizontal_FINAL.png"
plt.savefig(save_path, bbox_inches='tight')
plt.close(fig)

print(f"Final special case summary plot saved to: {save_path}")


--- Generating special case summary plot for (kron, bridge) context [FINAL CORRECTED LAYOUT] ---
Final special case summary plot saved to: plots_aggregated_summary\special_case_summary_kron_bridge_horizontal_FINAL.png
