In [None]:
import sys
import os

# Insert the parent folder of this notebook
notebook_path = os.getcwd()
sys.path.append(os.path.dirname(notebook_path))

import os
import copy
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from hydra import compose, initialize
from hydra.utils import instantiate
from hydra.core.global_hydra import GlobalHydra
import lightning as L
import json
import wandb
from topobench.data.preprocessor import PreProcessor
from topobench.dataloader import TBDataloader

# CONFIGURATION
METRIC = 'mae'  # Using MAE for triangle counting (lower is better)
PROJECT_NAME = 'graphuniverse/final_triangle_experiments'
PARAMETER_TO_VARY = 'degree_separation_range'
# Example of how to add checkpoint paths (you'll need to adapt this to your setup)
# Assuming you have a checkpoint directory structure like: ../checkpoints/model_name/
CHECKPOINT_BASE_DIR = "../checkpoints"

# Set publication style
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams.update({
    'font.size': 12,
    'font.family': 'serif',
    'font.serif': ['Times New Roman', 'DejaVu Serif'],
    'axes.linewidth': 0.8,
    'axes.spines.top': False,
    'axes.spines.right': False,
    'axes.grid': True,
    'grid.alpha': 0.3,
    'grid.linewidth': 0.5,
    'xtick.direction': 'in',
    'ytick.direction': 'in',
    'errorbar.capsize': 3,
    'figure.dpi': 300
})

def safe_get_nested(obj, keys, default=None):
    """Safely get nested dictionary values"""
    try:
        for key in keys:
            obj = obj[key]
        return obj
    except (KeyError, TypeError, IndexError):
        return default

def load_triangle_experiment_data():
    """Load data from triangle counting experiments"""
    api = wandb.Api(timeout=100)
    
    try:
        runs = api.runs(PROJECT_NAME)
        
        all_data = []
        for run in runs:
            summary = run.summary._json_dict
            summary['run_name'] = run.name
            summary['run_id'] = run.id
            
            # Add config parameters with 'config_' prefix
            config = run.config
            for key, value in config.items():
                summary[f'config_{key}'] = value
            
            all_data.append(summary)
        
        if len(all_data) == 0:
            print(f"No runs found in {PROJECT_NAME}")
            return pd.DataFrame()
        
        df = pd.DataFrame(all_data)
        return df
        
    except Exception as e:
        print(f"Error loading triangle experiments: {e}")
        return pd.DataFrame()

def dict_to_sorted_string_no_seed(d, seed_keys=['seed', 'random_seed', 'random_state']):
    """Convert nested dict to consistent sorted string, excluding seed-related keys"""
    if d is None:
        return "None"
    
    def remove_seeds(obj):
        if isinstance(obj, dict):
            return {k: remove_seeds(v) for k, v in obj.items() if k not in seed_keys}
        elif isinstance(obj, list):
            return [remove_seeds(item) for item in obj]
        else:
            return obj
    
    cleaned_dict = remove_seeds(d)
    return json.dumps(cleaned_dict, sort_keys=True, separators=(',', ':'))

def extract_transform_info(config_transforms, model_name):
    """Extract relevant transform information based on model type"""
    if pd.isna(config_transforms) or config_transforms is None:
        return "no_transform"
    
    if model_name in ['GPS', 'nsd']:
        # Extract encodings for GPS and NSD
        if 'CombinedPSEs' in config_transforms:
            encodings = config_transforms['CombinedPSEs'].get('encodings', [])
            if encodings:
                return '_'.join(sorted(encodings))
        return "no_encoding"
    
    elif model_name == 'topotune':
        return "cell_lifting"
    
    else:
        return "no_transform"

def extract_parameter_value(df, parameter_name):
    """Extract the parameter value from dataset configuration"""
    def extract_param(config_dataset):
        if pd.isna(config_dataset) or config_dataset is None:
            return None
        
        # Try different possible locations for the parameter
        param_value = safe_get_nested(config_dataset, 
            ['loader', 'parameters', 'generation_parameters', 'family_parameters', parameter_name])
        
        if param_value is None:
            # Try direct location
            param_value = safe_get_nested(config_dataset, 
                ['loader', 'parameters', parameter_name])
        
        # Convert to tuple if it's a list for consistency
        if isinstance(param_value, list):
            param_value = tuple(param_value)
        
        return param_value
    
    df[parameter_name] = df['config_dataset'].apply(extract_param)
    return df

