In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import itertools

# Setup
input_dir = Path("results")
# Create separate directories for the two types of plots
plot_dir_comparison = Path("plots_model_comparison")
plot_dir_aggregated = Path("plots_aggregated_summary")
plot_dir_comparison.mkdir(parents=True, exist_ok=True)
plot_dir_aggregated.mkdir(parents=True, exist_ok=True)

df = pd.read_csv(input_dir / "sensitivity_results_plus_avg.csv")

# Define constants for the analysis
arches = df['arch'].unique()
metrics = {
    "nll": ("nll_id", "nll_ood", "Negative Log-Likelihood"),
    "ece": ("ece_id", "ece_ood", "Expected Calibration Error"),
    "brier": ("brier_id", "brier_ood", "Brier Score"),
}
angles = sorted(df['rotation'].unique())
hp_to_sweep = 'prior_precision'
context_hps = ['hessian_structure', 'link_approx']

# Get all possible combinations of the context hyperparameters
context_values = [df[hp].unique() for hp in context_hps]
context_combinations = list(itertools.product(*context_values))

# DETAILED SIDE-BY-SIDE MODEL COMPARISON PLOTS
print(f"--- PART 1: Generating {len(context_combinations) * len(metrics)} detailed model comparison plots ---")
print(f"Plots will be saved to: {plot_dir_comparison}")

for context_combo in context_combinations:
    context = dict(zip(context_hps, context_combo))
    for mkey, (col_id, col_ood, mname) in metrics.items():
        # Create a 1x2 grid of subplots, sharing the Y-axis
        fig, axes = plt.subplots(1, 2, figsize=(20, 8), sharey=True)
        
        for i, arch in enumerate(arches):
            ax = axes[i]
            df_arch = df[df['arch'] == arch]
            query = ' & '.join([f'`{k}` == {repr(v)}' for k, v in context.items()])
            df_context = df_arch.query(query)

            if df_context.empty: continue

            id_means = df_context.groupby(hp_to_sweep)[col_id].mean()
            ax.plot(id_means.index, id_means.values, marker='o', linestyle='-', color='black', linewidth=2.5, label='ID')

            palette = sns.color_palette("viridis_r", n_colors=len(angles))
            for j, angle in enumerate(angles):
                df_angle = df_context[df_context['rotation'] == angle]
                if df_angle.empty: continue
                ood_means = df_angle.groupby(hp_to_sweep)[col_ood].mean()
                ax.plot(ood_means.index, ood_means.values, marker='.', linestyle='--', color=palette[j], label=f'OOD {angle}°')
            
            ax.set_xscale('log')
            ax.set_title(f'Model: {arch}', fontsize=14)
            ax.grid(True, which='major', linestyle='--', alpha=0.6)
            ax.set_xlabel('Prior Precision (log scale)')

            # Add a legend to each individual subplot
            ax.legend(loc="upper right", title="Distribution", fontsize='medium')
            
            # Set Y-label only on the first plot since it's shared
            if i == 0:
                ax.set_ylabel(mname)
        
        # Remove old figure-level legend and x-label code
        
        context_str = ', '.join([f'{k.split("_")[0]}={v}' for k, v in context.items()])
        fig.suptitle(f"Model Comparison | {mname} vs. Prior Precision\nContext: {context_str} (Averaged over Temperature)", fontsize=18)
        
        # Adjust layout for the new horizontal orientation
        fig.tight_layout(rect=[0, 0, 1, 0.93])

        fname_context = '_'.join([f'{k.split("_")[0]}{v}' for k, v in context.items()])
        filename = plot_dir_comparison / f"model_comparison_{mkey}_vs_{hp_to_sweep}_context_{fname_context}_avg_temp.png"
        plt.savefig(filename)
        plt.close(fig)

print("\n--- PART 1 Complete ---\n")