def process_triangle_data(df):
    """Process triangle counting experiment data with parameter extraction"""
    if df.empty:
        return df, None, None
    
    # Extract model names
    df['model_name'] = df['config_model'].apply(
        lambda x: x.get('model_name') if isinstance(x, dict) else None
    )
    
    # Extract the parameter to vary
    df = extract_parameter_value(df, PARAMETER_TO_VARY)
    
    # Create model config strings WITHOUT random seeds
    df['model_config_str'] = df['config_model'].apply(
        lambda x: dict_to_sorted_string_no_seed(x) if isinstance(x, dict) else "None"
    )
    
    # Extract transform info
    df['transform_info'] = df.apply(
        lambda row: extract_transform_info(row.get('config_transforms'), row['model_name']), 
        axis=1
    )
    
    # Create enhanced model config string that includes PE info for GPS/NSD
    df['enhanced_model_config_str'] = df.apply(
        lambda row: (row['model_config_str'] + f"_PE_{row['transform_info']}") 
                    if row['model_name'] in ['GPS', 'nsd'] 
                    else row['model_config_str'],
        axis=1
    )
    
    # Define validation and test metrics
    VAL_METRIC = f'val/{METRIC}'
    TEST_METRIC = f'test/{METRIC}'
    
    # Filter out rows with missing essential data
    df_clean = df.dropna(subset=['model_name', 'enhanced_model_config_str', PARAMETER_TO_VARY, VAL_METRIC, TEST_METRIC])
    df_clean = df_clean[df_clean['model_config_str'] != "None"]
    df_clean = df_clean[df_clean[PARAMETER_TO_VARY].notna()]
    
    return df_clean, VAL_METRIC, TEST_METRIC

# MAIN EXECUTION
print("Loading triangle counting experiment data...")
df = load_triangle_experiment_data()

if df.empty:
    print("No data found. Please check the project name and ensure runs exist.")
    exit()

print(f"Loaded {len(df)} runs")

# Process data
df_clean, val_metric, test_metric = process_triangle_data(df)

if df_clean.empty:
    print("No clean data available after processing.")
    exit()

print(f"Clean data: {len(df_clean)} runs")

# Filter for degree_separation_range == [0.0, 0.1] only
target_degree_separation = (0.0, 0.1)
df_filtered = df_clean[df_clean[PARAMETER_TO_VARY] == target_degree_separation]

if df_filtered.empty:
    print(f"No data found for {PARAMETER_TO_VARY} == {target_degree_separation}")
    exit()

print(f"Filtered data for {PARAMETER_TO_VARY} == {target_degree_separation}: {len(df_filtered)} runs")

# Get best model per model type based on validation performance
def get_best_model_configs(df, val_metric):
    """Get best model configuration for each model type based on validation performance."""
    best_models = {}
    
    for model_name in df['model_name'].unique():
        model_data = df[df['model_name'] == model_name]
        
        # Calculate mean validation performance per configuration
        config_performance = model_data.groupby('enhanced_model_config_str')[val_metric].mean()
        
        # For MAE, lower is better
        best_config = config_performance.idxmin()
        best_runs = model_data[model_data['enhanced_model_config_str'] == best_config]
        
        best_models[model_name] = best_runs
    
    return best_models

best_models = get_best_model_configs(df_filtered, val_metric)

print(f"Found best configurations for {len(best_models)} model types:")
for model_name, runs in best_models.items():
    print(f"  {model_name}: {len(runs)} runs")

In [None]:
import time

# Define shift configurations for graph size increases
SHIFT_CONFIGS = [
    {
        'universe_parameters': {},
        'family_parameters': {
            'min_n_nodes_shift': 200,
            'max_n_nodes_shift': 200
        }
    },
    {
        'universe_parameters': {},
        'family_parameters': {
            'min_n_nodes_shift': 500,
            'max_n_nodes_shift': 500
        }
    }
]

class TriangleShiftEvaluator:
    def __init__(self, model_df):
        """Initialize evaluator with model dataframe."""
        self.model_df = model_df
        self._setup_hydra()
        
    def _setup_hydra(self):
        if GlobalHydra().is_initialized():
            GlobalHydra().clear()
        initialize(config_path="../configs", job_name="triangle_shift_evaluation")
    
    def shift_dataset_parameters(self, shift_config, number_of_eval_graphs, run_idx):
        """Apply shifts to dataset generation parameters for triangle counting."""
        dataset_params = self.model_df.loc[run_idx]['config_dataset']['loader']['parameters']['generation_parameters']
        params = copy.deepcopy(dataset_params)
        
        # Set the number of graphs to the original number of graphs
        params['family_parameters']['n_graphs'] = number_of_eval_graphs

        # Universe parameters
        for param in ['center_variance', 'cluster_variance', 'edge_propensity_variance']:
            if f'{param}_shift' in shift_config['universe_parameters']:
                params['universe_parameters'][param] += shift_config['universe_parameters'][f'{param}_shift']
                params['universe_parameters'][param] = np.clip(params['universe_parameters'][param], 0, 1)
        
        # Family parameters - integers
        for param in ['min_n_nodes', 'max_n_nodes', 'min_communities', 'max_communities']:
            if f'{param}_shift' in shift_config['family_parameters']:
                params['family_parameters'][param] += shift_config['family_parameters'][f'{param}_shift']
                if param.startswith('n_graphs'):
                    params['family_parameters'][param] = np.clip(params['family_parameters'][param], 1, 100000)
                elif 'nodes' in param:
                    params['family_parameters'][param] = np.clip(params['family_parameters'][param], 1, 10000)
                elif 'communities' in param:
                    params['family_parameters'][param] = np.clip(params['family_parameters'][param], 1, 1000)
        
        # Family parameters - ranges (including degree_separation_range)
        range_params = ['homophily_range', 'avg_degree_range', 'degree_separation_range', 'power_law_exponent_range']
        for param in range_params:
            for i, bound in enumerate(['min', 'max']):
                shift_key = f'{bound}_{param}_shift'
                if shift_key in shift_config['family_parameters']:
                    params['family_parameters'][param][i] += shift_config['family_parameters'][shift_key]
                    
                    # Apply appropriate clipping
                    if 'homophily' in param:
                        params['family_parameters'][param][i] = np.clip(params['family_parameters'][param][i], 0, 1)
                    elif 'avg_degree' in param:
                        params['family_parameters'][param][i] = np.clip(params['family_parameters'][param][i], 1, 100)
                    elif 'degree_separation' in param:
                        params['family_parameters'][param][i] = np.clip(params['family_parameters'][param][i], 0, 100)
                    elif 'power_law' in param:
                        params['family_parameters'][param][i] = np.clip(params['family_parameters'][param][i], 0, 10)
        
        # Update seed if provided
        if 'seed' in shift_config['family_parameters']:
            params['family_parameters']['seed'] = shift_config['family_parameters']['seed']
        
        return params

    def config_to_overrides(self, config, prefix="", include_graphuniverse_parameters=False):
        """Convert config to overrides using recursion with proper dot notation."""
        overrides = []
        
        for key, value in config.items():
            current_path = f"{prefix}.{key}" if prefix else key
            
            if key == 'generation_parameters' and not include_graphuniverse_parameters:
                continue
                
            if isinstance(value, dict):
                overrides.extend(self.config_to_overrides(
                    value, 
                    prefix=current_path, 
                    include_graphuniverse_parameters=include_graphuniverse_parameters
                ))
            elif isinstance(value, list):
                overrides.append(f"{current_path}={value}")
            else:
                if value is None:
                    overrides.append(f"{current_path}=null")
                else:
                    overrides.append(f"{current_path}={value}")
        
        return overrides

    def prepare_all_overrides(self, shift_config, run_idx):
        """Prepare Hydra config with shifts applied."""
        data_config = self.model_df.loc[run_idx]['config_dataset']
        model_config = self.model_df.loc[run_idx]['config_model']
        transform_config = self.model_df.loc[run_idx]['config_transforms']

        # Calc the original n_graphs for test set 
        train_prop = data_config['split_params']['train_prop']
        eval_prop = (1 - train_prop)
        n_graphs = data_config['loader']['parameters']['generation_parameters']['family_parameters']['n_graphs']
        n_graphs_eval = int(n_graphs * eval_prop)

        shifted_params = self.shift_dataset_parameters(shift_config, n_graphs_eval, run_idx)
        
        all_overrides = []
        model_name = model_config['model_name']

        model_config_map = {
            'gcn': 'graph/gcn',
            'gat': 'graph/gat', 
            'GPS': 'graph/gps',
            'nsd': 'graph/nsd',
            'gin': 'graph/gin',
            'DeepSet': 'pointcloud/deepset',
            'topotune': 'cell/topotune',
            'GraphMLP': 'graph/graph_mlp',
            'GraphSAGE': 'graph/sage',
        }
        
        all_overrides.append(f"model={model_config_map[model_name]}")
        all_overrides.append(f"model.model_name={model_name}")
        
        model_overrides = self.config_to_overrides(model_config, prefix="model")
        all_overrides.extend(model_overrides)

        data_overrides = self.config_to_overrides(data_config, prefix="dataset", include_graphuniverse_parameters=False)
        # Find the index of the train_prop override
        train_prop_idx = [i for i, override in enumerate(data_overrides) if 'train_prop' in override]
        # Replace the train_prop with the new train_prop
        data_overrides[train_prop_idx[0]] = f"dataset.split_params.train_prop=0.0"

        # Now add the data_overrides
        all_overrides.extend(data_overrides)

        if model_name in ['GPS', 'nsd']:
            if pd.isna(transform_config) or transform_config is None:
                transform_overrides = ['transforms.CombinedPSEs.encodings=[]']
            else:
                transform_overrides = self.config_to_overrides(transform_config, prefix="transforms")
            all_overrides.extend(transform_overrides)
            
        shifted_params_overrides = []
        for param_group in ['universe_parameters', 'family_parameters']:
            for key, value in shifted_params[param_group].items():
                shifted_params_overrides.append(
                    f"dataset.loader.parameters.generation_parameters.{param_group}.{key}={value}"
                )
        all_overrides.extend(shifted_params_overrides)

        return all_overrides
         
    def evaluate_single_model(self, overrides, run_idx):
        """Evaluate a single model with given overrides."""
        cfg = compose(config_name="run.yaml", overrides=overrides)
        model = instantiate(cfg.model, evaluator=cfg.evaluator, optimizer=cfg.optimizer, loss=cfg.loss)
        
        # Note: For triangle counting, we need to get the checkpoint path differently
        # Assuming the checkpoint path is stored similar to the community detection case
        if 'checkpoint_local' in self.model_df.columns:
            checkpoint_path = self.model_df.loc[run_idx]['checkpoint_local']
        else:
            print("No checkpoint_local column found. Cannot evaluate model.")
            return None
        
        # Add error handling for checkpoint loading
        try:
            if not os.path.exists(checkpoint_path):
                print(f"Checkpoint not found: {checkpoint_path}")
                return None
                
            checkpoint = torch.load(checkpoint_path, map_location="cpu")
            model.load_state_dict(checkpoint["state_dict"], strict=True)
            
        except Exception as e:
            print(f"Error loading checkpoint {checkpoint_path}: {e}")
            return None

        loader = instantiate(cfg.dataset.loader)
        dataset, dataset_dir = loader.load()
        transform_config = cfg.get("transforms", None)
        preprocessor = PreProcessor(dataset, dataset_dir, transform_config)
        _, _, dataset_test = preprocessor.load_dataset_splits(cfg.dataset.split_params)
        
        datamodule = TBDataloader(
            dataset_train=None,
            dataset_val=None, 
            dataset_test=dataset_test,
            **cfg.dataset.get("dataloader_params", {}),
        )
        
        trainer = L.Trainer(
            devices=1,
            accelerator='auto',
            logger=False,
            enable_checkpointing=False,
            num_sanity_val_steps=0,
        )
        
        test_results = trainer.test(model, datamodule)


        # For the triangle counting task, I also want to calc mae/total_triangles
        # Calc the average amount of triangles in the test set
        test_dataloader = datamodule.test_dataloader()
        all_triangles = []
        for batch in test_dataloader:
            all_triangles.append(batch.y.float().mean())
        print("average amount of triangles: ", np.mean(all_triangles))

        # Add mae/average_num_of_triangles to the test_results
        test_results[0]['mae'] = test_results[0]['test/mae']
        test_results[0]['average_num_of_triangles'] = np.mean(all_triangles)
        test_results[0]['mae/average_num_of_triangles'] = test_results[0]['mae'] / test_results[0]['average_num_of_triangles']
        print("mae/average_num_of_triangles: ", test_results[0]['mae/average_num_of_triangles'])

        return test_results[0] if test_results else {}

    def evaluate_multiple_shifts(self, shift_configs, metrics=['test/mae']):
        """Evaluate model performance across multiple distribution shifts."""
        results = []
        
        # Always include a zero shift (baseline)
        zero_shift = {
            'universe_parameters': {},
            'family_parameters': {}
        }

        # Store results by run_idx to maintain pairing
        paired_results = {}
        
        # Combine zero shift with provided shifts
        all_shifts = [zero_shift] + shift_configs
        
        for run_idx in self.model_df.index:
            print(f"Evaluating run {run_idx}")
            paired_results[run_idx] = {}
            for shift_idx, shift_config in enumerate(all_shifts):
                shift_name = self.get_shift_title(shift_config)
                if shift_name == "No Shift":
                    shift_name = "Baseline"
                    
                print(f"Evaluating shift: {shift_name}")
                # Evaluate all models in the DataFrame
                try:
                    # Prepare overrides for this shift
                    overrides = self.prepare_all_overrides(shift_config, run_idx)

                    # Evaluate the model
                    result = self.evaluate_single_model(overrides, run_idx)
                    
                    # Check if result is None (checkpoint not found or error)
                    if result is None:
                        print(f"Skipping run {run_idx} with shift {shift_name} - checkpoint issue")
                        continue
                        
                    paired_results[run_idx][shift_name] = result
                    
                except Exception as e:
                    print(f"Error evaluating model {run_idx} with shift {shift_name}: {e}")
                    continue
            
            # If no shifts were successful for this run, remove it
            if not paired_results[run_idx]:
                print(f"No successful evaluations for run {run_idx}, removing from results")
                del paired_results[run_idx]
        
        # Now calculate paired differences AND save baseline values
        results = []
        for metric in metrics:
            # Extract baseline values for this metric
            baseline_values = []
            run_indices = []
            
            for run_idx, shifts in paired_results.items():
                if ('Baseline' in shifts and 
                    shifts['Baseline'] is not None and 
                    metric in shifts['Baseline']):
                    baseline_values.append(shifts['Baseline'][metric])
                    run_indices.append(run_idx)
            
            # Check if we have any baseline values
            if not baseline_values:
                print(f"No baseline values found for metric {metric}")
                continue
            
            # First, add baseline results (one row per baseline evaluation)
            for i, run_idx in enumerate(run_indices):
                baseline_result = paired_results[run_idx]['Baseline']
                results.append({
                    'shift_name': 'Baseline',
                    'metric': metric,
                    'run_idx': run_idx,
                    'value': baseline_values[i],
                    'mae': baseline_result.get('mae', None),
                    'average_num_of_triangles': baseline_result.get('average_num_of_triangles', None),
                    'mae_per_triangle': baseline_result.get('mae/average_num_of_triangles', None),
                    'is_baseline': True,
                    'mean_difference': 0.0,  # No difference for baseline
                    'se_difference': 0.0,
                    'ci_lower': 0.0,
                    'ci_upper': 0.0,
                    'p_value': 1.0,  # No significance test for baseline
                    'n_pairs': 1,
                    'baseline_mean': baseline_values[i],
                    'shift_mean': baseline_values[i]
                })
            
            # Then, for each shift, calculate paired differences
            for shift_name in [s for s in paired_results[list(paired_results.keys())[0]].keys() if s != 'Baseline']:
                shift_values = []
                valid_pairs = []
                shift_detailed_results = []
                
                for i, run_idx in enumerate(run_indices):
                    if (run_idx in paired_results and 
                        shift_name in paired_results[run_idx] and 
                        paired_results[run_idx][shift_name] is not None and
                        metric in paired_results[run_idx][shift_name]):
                        shift_values.append(paired_results[run_idx][shift_name][metric])
                        shift_detailed_results.append(paired_results[run_idx][shift_name])
                        valid_pairs.append(i)
                
                if len(valid_pairs) > 0:
                    baseline_paired = [baseline_values[i] for i in valid_pairs]
                    differences = np.array(shift_values) - np.array(baseline_paired)
                    
                    mean_diff = np.mean(differences)
                    se_diff = np.std(differences) / np.sqrt(len(differences)) if len(differences) > 1 else 0.0
                    
                    from scipy import stats
                    if len(differences) > 1:
                        t_val = stats.t.ppf(0.975, len(differences)-1)
                        ci_95 = t_val * se_diff
                        p_value = stats.ttest_1samp(differences, 0).pvalue
                    else:
                        ci_95 = 0.0
                        p_value = 1.0
                    
                    # Add individual shift results (one row per run)
                    for j, valid_pair_idx in enumerate(valid_pairs):
                        run_idx = run_indices[valid_pair_idx]
                        shift_result = shift_detailed_results[j]
                        results.append({
                            'shift_name': shift_name,
                            'metric': metric,
                            'run_idx': run_idx,
                            'value': shift_values[j],
                            'mae': shift_result.get('mae', None),
                            'average_num_of_triangles': shift_result.get('average_num_of_triangles', None),
                            'mae_per_triangle': shift_result.get('mae/average_num_of_triangles', None),
                            'is_baseline': False,
                            'mean_difference': shift_values[j] - baseline_paired[j],  # Individual difference
                            'se_difference': se_diff,  # Group-level SE
                            'ci_lower': mean_diff - ci_95,  # Group-level CI
                            'ci_upper': mean_diff + ci_95,
                            'p_value': p_value,  # Group-level p-value
                            'n_pairs': len(differences),
                            'baseline_mean': baseline_paired[j],  # Individual baseline
                            'shift_mean': shift_values[j]  # Individual shift value
                        })
                    
                    # Also add summary row with group statistics
                    results.append({
                        'shift_name': f'{shift_name}_SUMMARY',
                        'metric': metric,
                        'run_idx': -1,  # Indicates summary row
                        'value': np.mean(shift_values),
                        'mae': np.mean([r.get('mae', 0) for r in shift_detailed_results]),
                        'average_num_of_triangles': np.mean([r.get('average_num_of_triangles', 0) for r in shift_detailed_results]),
                        'mae_per_triangle': np.mean([r.get('mae/average_num_of_triangles', 0) for r in shift_detailed_results]),
                        'is_baseline': False,
                        'mean_difference': mean_diff,  # Group mean difference
                        'se_difference': se_diff,
                        'ci_lower': mean_diff - ci_95,
                        'ci_upper': mean_diff + ci_95,
                        'p_value': p_value,
                        'n_pairs': len(differences),
                        'baseline_mean': np.mean(baseline_paired),
                        'shift_mean': np.mean(shift_values)
                    })
                else:
                    print(f"No valid pairs found for shift {shift_name} and metric {metric}")
        
        return pd.DataFrame(results)
        
    def get_shift_title(self, shift_config):
        """Generate title from non-zero shift parameters."""
        shift_parts = []
        
        for group_name, group_config in shift_config.items():
            if group_name == 'universe_parameters':
                for param, value in group_config.items():
                    if param.endswith('_shift') and value != 0:
                        param_name = param.replace('_shift', '')
                        shift_parts.append(f"{param_name}: {value:+.2f}")
            elif group_name == 'family_parameters':
                for param, value in group_config.items():
                    if param.endswith('_shift') and value != 0:
                        param_name = param.replace('_shift', '').replace('_range', '')
                        shift_parts.append(f"{param_name}: {value:+}")
        
        if shift_parts == []:
            return "No Shift"
        
        return " | ".join(shift_parts) if shift_parts else "No Shift"
    
    def plot_shift_comparison(self, results_df, metrics=['test/mae'], save_plots=True, group_name=None):
        """Create publication-quality comparison plot for paired differences."""
        
        # Filter metrics to only those present in the data
        available_metrics = results_df['metric'].unique()
        metrics = [m for m in metrics if m in available_metrics]
        
        if not metrics:
            print("No matching metrics found in results_df")
            return None
        
        # Set publication-quality style
        plt.rcParams.update({
            'font.family': 'serif',
            'font.size': 12,
            'axes.labelsize': 13,
        })
        
        # Create subplots for each metric
        fig, axes = plt.subplots(1, len(metrics), figsize=(6*len(metrics), 8), dpi=300)
        if len(metrics) == 1:
            axes = [axes]
        
        for i, metric in enumerate(metrics):
            ax = axes[i]
            metric_data = results_df[results_df['metric'] == metric]
            
            if len(metric_data) == 0:
                continue
            
            # Create bar plot showing differences from baseline
            x_pos = np.arange(len(metric_data))
            
            # For MAE (lower is better), positive differences are worse (red), negative are better (green)
            colors = ['red' if diff > 0 else 'green' for diff in metric_data['mean_difference']]
            
            bars = ax.bar(x_pos, metric_data['mean_difference'], 
                        yerr=metric_data['se_difference'],
                        color=colors,
                        alpha=0.7,
                        capsize=5)
            
            # Add horizontal line at y=0 (no difference)
            ax.axhline(y=0, color='black', linestyle='-', alpha=0.8, linewidth=1)
            
            # Customize appearance
            ax.set_xlabel('Distribution Shift', fontweight='bold')
            ax.set_ylabel(f'Δ {metric.replace("test/", "").upper()}', fontweight='bold')
            ax.set_title(f'{group_name}: Performance Changes\n{metric.replace("test/", "").upper()} (Lower is Better)', 
                        fontweight='bold', pad=20)
            
            # Set x-axis labels
            ax.set_xticks(x_pos)
            ax.set_xticklabels(metric_data['shift_name'], rotation=45, ha='right')
            
            # Add value labels with significance
            for j, (bar, row) in enumerate(zip(bars, metric_data.itertuples())):
                height = bar.get_height()
                y_pos = height + row.se_difference + 0.01 if height >= 0 else height - row.se_difference - 0.01
                
                # Add significance stars
                significance = ""
                if row.p_value < 0.001:
                    significance = "***"
                elif row.p_value < 0.01:
                    significance = "**"
                elif row.p_value < 0.05:
                    significance = "*"
                
                label = f'{row.mean_difference:.3f}±{row.se_difference:.3f}\n{significance}'
                
                ax.text(bar.get_x() + bar.get_width()/2., y_pos,
                    label,
                    ha='center', va='bottom' if height >= 0 else 'top', 
                    fontweight='bold', fontsize=9)
            
            # Styling
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.grid(axis='y', linestyle='--', alpha=0.7)
            
            # Set y-limits
            if len(metric_data) > 0:
                abs_values = abs(metric_data['mean_difference'] + metric_data['se_difference'])
                if len(abs_values) > 0 and max(abs_values) > 0:
                    y_max = max(abs_values) * 1.3
                    ax.set_ylim([-y_max, y_max])
        
        plt.tight_layout()
        
        if save_plots:
            model_name = self.model_df.iloc[0]['config_model']['model_name']
            filename = f"triangle_{str(group_name).replace('(', '').replace(')', '').replace(', ', '_')}_shift_differences_{model_name.lower()}"
            plt.savefig(f"{filename}.png", bbox_inches='tight', dpi=300)
        
        return fig