# AGGREGATED SUMMARY PLOTS (UNCHANGED) ---
print(f"--- PART 2: Generating {len(context_combinations) * len(metrics)} aggregated summary plots ---")
print(f"Plots will be saved to: {plot_dir_aggregated}")

for context_combo in context_combinations:
    context = dict(zip(context_hps, context_combo))
    for mkey, (col_id, col_ood, mname) in metrics.items():
        plt.figure(figsize=(10, 7))
        
        query = ' & '.join([f'`{k}` == {repr(v)}' for k, v in context.items()])
        df_context = df.query(query)

        if df_context.empty: continue
        
        id_means = df_context.groupby(hp_to_sweep)[col_id].mean()
        plt.plot(id_means.index, id_means.values, marker='o', linestyle='-', color='black', linewidth=2.5, label='ID')

        palette = sns.color_palette("viridis_r", n_colors=len(angles))
        for i, angle in enumerate(angles):
            df_angle = df_context[df_context['rotation'] == angle]
            if df_angle.empty: continue
            ood_means = df_angle.groupby(hp_to_sweep)[col_ood].mean()
            plt.plot(ood_means.index, ood_means.values, marker='.', linestyle='--', color=palette[i], label=f'OOD {angle}°')

        plt.xscale('log')
        plt.xlabel('Prior Precision (log scale)')
        plt.ylabel(mname)
        
        context_str = ', '.join([f'{k.split("_")[0]}={v}' for k, v in context.items()])
        title = f"Aggregated Result | {mname} vs. Prior Precision\nContext: {context_str} (Averaged over Models & Temperature)"
        plt.title(title, fontsize=14)
        
        plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left", title="Distribution")
        
        plt.grid(True, which='major', linestyle='--', alpha=0.6)
        plt.tight_layout(rect=[0, 0, 0.85, 0.95])

        fname_context = '_'.join([f'{k.split("_")[0]}{v}' for k, v in context.items()])
        filename = plot_dir_aggregated / f"aggregated_{mkey}_vs_{hp_to_sweep}_context_{fname_context}_avg_temp.png"
        plt.savefig(filename)
        plt.close()

print("\n--- PART 2 Complete ---")
print("\nAnalysis finished.")

--- PART 1: Generating 12 detailed model comparison plots ---
Plots will be saved to: plots_model_comparison

--- PART 1 Complete ---

--- PART 2: Generating 12 aggregated summary plots ---
Plots will be saved to: plots_aggregated_summary

--- PART 2 Complete ---

Analysis finished.


In [3]:
print("\n--- Generating special case summary plot for (kron, bridge) context [FINAL CORRECTED LAYOUT] ---")

# Define the specific context we are interested in
special_context = {
    'hessian_structure': 'kron',
    'link_approx': 'bridge'
}

# Create the figure with 1 row and 3 columns of subplots
fig, axes = plt.subplots(1, 3, figsize=(22, 6))

# Filter the main dataframe for this specific context
query = ' & '.join([f'`{k}` == {repr(v)}' for k, v in special_context.items()])
df_context = df.query(query)

# Loop through the metrics to create each subplot
for i, (mkey, (col_id, col_ood, mname)) in enumerate(metrics.items()):
    ax = axes[i]
    
    # Plot ID performance (averaged over models & temps)
    id_means = df_context.groupby('prior_precision')[col_id].mean()
    ax.plot(id_means.index, id_means.values, marker='o', linestyle='-', color='black', linewidth=2.5, label='ID')

    # Plot OOD performance (averaged over models & temps)
    palette = sns.color_palette("viridis_r", n_colors=len(angles))
    for j, angle in enumerate(angles):
        df_angle = df_context[df_context['rotation'] == angle]
        if df_angle.empty: continue
        ood_means = df_angle.groupby('prior_precision')[col_ood].mean()
        ax.plot(ood_means.index, ood_means.values, marker='.', linestyle='--', color=palette[j], label=f'OOD {angle}°')

    # Formatting each subplot
    ax.set_xscale('log')
    ax.set_xlabel('Prior Precision (log scale)')
    ax.set_ylabel(mname)
    ax.set_title(f'Metric: {mname}', fontsize=14)
    ax.grid(True, which='major', linestyle='--', alpha=0.6)
    
    #Add the legend only to the last (rightmost) subplot
    if i == len(metrics) - 1:
        ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", title="Distribution")