def run_triangle_shift_evaluation(shift_configs, model_df, group_name):
    """Run comprehensive shift evaluation for triangle counting."""
    
    # Initialize evaluator
    evaluator = TriangleShiftEvaluator(model_df)
    
    # Evaluate all shifts
    results_df = evaluator.evaluate_multiple_shifts(
        shift_configs, 
        metrics=['test/mae']
    )
    
    # Check if we have any results before plotting
    if results_df.empty:
        print(f"No successful evaluations for group {group_name} - skipping plots")
        fig = None
    else:
        # Create comparison plot
        fig = evaluator.plot_shift_comparison(results_df, metrics=['test/mae'], group_name=group_name)
        
        # Print summary
        print("\n" + "="*50)
        print(f"TRIANGLE SHIFT EVALUATION SUMMARY for {group_name}")
        print("="*50)
        for _, row in results_df.iterrows():
            print(f"{row['shift_name']:30} | {row['metric']:12} | "
                  f"Diff: {row['mean_difference']:.3f} ± {row['se_difference']:.3f} "
                  f"(p={row['p_value']:.3f}, n={row['n_pairs']})")
    
    return evaluator, results_df, fig

def add_checkpoint_paths(best_models_dict, checkpoint_base_dir):
    """Add local checkpoint paths to the model dataframes."""
    for model_name, model_df in best_models_dict.items():
        checkpoint_paths = []
        
        for idx, row in model_df.iterrows():
            # Construct checkpoint path based on your directory structure
            # This is just an example - adapt to your actual structure
            checkpoint_dir = os.path.join(checkpoint_base_dir, f"degree_separation_range_(0, 0.1)/{model_name}")
            checkpoint_filename = f"{row['checkpoint'].split('/')[-1]}"  # or however your checkpoints are named
            checkpoint_path = os.path.join(checkpoint_dir, checkpoint_filename)
            checkpoint_paths.append(checkpoint_path)
        
        # Add checkpoint paths to the dataframe
        model_df = model_df.copy()
        model_df['checkpoint_local'] = checkpoint_paths
        best_models_dict[model_name] = model_df
    
    return best_models_dict