# Final formatting for the entire figure
context_str = ', '.join([f'{k.split("_")[0]}={v}' for k, v in special_context.items()])
fig.suptitle(f"Aggregated Result for Insensitive Context\nContext: {context_str} (Averaged over Models & Temperature)", fontsize=16)

#Use tight_layout with a rect that reserves space on the right for the legend
fig.tight_layout(rect=[0, 0, 0.9, 0.9]) # rect=[left, bottom, right, top]


# Saving the combined plot
plot_dir_aggregated = Path("plots_aggregated_summary")
plot_dir_aggregated.mkdir(parents=True, exist_ok=True)
save_path = plot_dir_aggregated / "special_case_summary_kron_bridge_horizontal_FINAL.png"
plt.savefig(save_path, bbox_inches='tight')
plt.close(fig)

print(f"Final special case summary plot saved to: {save_path}")


--- Generating special case summary plot for (kron, bridge) context [FINAL CORRECTED LAYOUT] ---
Final special case summary plot saved to: plots_aggregated_summary\special_case_summary_kron_bridge_horizontal_FINAL.png


In [6]:
# # ==============================================================================
# # --- PART 3: SENSITIVITY TO HESSIAN STRUCTURE (WITH TEMPERATURE SUBPLOTS) ---
# # ==============================================================================
# print("\n--- PART 3: Generating plots for 'hessian_structure' sensitivity with temperature subplots ---")
# hp_to_sweep_hessian = 'hessian_structure'
# context_hps_hessian = ['prior_precision', 'link_approx']
# temperature_values = sorted(df['temperature'].unique())

# # Create a new directory for these plots
# plot_dir_hessian = Path("plots_hessian_comparison")
# plot_dir_hessian.mkdir(parents=True, exist_ok=True)
# print(f"Plots will be saved to: {plot_dir_hessian}")

# # Get all possible combinations of the context hyperparameters
# context_values_hessian = [df[hp].unique() for hp in context_hps_hessian]
# context_combinations_hessian = list(itertools.product(*context_values_hessian))

# # --- Main Loop ---
# for context_combo in context_combinations_hessian:
#     context = dict(zip(context_hps_hessian, context_combo))
    
#     # Create one plot for each metric (NLL, ECE, Brier)
#     for mkey, (col_id, col_ood, mname) in metrics.items():
        
#         # Filter the main dataframe for this primary context
#         query = ' & '.join([f'`{k}` == {repr(v)}' for k, v in context.items()])
#         df_context = df.query(query)

#         if df_context.empty:
#             continue
            
#         # Create a figure with 1 row and 3 columns of subplots, one for each temperature
#         fig, axes = plt.subplots(1, 3, figsize=(24, 7), sharey=True)
        
#         # Loop through the temperatures and subplot axes
#         for i, temp_val in enumerate(temperature_values):
#             ax = axes[i]
#             df_temp = df_context[df_context['temperature'] == temp_val]
            
#             if df_temp.empty:
#                 ax.set_title(f'Temperature = {temp_val}\n(No data for this context)')
#                 continue

#             # Use lineplot to show performance degradation over rotation angle
#             sns.lineplot(
#                 data=df_temp, x='rotation', y=col_ood, hue=hp_to_sweep_hessian,
#                 style=hp_to_sweep_hessian, markers=True, dashes=False, errorbar='sd', ax=ax
#             )
            
#             # Add horizontal lines for the corresponding ID performance
#             id_means = df_temp.groupby(hp_to_sweep_hessian)[col_id].mean()
#             palette = sns.color_palette(n_colors=len(id_means))
#             for j, (hess_type, id_mean) in enumerate(id_means.items()):
#                 ax.axhline(id_mean, ls='--', color=palette[j], label=f'ID ({hess_type})')

#             # --- Formatting each subplot ---
#             ax.set_xlabel('Shift Intensity (Rotation Angle °)')
#             ax.set_title(f'Temperature = {temp_val}', fontsize=12)
#             ax.legend(title="Hessian", loc="upper left")
#             ax.grid(True, linestyle='--', alpha=0.6)
            
#             # Set Y-label only on the first plot since it's shared
#             if i == 0:
#                 ax.set_ylabel(mname)

#         # --- Formatting the entire figure ---
#         context_str = ', '.join([f'{k.replace("_", " ")}={v}' for k, v in context.items()])
#         fig.suptitle(f"Aggregated Result | {mname} vs. Shift Intensity\nComparing Hessian Structures | Context: {context_str}", fontsize=16)
        
#         fig.tight_layout(rect=[0, 0, 1, 0.94])

#         # --- Saving the plot with a unique filename ---
#         fname_context = '_'.join([f'{k.split("_")[0]}{v}' for k, v in context.items()])
#         filename = plot_dir_hessian / f"hessian_comparison_by_temp_{mkey}_context_{fname_context}.png"
#         plt.savefig(filename)
#         plt.close(fig)

# print(f"\n--- PART 3 Complete ---")


# ==============================================================================
# --- PART 3: SENSITIVITY TO HESSIAN STRUCTURE ---
# ==============================================================================
print("\n--- PART 3: Generating plots for 'hessian_structure' sensitivity ---")
hp_to_sweep_hessian = 'hessian_structure'
context_hps_hessian = ['prior_precision', 'link_approx']

# Create a new directory for these plots
plot_dir_hessian = Path("plots_hessian_comparison_avgtemp")
plot_dir_hessian.mkdir(parents=True, exist_ok=True)
print(f"Plots will be saved to: {plot_dir_hessian}")

# Get all possible combinations of the context hyperparameters
context_values_hessian = [df[hp].unique() for hp in context_hps_hessian]
context_combinations_hessian = list(itertools.product(*context_values_hessian))

# --- Main Loop ---
for context_combo in context_combinations_hessian:
    context = dict(zip(context_hps_hessian, context_combo))
    
    # Create one plot for each metric (NLL, ECE, Brier)
    for mkey, (col_id, col_ood, mname) in metrics.items():
        
        # Filter the main dataframe for this specific context
        query = ' & '.join([f'`{k}` == {repr(v)}' for k, v in context.items()])
        df_context = df.query(query)

        if df_context.empty:
            continue
            
        # Create a plot comparing diag vs kron for the aggregated model results
        plt.figure(figsize=(10, 7))
        
        # Use lineplot to show performance degradation over rotation angle
        ax = sns.lineplot(
            data=df_context, x='rotation', y=col_ood, hue=hp_to_sweep_hessian,
            style=hp_to_sweep_hessian, markers=True, dashes=False, errorbar='sd'
        )
        
        # Add horizontal lines for the corresponding ID performance
        id_means = df_context.groupby(hp_to_sweep_hessian)[col_id].mean()
        # Get the same colors that lineplot used
        palette = sns.color_palette(n_colors=len(id_means))
        for i, (hess_type, id_mean) in enumerate(id_means.items()):
            ax.axhline(id_mean, ls='--', color=palette[i], label=f'ID ({hess_type})')

        # --- Formatting the plot ---
        ax.set_xlabel('Shift Intensity (Rotation Angle °)')
        ax.set_ylabel(mname)
        
        # Create a descriptive title with the full context
        context_str = ', '.join([f'{k.replace("_", " ")}={v}' for k, v in context.items()])
        title = f"Aggregated Result | {mname} vs. Shift Intensity\nComparing Hessian Structures | Context: {context_str}"
        ax.set_title(title, fontsize=14)
        
        # Manually create and place the legend to combine lineplot and axhline labels
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles=handles, labels=labels, title="Hessian Structure", loc="upper left")
        
        ax.grid(True, linestyle='--', alpha=0.6)
        plt.tight_layout(rect=[0, 0, 1, 0.95])

        # --- Saving the plot with a unique filename ---
        fname_context = '_'.join([f'{k.split("_")[0]}{v}' for k, v in context.items()])
        filename = plot_dir_hessian / f"hessian_comparison_{mkey}_context_{fname_context}.png"
        plt.savefig(filename)
        plt.close()