# Uncomment and modify this section when you have checkpoint paths set up:
best_models_with_checkpoints = add_checkpoint_paths(best_models, CHECKPOINT_BASE_DIR)

# Run evaluation for each model type
for model_name, model_df in best_models.items():
    print(f"  {model_name}: {len(model_df)} runs")

print("\nShift configurations to test:")
for i, shift_config in enumerate(SHIFT_CONFIGS):
    evaluator_temp = TriangleShiftEvaluator(list(best_models.values())[0])
    shift_name = evaluator_temp.get_shift_title(shift_config)
    print(f"  {i+1}. {shift_name}")

# Run evaluation for each model type
all_results_df = pd.DataFrame()
for model_name, model_df in best_models_with_checkpoints.items():
    print(f'\nEvaluating {model_name} models...')
    print(f'Number of runs: {len(model_df)}')
  
    # Check if checkpoints exist before running evaluation
    missing_checkpoints = []
    for idx, row in model_df.iterrows():
        if not os.path.exists(row['checkpoint_local']):
            missing_checkpoints.append(row['checkpoint_local'])
  
    if missing_checkpoints:
        print(f'Warning: {len(missing_checkpoints)} checkpoints missing for {model_name}')
        print(f'First few missing: {missing_checkpoints[:3]}')
        # You can choose to skip this model or continue with available checkpoints
  
    try:
        evaluator, results_df, fig = run_triangle_shift_evaluation(
            SHIFT_CONFIGS, model_df, f'{model_name}_degree_sep_0.0_0.1'
        )
      
        if not results_df.empty:
            results_df['model_name'] = model_name
            all_results_df = pd.concat([all_results_df, results_df])
            print(f'✓ Added results for {model_name}')
          
            # Display results for this model
            print(f'\nResults for {model_name}:')
            for _, row in results_df.iterrows():
                print(f'  {row["shift_name"]}: MAE change = {row["mean_difference"]:.4f} ± {row["se_difference"]:.4f}')
        else:
            print(f'✗ No results for {model_name}')
          
    except Exception as e:
        print(f'Error evaluating {model_name}: {e}')
        continue
# Summary across all models
if not all_results_df.empty:
    print("\n" + "="*60)
    print("FINAL RESULTS SUMMARY")
    print("="*60)
  
    for model in all_results_df['model_name'].unique():
        model_results = all_results_df[all_results_df['model_name'] == model]
        print(f"\n{model}:")
        for _, row in model_results.iterrows():
            significance = ""
            if row['p_value'] < 0.001:
                significance = "***"
            elif row['p_value'] < 0.01:
                significance = "**"
            elif row['p_value'] < 0.05:
                significance = "*"
          
            print(f"  {row['shift_name']:20}: Δ{row['mean_difference']:+.4f} ± {row['se_difference']:.4f} {significance}")
  
    # Save all results
    all_results_df.to_csv('triangle_shift_evaluation_results.csv', index=False)
    print(f"\n✓ Results saved to 'triangle_shift_evaluation_results.csv'")
else:
    print("No successful evaluations completed.")