print(f"\n--- PART 3 Complete ---")


--- PART 3: Generating plots for 'hessian_structure' sensitivity ---
Plots will be saved to: plots_hessian_comparison_avgtemp

--- PART 3 Complete ---


In [None]:
# =====================================================================================
# --- PART 4: SUMMARY - Hessian Performance at Smallest and Biggest Prior Precisions ---
# =====================================================================================
print("\n--- PART 4: Generating summary plots for Hessian Structure vs. Prior Precision extremes ---")

# --- Configuration for this summary ---
metrics_to_plot = ['ece', 'nll']
link_contexts = ['bridge', 'probit']
hp_to_compare = 'hessian_structure'
context_hp = 'prior_precision'

# Create a new directory for these summary plots
plot_dir_summary = Path("plots_hessian_summary")
plot_dir_summary.mkdir(parents=True, exist_ok=True)
print(f"Summary plots will be saved to: {plot_dir_summary}")

# --- Main Loop to create the 4 figures ---
for metric_key in metrics_to_plot:
    for link_ctx in link_contexts:
        
        # --- Data Filtering ---
        df_summary = df[df['link_approx'] == link_ctx]
        
        # *** MODIFIED: Find only the lowest and highest prior precision values ***
        min_precision = df_summary[context_hp].min()
        max_precision = df_summary[context_hp].max()
        precision_points = [min_precision, max_precision]

        # Get the relevant metric columns and name
        col_id, col_ood, mname = metrics[metric_key]

        # --- Plotting ---
        # *** MODIFIED: Create a figure with 1 row and 2 columns of subplots ***
        fig, axes = plt.subplots(1, 2, figsize=(20, 8), sharey=True)

        for i, precision_val in enumerate(precision_points):
            ax = axes[i]
            df_plot = df_summary[df_summary[context_hp] == precision_val]

            if df_plot.empty:
                ax.set_title(f'{context_hp} = {precision_val:.1e}\n(No data for this context)')
                continue
            
            # Use lineplot to compare diag vs kron
            sns.lineplot(
                data=df_plot, x='rotation', y=col_ood, hue=hp_to_compare,
                style=hp_to_compare, markers=True, dashes=False, errorbar='sd', ax=ax
            )
            
            # Add horizontal lines for ID performance
            id_means = df_plot.groupby(hp_to_compare)[col_id].mean()
            palette = sns.color_palette(n_colors=len(id_means))
            for j, (hess_type, id_mean) in enumerate(id_means.items()):
                ax.axhline(id_mean, ls='--', color=palette[j], label=f'ID ({hess_type})')
                
            # --- Formatting each subplot ---
            ax.set_xlabel('Shift Intensity (Rotation Angle °)')
            ax.set_title(f'{context_hp} = {precision_val:.1e}', fontsize=14)
            ax.legend(title="Hessian", loc="best")
            ax.grid(True, linestyle='--', alpha=0.6)
            
            # Set shared Y-label only on the first plot
            if i == 0:
                ax.set_ylabel(mname)

        # --- Final Figure Formatting ---
        fig.suptitle(f"Hessian Performance at Extreme Prior Precisions\nAggregated Result | Context: link approx = {link_ctx}", fontsize=18)
        fig.tight_layout(rect=[0, 0, 1, 0.93])

        # --- Saving -
        filename = plot_dir_summary / f"summary_hessian_vs_precision_extremes_for_{metric_key}_{link_ctx}.png"
        plt.savefig(filename)
        plt.close(fig)

print(f"\n--- PART 4 Complete: {len(metrics_to_plot) * len(link_contexts)} summary plots generated. ---")


--- PART 4: Generating summary plots for Hessian Structure vs. Prior Precision extremes ---
Summary plots will be saved to: plots_hessian_summary

--- PART 4 Complete: 4 summary plots generated. ---


In [13]:
# =====================================================================================
# --- PART 4: SUMMARY - Model-Specific Hessian Performance at Extreme Precisions ---
# =====================================================================================
print("\n--- PART 4: Generating model-specific summary plots for Hessian Structure vs. Prior Precision extremes ---")

# --- Configuration for this summary ---
metrics_to_plot = ['ece', 'nll']
link_contexts = ['bridge', 'probit']
hp_to_compare = 'hessian_structure'
context_hp = 'prior_precision'

# Create a new directory for these summary plots
plot_dir_summary = Path("plots_hessian_summary2")
plot_dir_summary.mkdir(parents=True, exist_ok=True)
print(f"Summary plots will be saved to: {plot_dir_summary}")

# --- Main Loop to create the 4 figures ---
for metric_key in metrics_to_plot:
    for link_ctx in link_contexts:
        
        # --- Data Filtering for the specific link_approx context ---
        df_summary = df[df['link_approx'] == link_ctx]
        
        # Find the lowest and highest prior precision values
        min_precision = df_summary[context_hp].min()
        max_precision = df_summary[context_hp].max()
        precision_points = [min_precision, max_precision]

        # Get the relevant metric columns and name
        col_id, col_ood, mname = metrics[metric_key]

        # --- Plotting ---
        # *** MODIFIED: Create a figure with a 2x2 grid of subplots ***
        fig, axes = plt.subplots(2, 2, figsize=(20, 14), sharex=True, sharey='row')

        # --- Nested loops to populate the 2x2 grid ---
        for row_idx, arch in enumerate(arches):
            for col_idx, precision_val in enumerate(precision_points):
                ax = axes[row_idx, col_idx]
                
                # Filter data for the specific model and precision value
                df_plot = df_summary[(df_summary['arch'] == arch) & (df_summary[context_hp] == precision_val)]

                if df_plot.empty:
                    ax.set_title(f'Precision = {precision_val:.1e}\n(No data)')
                    continue
                
                # Use lineplot to compare diag vs kron
                sns.lineplot(
                    data=df_plot, x='rotation', y=col_ood, hue=hp_to_compare,
                    style=hp_to_compare, markers=True, dashes=False, errorbar='sd', ax=ax
                )
                
                # Add horizontal lines for ID performance
                id_means = df_plot.groupby(hp_to_compare)[col_id].mean()
                palette = sns.color_palette(n_colors=len(id_means))
                for j, (hess_type, id_mean) in enumerate(id_means.items()):
                    ax.axhline(id_mean, ls='--', color=palette[j], label=f'ID ({hess_type})')
                    
                # --- Formatting each subplot ---
                ax.legend(title="Hessian", loc="best")
                ax.grid(True, linestyle='--', alpha=0.6)
                
                # Set titles and labels strategically to avoid clutter
                if row_idx == 0:
                    ax.set_title(f'{context_hp} = {precision_val:.1e}', fontsize=14)
                if col_idx == 0:
                    ax.set_ylabel(f'Model: {arch}\n\n{mname}') # Use Y-label to show model
                if row_idx == 1:
                    ax.set_xlabel('Shift Intensity (Rotation Angle °)')


        # --- Final Figure Formatting ---
        fig.suptitle(f"Hessian Performance at Extreme Prior Precisions | Context: link approx = {link_ctx}", fontsize=18)
        fig.tight_layout(rect=[0, 0.03, 1, 0.95])

        # --- Saving ---
        filename = plot_dir_summary / f"model_specific_summary_hessian_vs_precision_for_{metric_key}_{link_ctx}.png"
        plt.savefig(filename)
        plt.close(fig)

print(f"\n--- PART 4 Complete: {len(metrics_to_plot) * len(link_contexts)} model-specific summary plots generated. ---")


--- PART 4: Generating model-specific summary plots for Hessian Structure vs. Prior Precision extremes ---
Summary plots will be saved to: plots_hessian_summary2

--- PART 4 Complete: 4 model-specific summary plots generated. ---


In [10]:
# ==============================================================================
# --- PART 5: SENSITIVITY TO LINK APPROXIMATION (WITH TEMPERATURE SUBPLOTS) ---
# ==============================================================================
print("\n--- PART 5: Generating plots for 'link_approx' sensitivity with temperature subplots ---")
hp_to_sweep_link = 'link_approx'
context_hps_link = ['prior_precision', 'hessian_structure']
temperature_values = sorted(df['temperature'].unique())

# Create a new directory for these plots
plot_dir_link = Path("plots_link_comparison")
plot_dir_link.mkdir(parents=True, exist_ok=True)
print(f"Plots will be saved to: {plot_dir_link}")

# Get all possible combinations of the context hyperparameters
context_values_link = [df[hp].unique() for hp in context_hps_link]
context_combinations_link = list(itertools.product(*context_values_link))

# --- Main Loop ---
for context_combo in context_combinations_link:
    context = dict(zip(context_hps_link, context_combo))
    
    # Create one plot for each metric (NLL, ECE, Brier)
    for mkey, (col_id, col_ood, mname) in metrics.items():
        
        # Filter the main dataframe for this primary context
        query = ' & '.join([f'`{k}` == {repr(v)}' for k, v in context.items()])
        df_context = df.query(query)

        if df_context.empty:
            continue
            
        # Create a figure with 1 row and 3 columns of subplots, one for each temperature
        fig, axes = plt.subplots(1, 3, figsize=(24, 7), sharey=True)
        
        # Loop through the temperatures and subplot axes
        for i, temp_val in enumerate(temperature_values):
            ax = axes[i]
            df_temp = df_context[df_context['temperature'] == temp_val]
            
            if df_temp.empty:
                ax.set_title(f'Temperature = {temp_val}\n(No data for this context)')
                continue

            # Use lineplot to show performance degradation over rotation angle
            sns.lineplot(
                data=df_temp, x='rotation', y=col_ood, hue=hp_to_sweep_link,
                style=hp_to_sweep_link, markers=True, dashes=False, errorbar='sd', ax=ax,
                palette='magma' # Use a different palette for variety
            )
            
            # Add horizontal lines for the corresponding ID performance
            id_means = df_temp.groupby(hp_to_sweep_link)[col_id].mean()
            palette = sns.color_palette('magma', n_colors=len(id_means))
            for j, (link_type, id_mean) in enumerate(id_means.items()):
                ax.axhline(id_mean, ls='--', color=palette[j], label=f'ID ({link_type})')

            # --- Formatting each subplot ---
            ax.set_xlabel('Shift Intensity (Rotation Angle °)')
            ax.set_title(f'Temperature = {temp_val}', fontsize=12)
            ax.legend(title="Link Approx", loc="best")
            ax.grid(True, linestyle='--', alpha=0.6)
            
            # Set Y-label only on the first plot since it's shared
            if i == 0:
                ax.set_ylabel(mname)

        # --- Formatting the entire figure ---
        context_str = ', '.join([f'{k.replace("_", " ")}={v}' for k, v in context.items()])
        fig.suptitle(f"Aggregated Result | {mname} vs. Shift Intensity\nComparing Link Approximations | Context: {context_str}", fontsize=16)
        
        fig.tight_layout(rect=[0, 0, 1, 0.94])

        # --- Saving the plot with a unique filename ---
        fname_context = '_'.join([f'{k.split("_")[0]}{v}' for k, v in context.items()])
        filename = plot_dir_link / f"link_comparison_by_temp_{mkey}_context_{fname_context}.png"
        plt.savefig(filename)
        plt.close(fig)

print(f"\n--- PART 5 Complete ---")


--- PART 5: Generating plots for 'link_approx' sensitivity with temperature subplots ---
Plots will be saved to: plots_link_comparison

--- PART 5 Complete ---


In [12]:
# =================================================================================
# --- PART 6: SENSITIVITY TO LINK APPROXIMATION (AVERAGED OVER TEMPERATURE) ---
# =================================================================================
print("\n--- PART 6: Generating plots for 'link_approx' sensitivity, averaging over temperature ---")
hp_to_sweep_link_avg = 'link_approx'
context_hps_link_avg = ['prior_precision', 'hessian_structure']

# --- Create a new directory for these plots ---
plot_dir_link_avg = Path("plots_link_comparison_avg_temp")
plot_dir_link_avg.mkdir(parents=True, exist_ok=True)
print(f"Plots will be saved to: {plot_dir_link_avg}")

# Get all possible combinations of the context hyperparameters
context_values_link_avg = [df[hp].unique() for hp in context_hps_link_avg]
context_combinations_link_avg = list(itertools.product(*context_values_link_avg))

# --- Main Loop ---
for context_combo in context_combinations_link_avg:
    context = dict(zip(context_hps_link_avg, context_combo))
    
    # Create one plot for each metric (NLL, ECE, Brier)
    for mkey, (col_id, col_ood, mname) in metrics.items():
        
        # Filter the main dataframe for this specific context (temperature is not filtered)
        query = ' & '.join([f'`{k}` == {repr(v)}' for k, v in context.items()])
        df_context = df.query(query)

        if df_context.empty:
            continue
            
        # Create a plot comparing probit vs bridge for the aggregated model results
        plt.figure(figsize=(10, 7))
        
        # Use lineplot; seaborn will automatically average over the temperatures for each point
        ax = sns.lineplot(
            data=df_context, x='rotation', y=col_ood, hue=hp_to_sweep_link_avg,
            style=hp_to_sweep_link_avg, markers=True, dashes=False, errorbar='sd',
            palette='magma'
        )
        
        # Add horizontal lines for the corresponding ID performance (also averaged over temp)
        id_means = df_context.groupby(hp_to_sweep_link_avg)[col_id].mean()
        palette = sns.color_palette('magma', n_colors=len(id_means))
        for i, (link_type, id_mean) in enumerate(id_means.items()):
            ax.axhline(id_mean, ls='--', color=palette[i], label=f'ID ({link_type})')

        # --- Formatting the plot ---
        ax.set_xlabel('Shift Intensity (Rotation Angle °)')
        ax.set_ylabel(mname)
        
        # Create a descriptive title with the full context
        context_str = ', '.join([f'{k.replace("_", " ")}={v}' for k, v in context.items()])
        title = f"Aggregated Result | {mname} vs. Shift Intensity\nComparing Link Approximations | Context: {context_str} (Avg. over Temp)"
        ax.set_title(title, fontsize=14)
        
        # Manually create and place the legend
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles=handles, labels=labels, title="Link Approx", loc="best")
        
        ax.grid(True, linestyle='--', alpha=0.6)
        plt.tight_layout(rect=[0, 0, 1, 0.95])

        # --- Saving the plot with a unique filename ---
        fname_context = '_'.join([f'{k.split("_")[0]}{v}' for k, v in context.items()])
        filename = plot_dir_link_avg / f"link_comparison_{mkey}_context_{fname_context}_avg_temp.png"
        plt.savefig(filename)
        plt.close()

print(f"\n--- PART 5 Complete ---")



--- PART 6: Generating plots for 'link_approx' sensitivity, averaging over temperature ---
Plots will be saved to: plots_link_comparison_avg_temp

--- PART 5 Complete ---
