In [None]:
from random import seed
import matplotlib.pyplot as plt
import benchmark_visualization as bv
import numpy as np
import glob
import os
import json
import re
from statistics import geometric_mean, mean


import matplotlib as mpl
from matplotlib.ticker import LogLocator, ScalarFormatter, FixedLocator
import matplot2tikz as tikzplotlib


# --- Academic Style Configuration ---
def set_academic_style():
    # Attempt to use LaTeX for text rendering (requires local latex install)
    # If this fails, set text.usetex to False
    try:
        plt.rcParams.update({
            "text.usetex": True,
            "font.family": "serif",
            "font.serif": ["Computer Modern Roman"],
        })
    except:
        print("Warning: LaTeX not found. Text rendering will use standard fonts.")
        plt.rcParams.update({"text.usetex": False, "font.family": "serif"})

    plt.rcParams.update({
        # Font sizes
        "font.size": 11,
        "axes.titlesize": 12,
        "axes.labelsize": 12,
        "legend.fontsize": 10,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        
        # Line styles
        "lines.linewidth": 1.5,
        "lines.markersize": 0, # Usually academic plots rely on lines, markers often clutter dense time series
        
        # Grid (Reference image has a light dotted grid)
        "axes.grid": True,
        "grid.alpha": 0.5,
        "grid.linestyle": ":",
        "grid.linewidth": 0.8,
        
        # Legend (Reference: No frame, inside plot)
        "legend.frameon": False,
        "legend.loc": "lower left",
        "legend.borderpad": 0.2,
    })

# Define the custom color palette based on the reference image (Green, Blue, Brown, Teal, Orange, Purple, Red, Black)
ACADEMIC_COLORS = [
    "#4DAF4A", # Green (KaHyPar-CA)
    "#377EB8", # Blue (KaHyPar-CA-V)
    "#A65628", # Brown
    "#98E0D6", # Teal/Light Blue
    "#FF7F00", # Orange
    "#984EA3", # Purple
    "#E41A1C", # Red
    "#000000", # Black
]

set_academic_style()

HOME_DIR = os.environ['HOME']

SHORT_TIMELIMIT = 7200
LONG_TIMELIMIT = 28800

# Default figure save directory (user may override per call)
FIG_SAVE_DIR = os.path.join(HOME_DIR, 'Documents', 'BA_benchmarks')

runs = {}
class BenchmarkRun:
    HEADER = "algorithm,graph,timeout,seed,k,epsilon,num_threads,imbalance,totalPartitionTime,objective,km1,cut,failed"
    
    def __init__(self, content: str, timelimit: int = None):
        self.data = {}
        for i, key in enumerate(self.HEADER.split(',')):
            value = content.split(',')[i].strip()
            # Try to convert to int or float if possible
            try:
                if '.' in value:
                    value = float(value)
                else:
                    value = int(value)
            except ValueError:
                pass
            self.data[key] = value
        self.data['timelimit'] = timelimit

    def get(self, param):
        return self.data[param]

def parse_results_file(file_path: str):
    results_array = []
    timelimit = file_path.split('.')[-2]
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            results_array += [BenchmarkRun(content=line, timelimit=int(timelimit))]
    return results_array

def aggregate_runs(directory_path: str, instances_to_exclude: list = []):
    runs = {}
    for file_path in glob.glob(directory_path + "/*.results"):
        file_name = file_path.split('/')[-1]
        instance_name = '.'.join(file_name.split('.')[:-2])
        purely_instance_name = '.'.join(instance_name.split('.')[:-3])
        if purely_instance_name in instances_to_exclude:
            continue   
        if instance_name not in runs:
            runs[instance_name] = {"short": None, "long": None}

        run_results = parse_results_file(file_path)
        if len(run_results) > 1:
            min_km1=min([run.get('km1') for run in run_results if run.get('failed') == 'no'])
            runs[instance_name]["short"] = min_km1
        elif run_results:
            runs[instance_name]["long"] = run_results[0].get('km1')
        else:
            pass
            #print(f"Warning: No results in file {file_path}")

    return runs


def convert_instance_naming_scheme(instance_name: str, use_fixed_seed: bool) -> str:
    parts = instance_name.split('.')
    hgr_index = parts.index('hgr')
    base_name = '.'.join(parts[:hgr_index + 1])
    
    # Extract parameters from the rest
    k_value = None
    seed_value = None
    timelimit_value = None

    for part in parts[hgr_index + 1:]:
        if part.startswith('k'):
            k_value = part[1:]
        elif part.startswith('seed'):
            seed_value = part[4:]
        elif part.startswith('timelimit'):
            timelimit_value = part[9:]
    
    # Construct new name: base.threads.k.seed.timelimit
    threads = '1'  # Default to 1 thread if not specified
    if use_fixed_seed:
        seed_value = '1'
    new_name = f"{base_name}.{threads}.{k_value}.{seed_value}.{timelimit_value}"
    return new_name


def parse_end_result_history_file(file_path: str):
    # Read the last number from the file
    with open(file_path, 'r') as f:
        lines = f.readlines()
        if lines:
            last_line = lines[-1].strip()
            km1 = int(last_line.split(',')[-1].strip())
        else:
            km1 = None
    return km1

def parse_history_file(file_path: str):
    history = []
    with open(file_path, 'r') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue         
            if line.startswith('Starttime:'):
                timestamp = int(line.split(':')[1].strip())
                history.append((timestamp, None, None))
            else:
                parts = line.split(',')
                timestamp = int(parts[0].strip())
                mode = parts[1].strip()
                km1_value = int(parts[2].strip())
                history.append((timestamp, mode, km1_value)) 
    return history

def get_average_diff_from_matrix(matrix):
    np_matrix = np.array(matrix)
    mask = ~np.eye(np_matrix.shape[0], dtype=bool)
    off = np_matrix[mask]
    mean = (off.mean() if off.size else 0)
    return mean

def get_average_diff_from_matrices(matrices):
    averages = []
    for matrix in matrices:
        avg = get_average_diff_from_matrix(matrix)
        averages.append(avg)
    return averages


def get_max_diff_from_matrix(matrix):
    np_matrix = np.array(matrix)
    max_val = np_matrix.max()
    return max_val

def get_max_diff_from_matrices(matrices):
    max_values = []
    for matrix in matrices:
        max_val = get_max_diff_from_matrix(matrix)
        max_values.append(max_val)
    return max_values

# --- Figure saving support ---

def _ensure_dir(path: str):
    if path and not os.path.isdir(path):
        os.makedirs(path, exist_ok=True)


def _finalize_figure(fig, show: bool, save: bool, save_dir: str, filename: str, dpi: int):
    if save:
        target_dir = save_dir or FIG_SAVE_DIR
        _ensure_dir(target_dir)
        out_path = os.path.join(target_dir, filename)
        fig.savefig(out_path, dpi=dpi, bbox_inches='tight')
    if show:
        plt.show()
    else:
        plt.close(fig)


def plot_combined_data(combined_data, title: str = "Combined History and Difference Matrices", filename: str = None,
                       show: bool = True, save: bool = False, save_dir: str = None, dpi: int = 150):
    """Plot KM1 and diff metrics over time.
    Params:
      combined_data: list of dicts with 'timestamp','km1','diff_matrix'
      filename: optional explicit filename (e.g. 'run1_combined.png')
      show: display inline in notebook
      save: if True, save image
      save_dir: directory to save (defaults to FIG_SAVE_DIR)
      dpi: resolution for saved figure
    """
    timestamps = [entry['timestamp'] for entry in combined_data]
    km1_values = [entry['km1'] for entry in combined_data]
    avg_diffs = [get_average_diff_from_matrix(entry['diff_matrix']) for entry in combined_data]
    max_diffs = [get_max_diff_from_matrix(entry['diff_matrix']) for entry in combined_data]

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))

    # First subplot: KM1 values
    color = 'tab:blue'
    ax1.set_xlabel('Time (seconds)', fontsize=12)
    ax1.set_ylabel('KM1 Value', color=color, fontsize=12)
    ax1.plot(timestamps, km1_values, color=color, marker='o', label='KM1 Value')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.set_title('KM1 Value over Time', fontsize=14, fontweight='bold')
    ax1.grid(True, alpha=0.3)

    # Second subplot: Average and Max Differences
    ax2.set_xlabel('Time (seconds)', fontsize=12)
    ax2.set_ylabel('Difference Value', fontsize=12)
    ax2.plot(timestamps, avg_diffs, color='tab:red', marker='x', label='Average Difference')
    ax2.plot(timestamps, max_diffs, color='tab:orange', marker='s', label='Max Difference')
    ax2.legend(loc='upper right')
    ax2.set_title('Difference Metrics over Time', fontsize=14, fontweight='bold')
    ax2.grid(True, alpha=0.3)

    fig.suptitle(title, fontsize=16, fontweight='bold')
    plt.tight_layout()

    if filename is None:
        safe = title.lower().replace(' ', '_').replace('/', '_')
        filename = f"{safe}.png"
    _finalize_figure(fig, show, save, save_dir, filename, dpi)


def combine_history_and_diff(history_run, diff_run, time_limit: float = None):
    combined_data = []
    index = 0
    start_time = None
    
    for step in history_run:
        timestamp, mode, km1_value = step
        if mode == None:
            start_time = timestamp
            continue
        if mode == 'Initial':
            # skip
            continue
        
        relative_time = (timestamp - start_time) / 1000.0  
        
        combined_data.append({
            'timestamp': relative_time,
            'mode': mode,
            'km1': km1_value,
            'diff_matrix': diff_run[index]
        })
        index += 1
    ## Append last diff matrix if any left
    if index < len(diff_run):
        combined_data.append({
            'timestamp': time_limit,
            'mode': mode,
            'km1': km1_value,
            'diff_matrix': diff_run[-1]
        })
    
    return combined_data

def aggregate_history_runs(directory_path: str, full_history: bool = False, instances_to_exclude: list = []):
    history_runs = {}
    for file_path in glob.glob(directory_path + "/*.csv"):
        file_name = os.path.basename(file_path)
        thread_id = file_name.split('.')[-3]
        instance_name = '.'.join(file_name.split('.')[:-3]) 
        
        use_fixed_seed = True
        if full_history:
            use_fixed_seed = False
        instance_name = convert_instance_naming_scheme(instance_name, use_fixed_seed)
        
        purely_instance_name = '.'.join(instance_name.split('.')[:-4])
        if purely_instance_name in instances_to_exclude:
            continue
        
        history = parse_history_file(file_path)
        
        # Last entry should be the final result
        km1 = None
        for timestamp, mode, km1_value in reversed(history):
            if mode is not None:
                km1 = km1_value
                break

        # Store only the history for the best run
        if instance_name not in history_runs:
            history_runs[instance_name] = {'thread_id': thread_id, 'km1': km1, 'history': history}
        else:
            if history_runs[instance_name]['km1'] is None:
                pass
            if km1 is not None and km1 < history_runs[instance_name]['km1']:
                history_runs[instance_name] = {'thread_id': thread_id, 'km1': km1, 'history': history}

    return history_runs

def get_diff_matrices_for_best_run(history_runs, diff_matrices_list, instance_name: str):
    if instance_name not in history_runs:
        return None

    best_thread_id = history_runs[instance_name]['thread_id']
    for entry in diff_matrices_list:
        if entry['thread_id'] == best_thread_id:
            return entry['matrices']
    return None


def aggregate_diff_runs(directory_path: str, instances_to_exclude: list = []):
    diff_runs = {}
    for file_path in glob.glob(directory_path + "/*.csv"):
        matrices = bv.parse_diff_matrices(file_path)
        file_name = os.path.basename(file_path)
        thread_id = file_name.split('.')[-3]
        instance_name = '.'.join(file_name.split('.')[:-3])
        use_fixed_seed = True
        instance_name = convert_instance_naming_scheme(instance_name, use_fixed_seed)
        purely_instance_name = '.'.join(instance_name.split('.')[:-4])
        if purely_instance_name in instances_to_exclude:
            continue
        if instance_name not in diff_runs:
            diff_runs[instance_name] = []    
        diff_runs[instance_name].append({
            'thread_id': thread_id,
            'matrices': matrices
        })

    return diff_runs

def split_long_short_history_runs(history_runs):
    long_runs = {}
    short_runs = {}
    for instance_name, run_data in history_runs.items():
        timelimit = int(instance_name.split('.')[-1])
        if timelimit == SHORT_TIMELIMIT:
            short_runs[instance_name] = run_data
        else:
            long_runs[instance_name] = run_data
    return short_runs, long_runs

def split_seed_history_runs(history_runs):
    seed_runs = {}
    for instance_name, run_data in history_runs.items():
        seed = instance_name.split('.')[-2]
        if seed not in seed_runs:
            seed_runs[seed] = {}
        seed_runs[seed][instance_name] = run_data
        
    ## Convert to list
    seed_count = seed_runs.keys().__len__()
    all_seeds_runs = [None] * seed_count
    for seed, runs in seed_runs.items():
        seed_int= int(seed)
        all_seeds_runs[seed_int - 1] = runs
    return all_seeds_runs

def instance_to_seeds(history_runs):
    # Dictionary maps instance_name to its seed runs
    instance_to_seeds = {}
    for instance_name, run_data in history_runs.items():
        purely_instance_name = '.'.join(instance_name.split('.')[:-4])
        seed = int(instance_name.split('.')[-2])
        if purely_instance_name not in instance_to_seeds:
            instance_to_seeds[purely_instance_name] = {}
        instance_to_seeds[purely_instance_name][seed] = run_data
    return instance_to_seeds

def split_k_value_history_runs(history_runs, *k_values):
    k_runs_list = [{} for _ in k_values]
    k_value_to_index = {k: i for i, k in enumerate(k_values)}

    for instance_name, run_data in history_runs.items():
        k_from_instance = instance_name.split('.')[-3]
        
        if k_from_instance in k_value_to_index:
            index = k_value_to_index[k_from_instance]
            k_runs_list[index][instance_name] = run_data
            
    return k_runs_list

def convert_time_for_history_run(history_run):   
    converted_history = []
    start_time = None
    for timestamp, mode, km1_value in history_run:
        if mode is None:
            start_time = timestamp
            # check if already converted first
            if start_time == 0:
                return history_run.copy()
            converted_history.append((0, mode, km1_value))
        else:
            relative_time = (timestamp - start_time) / 1000.0  
            converted_history.append((relative_time, mode, km1_value))
    return converted_history


def make_history_runs_sequential(*history_runs, time_limit=None):
    combined_runs = {}
    padded_history_runs = [None] * len(history_runs)
    # Add time_limit to history run timestamps
    for i, history_run in enumerate(history_runs):
        padded_history_run = {}
        for instance_name, run_data in history_run.items():
            # Fix seed to 1 for combined run
            instance_name_converted = instance_name.rsplit('.', 2)[0] + '.1.' + instance_name.rsplit('.', 1)[1]
            
            # Create a COPY of the history to avoid modifying original
            history = [entry for entry in run_data['history']]  # Copy list
            converted_history = convert_time_for_history_run(history)
            
            # Create new history with offset times
            new_history = []
            for timestamp, mode, km1_value in converted_history:
                new_history.append((timestamp + i * time_limit, mode, km1_value))
            
            padded_history_run[instance_name_converted] = {
                'thread_id': run_data['thread_id'], 
                'km1': run_data['km1'], 
                'history': new_history
            }
        padded_history_runs[i] = padded_history_run
    # Combine all padded history runs
    combined_runs = merge_histories(*padded_history_runs)
    return combined_runs

def merge_histories(*history_runs):
    merged_histories = {}
    for history_run in history_runs:
        for instance_name, run_data in history_run.items():
            if instance_name not in merged_histories:
                # Deep copy the run_data
                merged_histories[instance_name] = {
                    'thread_id': run_data['thread_id'],
                    'km1': run_data['km1'],
                    'history': [entry for entry in run_data['history']]  # Copy history
                }
            else:
                existing_history = merged_histories[instance_name]['history']
                new_history = run_data['history']
                # Append new history entries (creates new list)
                merged_histories[instance_name]['history'] = existing_history + [entry for entry in new_history]
                # Update km1 and thread_id if new run is better
                if run_data['km1'] < merged_histories[instance_name]['km1']:
                    merged_histories[instance_name]['km1'] = run_data['km1']
                    merged_histories[instance_name]['thread_id'] = run_data['thread_id']
                
    # Sort histories by timestamp
    for instance_name, run_data in merged_histories.items():
        run_data['history'].sort(key=lambda x: x[0])
    # History for each instance shall only contain decreasing km1 values
    for instance_name, run_data in merged_histories.items():
        filtered_history = []
        last_km1 = float('inf')
        for entry in run_data['history']:
            time, mode, km1_value = entry
            if km1_value is not None and km1_value < last_km1:
                filtered_history.append(entry)
                last_km1 = km1_value
            elif time == 0:
                filtered_history.append((0, None, None))
        run_data['history'] = filtered_history
    return merged_histories
               
def split_runs_k_value(runs, *k_values):
# Create a list of empty dictionaries, one for each k-value
    k_runs_list = [{} for _ in k_values]
    
    # Create a mapping from k_value to its index for quick lookups
    k_value_to_index = {k: i for i, k in enumerate(k_values)}

    for instance_name, run_data in runs.items():
        k_from_instance = instance_name.split('.')[-2]
        
        if k_from_instance in k_value_to_index:
            index = k_value_to_index[k_from_instance]
            k_runs_list[index][instance_name] = run_data
                
    return k_runs_list

def extract_info_from_config(config_name: str):
    
    result_dict = {}

    # Remove leading date + underscore if present
    base = re.sub(r'^\d{4}-\d{1,2}-\d{1,2}_', '', config_name)

    # Split different parameters on '-'
    # Filter out empty strings (in case of accidental double '-')
    param_groups = [g for g in base.split('-') if g]

    for group in param_groups:
        parts = group.split('_')
        if not parts or parts[0] == '':
            continue
        param_name = parts[0]
        values = parts[1:] if len(parts) > 1 else []
        # Only store if we have at least one value; keep empty list otherwise
        result_dict[param_name] = values

    return result_dict

def create_mean_over_seeds_single_instance(seeds_to_run, instance_name: str):  
    result_list = []
    current_indices = [0] * len(seeds_to_run)
    current_values = [None] * len(seeds_to_run)
    current_times = [0] * len(seeds_to_run)
    max_time = 0
    
    converted_runs = []
    for seed_run in seeds_to_run.values():
        converted_history = convert_time_for_history_run(seed_run['history'])
        converted_runs.append(converted_history)
    
    # First value is special case
    for i, history in enumerate(converted_runs):
        time, mode, km1_value = history[1]
        current_values[i] = km1_value
        current_indices[i] = 1
        current_times[i] = time
        if time > max_time:
            max_time = time
    # update other instances to match max_time
    for i in range(len(current_times)):
        while current_times[i] < max_time:
            index = current_indices[i]
            history = converted_runs[i]
            if index + 1 < len(history) and history[index + 1][0] <= max_time:
                time, mode, km1_value = history[index + 1]
                current_indices[i] += 1
                current_values[i] = km1_value
                current_times[i] = time
            else:
                break
    result_list.append((max_time, mean(current_values)))
    
    end = False
    while not end:
        next_time = float('inf')
        for i, history in enumerate(converted_runs):
            index = current_indices[i]
            if index + 1 < len(history):
                time, mode, km1_value = history[index + 1]
                if time < next_time:
                    next_time = time
                    instance_to_increment = i
                    km1_to_update = km1_value
        ## next time found
        if next_time == float('inf'):
            end = True
            continue
        # Increment the current index for the instance with the next time
        current_indices[instance_to_increment] += 1
        current_values[instance_to_increment] = km1_to_update
        current_times[instance_to_increment] = next_time
        
        # append result list
        result_list.append((next_time, mean(current_values)))
    return result_list
        
        
def max_iterations_per_instance(iteration_runs, multiple_seeds: bool = False):
    # Return the maximum iterations per instance (avg betwen multiple seeds) as a dict
    max_iterations = {}
    if multiple_seeds:
        instance_to_seeds_dict = instance_to_seeds(iteration_runs)
        for instance_name, seed_runs in instance_to_seeds_dict.items():
            max_iters_per_seed = {}
            for seed_run in seed_runs.values():
                timestamp_to_iteration, iteration_to_metric_value = seed_run
                # Get the last iteration
                last_iteration = iteration_to_metric_value[-1][0]
                if instance_name not in max_iters_per_seed:
                    max_iters_per_seed[instance_name] = []
                max_iters_per_seed[instance_name].append(last_iteration)
            max_iterations[instance_name] = mean(max_iters_per_seed[instance_name])
    else:
        for instance_name, run_data in iteration_runs.items():
            timestamp_to_iteration, iteration_to_metric_value = run_data
            last_iteration = iteration_to_metric_value[-1][0]
            max_iterations[instance_name] = last_iteration
    
    return max_iterations
    
def get_iteration_multipliers(baseline_max_iterations, other_max_iterations):
    multipliers = {}
    for instance_name in baseline_max_iterations:
        if instance_name in other_max_iterations:
            baseline_iters = baseline_max_iterations[instance_name]
            other_iters = other_max_iterations[instance_name]
            if baseline_iters > 0:
                multiplier = other_iters / baseline_iters
                multipliers[instance_name] = multiplier
    if multipliers:
        return multipliers
    else:
        return None
    
def create_multiplier_based_history_runs(history_runs, multipliers):
    adjusted_history_runs = {}
    for instance_name, run_data in history_runs.items():
        # Get first timestamp
        last_initial_timestamp = run_data['history'][0][0] if run_data['history'] else 0
        purely_instance_name = '.'.join(instance_name.split('.')[:-4])
        if purely_instance_name in multipliers:
            multiplier = multipliers[purely_instance_name]
            adjusted_history = []
            for timestamp, mode, km1_value in run_data['history']:
                time_diff = timestamp - last_initial_timestamp
                if mode is not None and mode != 'Initial':
                    # Round to nearest integer
                    adjusted_time = last_initial_timestamp + round(time_diff * multiplier)
                else:
                    last_initial_timestamp = timestamp
                    adjusted_time = timestamp
                adjusted_history.append((adjusted_time, mode, km1_value))
            adjusted_history_runs[instance_name] = {
                'thread_id': run_data['thread_id'],
                'km1': run_data['km1'],
                'history': adjusted_history
            }
        else:
            # No adjustment if no multiplier found
            adjusted_history_runs[instance_name] = run_data
    return adjusted_history_runs
           
                    
def create_geomean_over_all_instances(history_runs, multiple_seeds: bool = False):
    result_list = []
    current_indices = {}
    current_values = {}
    current_times = {}
    max_time = 0
    
    if multiple_seeds:
        instance_to_seeds_dict = instance_to_seeds(history_runs)
        # For each instance, compute mean over seeds
        averaged_history_runs = {}
        for instance_name, seed_runs in instance_to_seeds_dict.items():
            mean_history = create_mean_over_seeds_single_instance(seed_runs, instance_name)
            # history list of tuples (time, mode, km1)
            history = [(time, None, km1) for time, km1 in mean_history]
            averaged_history_runs[instance_name] = history
        converted_runs = averaged_history_runs
    
    # Skip conversion if already done for multiple seeds
    if not multiple_seeds:
        converted_runs = {}
        for instance_name, run_data in history_runs.items():
            converted_history = convert_time_for_history_run(run_data['history'])
            converted_runs[instance_name] = converted_history
            
        
    # First value is special case
    for i, (instance_name, history) in enumerate(converted_runs.items()):
        time, mode, km1_value = history[0]
        current_values[instance_name] = km1_value
        current_indices[instance_name] = 1
        current_times[instance_name] = time
        if time > max_time:
            max_time = time
    
    # update other instances to match max_time
    for instance_name in current_times:
        while current_times[instance_name] < max_time:
            index = current_indices[instance_name]
            history = converted_runs[instance_name]
            if index + 1 < len(history) and history[index + 1][0] <= max_time:
                time, mode, km1_value = history[index + 1]
                current_indices[instance_name] += 1
                current_values[instance_name] = km1_value
                current_times[instance_name] = time
            else:
                break
    # Check for 0 values in current_results and substitute with 1
    for instance_name in current_values:
        if current_values[instance_name] == 0:
            current_values[instance_name] = 1 
    result_list.append((max_time, geometric_mean(current_values.values())))
    
    end = False
    while not end:
        next_time = float('inf')
        for i, (instance_name, history) in enumerate(converted_runs.items()):
            index = current_indices[instance_name]
            if index + 1 < len(history):
                time, mode, km1_value = history[index + 1]
                if time < next_time:
                    next_time = time
                    instance_to_increment = instance_name
                    km1_to_update = km1_value
        ## next time found
        if next_time == float('inf'):
            end = True
            continue
        # Increment the current index for the instance with the next time
        current_indices[instance_to_increment] += 1
        current_values[instance_to_increment] = km1_to_update
        current_times[instance_to_increment] = next_time
        
        # append result list
        result_list.append((next_time, geometric_mean(current_values.values())))

    return result_list

def eval_k_kway(
    configs_name="",
    all_configs_root=os.path.expanduser("~/Documents/experiment_results"),
    results_dirname="mt_kahypar_evo_results",
    diff_dirname="evo_diff",
    history_dirname="evo_history",
    instances_to_exclude=None,
    show_full_history=False,
    default_pop_size=10,
):

    all_configs_dir = os.path.join(all_configs_root, configs_name)

    # Loop through each folder inside DIR
    all_geomeans = {}

    for config_dir in glob.glob(all_configs_dir + "/*/"):
        config_name = os.path.basename(os.path.normpath(config_dir))
        print(f"Processing configuration: {config_name}")

        config_info = extract_info_from_config(config_name)

        k_values = config_info.get('k', None)
        kway_value = config_info.get('kway', None)
        if kway_value:
            kway_value = kway_value[0]
        pop_size = config_info.get('pop', None)

        result_dir = os.path.join(config_dir, results_dirname)
        diff_dir = os.path.join(config_dir, diff_dirname)
        history_dir = os.path.join(config_dir, history_dirname)

        # Aggregate runs from result files
        runs = aggregate_runs(result_dir, instances_to_exclude=instances_to_exclude)
        k_runs = split_runs_k_value(runs, *k_values) if k_values else [runs]
        k_runs_dict = {}
        if k_values:
            for i, k in enumerate(k_values):
                k_runs_dict[k] = k_runs[i]

        history_runs = aggregate_history_runs(history_dir, full_history=show_full_history, instances_to_exclude=instances_to_exclude)
        #diff_runs = aggregate_diff_runs(diff_dir, instances_to_exclude=instances_to_exclude)  # kept for completeness
        k_history_runs = split_k_value_history_runs(history_runs, *k_values) if k_values else [history_runs]
        k_history_runs_dict = {}
        if k_values:
            for i, k in enumerate(k_values):
                k_history_runs_dict[k] = k_history_runs[i]

        # Create Geomeans for each configuration
        for k in k_values:
            k_run = k_runs_dict[k]
            k_history_run = k_history_runs_dict[k]
            if k == '64' and kway_value == '2':
                pass
            geomean = create_geomean_over_all_instances(k_history_run, multiple_seeds=show_full_history)
            config = (k, kway_value, pop_size[0] if pop_size else default_pop_size)
            all_geomeans[config] = geomean

    # Plot all kway geomeans for each k-value in separate plots (data prep only)
    k_pop_to_kway_geomeans = {}
    for (k, kway, pop_size), geomean in all_geomeans.items():
        key = (k, pop_size)
        if key not in k_pop_to_kway_geomeans:
            k_pop_to_kway_geomeans[key] = {}
        k_pop_to_kway_geomeans[key][kway] = geomean

    # Directory where plots will be saved by the final loop
    save_path = os.path.join(FIG_SAVE_DIR, configs_name)
    return k_pop_to_kway_geomeans, save_path

def eval_evothreads(
    configs_name="",
    all_configs_root=os.path.expanduser("~/Documents/experiment_results"),
    results_dirname="mt_kahypar_evo_results",
    diff_dirname="evo_diff",
    history_dirname="evo_history",
    instances_to_exclude=None,
    show_full_history=False
):
    all_configs_dir = os.path.join(all_configs_root, configs_name)
    # Loop through each folder inside DIR
    thread_count = 1
    all_geomeans = {}
    for config_dir in glob.glob(all_configs_dir + "/*/"):
        config_name = os.path.basename(os.path.normpath(config_dir))
        print(f"Processing configuration: {config_name}")
        
        config_info = extract_info_from_config(config_name)
        threads_per_worker = config_info.get('parallel', None)
        
        result_dir = os.path.join(config_dir, results_dirname)
        diff_dir = os.path.join(config_dir, diff_dirname)
        history_dir = os.path.join(config_dir, history_dirname)
        
        # Aggregate runs from result files
        runs = aggregate_runs(result_dir, instances_to_exclude=instances_to_exclude)
        history_runs = aggregate_history_runs(history_dir, full_history=show_full_history, instances_to_exclude=instances_to_exclude)

        if not show_full_history:
            geomean = create_geomean_over_all_instances(history_runs)
        else:
            geomean = create_geomean_over_all_instances(history_runs, multiple_seeds=True)
        threads_per_worker = int(threads_per_worker[0]) if threads_per_worker else "default"
        thread_count = max(thread_count, threads_per_worker)
        all_geomeans[threads_per_worker] = geomean
        
    # Directory where plots will be saved by the final loop
    save_path = os.path.join(FIG_SAVE_DIR, configs_name)
    return all_geomeans, save_path, thread_count
    
    
def eval_generic_by_name(
    configs_name="",
    all_configs_root=os.path.expanduser("~/Documents/experiment_results"),
    results_dirname="mt_kahypar_evo_results",
    diff_dirname="evo_diff",
    history_dirname="evo_history",
    instances_to_exclude=None,
    show_full_history=False,
    plot_iterations=False
): 
    all_configs_dir = os.path.join(all_configs_root, configs_name)
    # Loop through each folder inside DIR
    all_geomeans = {}
    for config_dir in glob.glob(all_configs_dir + "/*/"):
        config_name = os.path.basename(os.path.normpath(config_dir))
        print(f"Processing configuration: {config_name}")
        
        result_dir = os.path.join(config_dir, results_dirname)
        diff_dir = os.path.join(config_dir, diff_dirname)
        history_dir = os.path.join(config_dir, history_dirname)
        
        # Aggregate runs from result files
        runs = aggregate_runs(result_dir, instances_to_exclude=instances_to_exclude)
        history_runs = aggregate_history_runs(history_dir, full_history=show_full_history, instances_to_exclude=instances_to_exclude)

        if not show_full_history:
            geomean = create_geomean_over_all_instances(history_runs)
        else:
            geomean = create_geomean_over_all_instances(history_runs, multiple_seeds=True)
        
        # Remove date prefix if present    
        config_name_cleaned = remove_date_prefix(config_name)
        all_geomeans[config_name_cleaned] = geomean
        
    # Directory where plots will be saved by the final loop
    save_path = os.path.join(FIG_SAVE_DIR, configs_name)
    return all_geomeans, save_path
    
    
def remove_date_prefix(name: str) -> str:
    return re.sub(r'^\d{4}-\d{1,2}-\d{1,2}_', '', name)
    
def aggregate_iterations_runs(directory_path: str, show_full_history: bool = False, instances_to_exclude: list = []):
    iteration_runs = {}
    for file_path in glob.glob(directory_path + "/*.csv"):
        file_name = os.path.basename(file_path)
        instance_name = '.'.join(file_name.split('.')[:-3]) 
        
        use_fixed_seed = True
        if show_full_history:
            use_fixed_seed = False
        instance_name = convert_instance_naming_scheme(instance_name, use_fixed_seed)
        
        purely_instance_name = '.'.join(instance_name.split('.')[:-4])
        if purely_instance_name in instances_to_exclude:
            continue
        
        timestamp_to_iteration, iteration_to_metric_value = get_iteration_run(file_path)
        iteration_runs[instance_name] = (timestamp_to_iteration, iteration_to_metric_value)
    return iteration_runs

def get_iteration_run(iteration_run_file: str):
    ## Return two lists: timestamp_to_iteration and iteration_to_metric_value
    timestamp_to_iteration = []
    iteration_to_metric_value = []
    # iteration log file format: iteration, timestamp, metric_value
    with open(iteration_run_file, 'r') as f:
        first_timestamp = None
        for line in f:
            parts = line.strip().split(',')
            if len(parts) >= 3:
                try:
                    iteration = int(parts[0].strip())
                    timestamp = int(parts[1].strip())
                    metric_value = float(parts[2].strip())
                    
                    if first_timestamp is None:
                        first_timestamp = timestamp
                        
                    normalized_timestamp = (timestamp - first_timestamp) / 1000.0  # in seconds
                    
                    timestamp_to_iteration.append((normalized_timestamp, iteration))
                    iteration_to_metric_value.append((iteration, metric_value))
                except ValueError:
                    continue
    return timestamp_to_iteration, iteration_to_metric_value

def simulate_stop_after_k_no_improvement_iterations(
    timestamp_to_iteration, iteration_to_km1, k: int
):
    if not iteration_to_km1:
        return None, None, None, None, None, None
    
    iter_to_time = {iter_: t for t, iter_ in timestamp_to_iteration}

    first_iter, first_km1 = iteration_to_km1[0]
    last_iter, last_km1 = iteration_to_km1[-1]
    total_improvement = first_km1 - last_km1
    start_time = iter_to_time.get(first_iter, 0.0)
    end_time = iter_to_time.get(last_iter, start_time)
    total_time = max(end_time - start_time, 1e-9)

    best_so_far = first_km1
    no_improve_count = 0

    stop_iter = last_iter
    stop_km1 = last_km1
    stop_time = end_time

    for it, km1 in iteration_to_km1[1:]:
        if km1 < best_so_far:
            best_so_far = km1
            no_improve_count = 0
            stop_iter = it
            stop_km1 = km1
            stop_time = _time_at_or_before_iteration(timestamp_to_iteration, stop_iter)
        else:
            no_improve_count += 1
            if no_improve_count >= k:
                stop_iter = it
                stop_km1 = km1
                stop_time = _time_at_or_before_iteration(timestamp_to_iteration, stop_iter)
                break

    improvement_at_stop = first_km1 - stop_km1
    frac_improvement = (
        improvement_at_stop / total_improvement if total_improvement > 0 else 1.0
    )
    frac_time = (stop_time - start_time) / total_time

    return stop_time, stop_iter, stop_km1, last_km1, frac_improvement, frac_time

def extract_improvements_from_iterations(iteration_to_km1):
    if not iteration_to_km1:
        return []

    improvements = []
    best_so_far = iteration_to_km1[0][1]
    last_improv_iter = iteration_to_km1[0][0]

    for it, km1 in iteration_to_km1[1:]:
        if km1 < best_so_far:
            delta = best_so_far - km1
            improvements.append((it, km1, delta))
            best_so_far = km1
            last_improv_iter = it

    # Contains tuples of (iteration, km1, delta) for each improvement
    return improvements


def _time_at_or_before_iteration(timestamp_to_iteration, target_it):
    """
    Given a list of (time, iteration) pairs sorted by time,
    return the time of the largest iteration <= target_it.
    If all iterations are > target_it, return the first time.
    """
    # timestamp_to_iteration: [(time0, it0), (time1, it1), ...]
    last_time = timestamp_to_iteration[0][0]
    last_it = timestamp_to_iteration[0][1]

    for t, it in timestamp_to_iteration:
        if it > target_it:
            break
        last_time = t
        last_it = it

    return last_time

def simulate_stop_by_improvement_rate(
    timestamp_to_iteration,
    iteration_to_km1,
    early_window_improvs: int = 5,
    recent_window_improvs: int = 5,
    alpha: float = 0.1,
    max_iters_without_improv: int = 500,
):
    if len(iteration_to_km1) < 2:
        return None, None, None, None, None, None

    iter_to_time = {iter_: t for t, iter_ in timestamp_to_iteration}
    first_iter, first_km1 = iteration_to_km1[0]
    last_iter, last_km1 = iteration_to_km1[-1]
    total_improvement = first_km1 - last_km1
    start_time = iter_to_time.get(first_iter, 0.0)
    end_time = iter_to_time.get(last_iter, start_time)
    total_time = end_time - start_time

    def km1_at_or_before(target_it):
        for it, km1 in reversed(iteration_to_km1):
            if it <= target_it:
                return km1
        return iteration_to_km1[0][1]
    
    # Extract improvements as (iter, km1, delta)
    improvements = extract_improvements_from_iterations(iteration_to_km1)

    # Fallback: handle too few improvements case (last k iterations without improvement)
    if len(improvements) < early_window_improvs + 1:
        best_so_far = first_km1
        last_improv_iter = first_iter
        stop_iter = last_iter
        stop_km1 = last_km1
        stop_time = end_time

        for it, km1 in iteration_to_km1[1:]:
            if km1 < best_so_far:
                best_so_far = km1
                last_improv_iter = it
                stop_iter = it
                stop_km1 = km1
                stop_time = _time_at_or_before_iteration(timestamp_to_iteration, stop_iter)
            else:
                # Plateau fallback
                if it - last_improv_iter >= max_iters_without_improv:
                    stop_iter = it
                    stop_km1 = best_so_far
                    stop_time = _time_at_or_before_iteration(timestamp_to_iteration, stop_iter)
                    break

        improvement_at_stop = first_km1 - stop_km1
        frac_improvement = (
            improvement_at_stop / total_improvement if total_improvement > 0 else 1.0
        )
        frac_time = (stop_time - start_time) / total_time
        return stop_time, stop_iter, stop_km1, last_km1, frac_improvement, frac_time


    # Normal case

    # Early rate from first early_window_improvs improvements
    early_slice = improvements[:early_window_improvs]
    it0, km1_0, _ = early_slice[0]
    itW, km1_W, _ = early_slice[-1]
    early_delta = km1_0 - km1_W
    early_span = max(itW - it0, 1)
    early_rate = early_delta / early_span

    stop_iter = last_iter
    stop_km1 = last_km1
    stop_time = end_time

    last_improv_iter = improvements[early_window_improvs - 1][0]

    # Walk over improvements and compute recent rate over a sliding window of improvements
    for i in range(early_window_improvs, len(improvements)):
        window = improvements[max(0, i - recent_window_improvs + 1): i + 1]
        it_start, _, _ = window[0]
        it_end, km_end, _ = window[-1]
        delta_sum = sum(d for _, _, d in window)
        span = max(it_end - it_start, 1)
        recent_rate = delta_sum / span

        # Update last_improv_iter (most recent improvement considered)
        last_improv_iter = it_end
        
        if recent_rate < alpha * early_rate:
            stop_iter = it_end
            stop_km1 = km_end
            stop_time = _time_at_or_before_iteration(timestamp_to_iteration, stop_iter)
            break

        # Plateau fallback: check for distance to next improvement
        if i + 1 < len(improvements):
            next_improv_iter, _, _ = improvements[i + 1]
            gap = next_improv_iter - last_improv_iter

            # If the next improvement is further than max_iters_without_improv away then stop
            if gap > max_iters_without_improv:
                target_it = last_improv_iter + max_iters_without_improv
                stop_iter = target_it
                stop_km1 = km1_at_or_before(target_it)
                stop_time = _time_at_or_before_iteration(timestamp_to_iteration, target_it)
                improvement_at_stop = first_km1 - stop_km1
                frac_improvement = (
                    improvement_at_stop / total_improvement if total_improvement > 0 else 1.0
                )
                frac_time = (stop_time - start_time) / total_time
                return (
                    stop_time,
                    stop_iter,
                    stop_km1,
                    last_km1,
                    frac_improvement,
                    frac_time,
                )
        else:
            # No future improvements at all
            target_it = min(last_improv_iter + max_iters_without_improv, last_iter)
            stop_iter = target_it
            stop_km1 = km1_at_or_before(target_it)
            stop_time = _time_at_or_before_iteration(timestamp_to_iteration, target_it)
            improvement_at_stop = first_km1 - stop_km1
            frac_improvement = (
                improvement_at_stop / total_improvement if total_improvement > 0 else 1.0
            )
            frac_time = (stop_time - start_time) / total_time
            return (
                stop_time,
                stop_iter,
                stop_km1,
                last_km1,
                frac_improvement,
                frac_time,
            )

    improvement_at_stop = first_km1 - stop_km1
    frac_improvement = (
        improvement_at_stop / total_improvement if total_improvement > 0 else 1.0
    )
    frac_time = (stop_time - start_time) / total_time
    return stop_time, stop_iter, stop_km1, last_km1, frac_improvement, frac_time



def simulate_stop_by_global_tangent(
    timestamp_to_iteration,
    iteration_to_km1,
    alpha: float = 0.1,
    min_iters: int = 50,
    max_iters_without_improv: int = 500,
):
    if len(iteration_to_km1) < 2:
        return None, None, None, None, None, None

    iter_to_time = {it: t for t, iter_ in timestamp_to_iteration}
    first_iter, first_km1 = iteration_to_km1[0]
    last_iter, last_km1 = iteration_to_km1[-1]

    total_improvement = first_km1 - last_km1
    start_time = iter_to_time.get(first_iter, 0.0)
    end_time = iter_to_time.get(last_iter, start_time)
    total_time = end_time - start_time

    # --- Find first improvement over the initial value ---
    base_km1 = first_km1
    first_improv_iter = None
    first_improv_km1 = None

    no_improv_since_start = 0

    for it, km1 in iteration_to_km1[1:]:
        if km1 < base_km1:
            first_improv_iter = it
            first_improv_km1 = km1
            break
        else:
            no_improv_since_start += 1
            if no_improv_since_start >= max_iters_without_improv:
                # Plateau or optimal from the very beginning
                stop_iter = it
                stop_km1 = km1
                stop_time = _time_at_or_before_iteration(timestamp_to_iteration, stop_iter)
                improvement_at_stop = first_km1 - stop_km1
                frac_improvement = (
                    improvement_at_stop / total_improvement
                    if total_improvement > 0
                    else 1.0
                )
                frac_time = (stop_time - start_time) / total_time
                return (
                    stop_time,
                    stop_iter,
                    stop_km1,
                    last_km1,
                    frac_improvement,
                    frac_time,
                )

    # No improvement at all before the end of the log
    if first_improv_iter is None:
        return end_time, last_iter, last_km1, last_km1, 1.0, 1.0

    # helper to get best-so-far at or before an iteration
    # only necessary becaue of "skipped" iterations in the log
    def best_km1_at_iter(target_it):
        for it, km1 in reversed(iteration_to_km1):
            if it <= target_it:
                return km1
        return iteration_to_km1[0][1]

    # Initial global rate is measured once we have enough iterations since first_improv_iter
    init_rate_it = None
    init_rate_val = None

    for it, km1 in iteration_to_km1:
        if it < first_improv_iter + min_iters:
            continue
        curr_best = best_km1_at_iter(it)
        delta = first_improv_km1 - curr_best
        span = max(it - first_improv_iter, 1)
        rate = delta / span
        init_rate_it = it
        init_rate_val = rate
        break

    # If we never got enough iterations, just run to the end
    if init_rate_it is None or init_rate_val <= 0:
        return end_time, last_iter, last_km1, last_km1, 1.0, 1.0

    best_so_far = best_km1_at_iter(init_rate_it)
    last_improv_iter = init_rate_it
    stop_iter = last_iter
    stop_km1 = last_km1
    stop_time = end_time

    for it, km1 in iteration_to_km1:
        if it < init_rate_it:
            continue

        if km1 < best_so_far:
            best_so_far = km1
            last_improv_iter = it

        #Check idle plateau as a backup (local minimum / true optimum)
        if it - last_improv_iter >= max_iters_without_improv:
            stop_iter = it
            stop_km1 = best_so_far
            stop_time = _time_at_or_before_iteration(timestamp_to_iteration, stop_iter)
            break

        # Compute global rate from first_improv_iter to current it
        curr_best = best_so_far
        delta = first_improv_km1 - curr_best
        span = max(it - first_improv_iter, 1)
        curr_rate = delta / span

        # Compare to initial global rate
        if curr_rate <= alpha * init_rate_val:
            stop_iter = it
            stop_km1 = curr_best
            stop_time = _time_at_or_before_iteration(timestamp_to_iteration, stop_iter)
            break

    improvement_at_stop = first_km1 - stop_km1
    frac_improvement = (
        improvement_at_stop / total_improvement if total_improvement > 0 else 1.0
    )
    frac_time = (stop_time - start_time) / total_time

    return stop_time, stop_iter, stop_km1, last_km1, frac_improvement, frac_time


def simulate_stop_by_improvement_rate_iter_window(
    timestamp_to_iteration,
    iteration_to_km1,
    early_window_iters: int = 100,
    recent_window_iters: int = 20,
    alpha: float = 0.1,
    max_iters_without_improv: int = 500
):
    if len(iteration_to_km1) < 2:
        return None, None, None, None, None, None

    iter_to_time = {iter_: t for t, iter_ in timestamp_to_iteration}
    first_iter, first_km1 = iteration_to_km1[0]
    last_iter, last_km1 = iteration_to_km1[-1]
    total_improvement = first_km1 - last_km1
    start_time = iter_to_time.get(first_iter, 0.0)
    end_time = iter_to_time.get(last_iter, start_time)
    total_time = end_time - start_time

    def km1_at_iter(target_it):
        for it, km1 in reversed(iteration_to_km1):
            if it <= target_it:
                return km1
        return iteration_to_km1[0][1]

    # Early window: from first_iter to first_iter + early_window_iters
    early_start_it = first_iter
    early_end_it = min(first_iter + early_window_iters, last_iter)
    km1_start = km1_at_iter(early_start_it)
    km1_end = km1_at_iter(early_end_it)
    early_delta = km1_start - km1_end
    
    if early_delta == 0:
        # Iterate to first improvement if no improvement in early window
        it = early_end_it + 1
        while it <= last_iter and (it - early_start_it) <= max_iters_without_improv:
            km1_curr = km1_at_iter(it)
            if km1_curr < km1_end:
                early_end_it = it
                km1_end = km1_curr
                early_delta = km1_start - km1_end
                break
            it += 1
    
    early_span = max(early_end_it - early_start_it, 1)
    early_rate = early_delta / early_span
    
    stop_iter = last_iter
    stop_km1 = last_km1
    stop_time = end_time

    # Slide recent window in iteration space
    it = early_end_it
    while it <= last_iter:
        recent_start_it = max(first_iter, it - recent_window_iters)
        recent_end_it = it
        km1_recent_start = km1_at_iter(recent_start_it)
        km1_recent_end = km1_at_iter(recent_end_it)
        delta = km1_recent_start - km1_recent_end
        span = max(recent_end_it - recent_start_it, 1)
        recent_rate = delta / span

        if recent_rate < alpha * early_rate:
            stop_iter = recent_end_it
            stop_km1 = km1_recent_end
            stop_time = _time_at_or_before_iteration(timestamp_to_iteration, stop_iter)
            break

        it += 1

    improvement_at_stop = first_km1 - stop_km1
    frac_improvement = (
        improvement_at_stop / total_improvement if total_improvement > 0 else 1.0
    )
    frac_time = (stop_time - start_time) / total_time

    return stop_time, stop_iter, stop_km1, last_km1, frac_improvement, frac_time



def plot_iteration_data(*datasets, title: str = "", labels=None, filename: str = None,
                     show: bool = True, save: bool = False, save_dir: str = None, dpi: int = 150, iteration_to_metric: bool = False,
                     timestamp_to_iteration: bool = False):
    
    assert (iteration_to_metric != timestamp_to_iteration), "Only one of iteration_to_metric or timestamp_to_iteration can be True."
    
    fig, ax = plt.subplots(figsize=(10, 6))
    colors = ['b', 'r', 'g', 'orange', 'purple', 'brown', 'pink', 'gray', 'cyan']
    markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p', '*']
    if labels is None:
        labels = [f'Dataset {i+1}' for i in range(len(datasets))]  
    for i, data in enumerate(datasets):
        x = [entry[0] for entry in data]
        y = [entry[1] for entry in data]
        
        color = colors[i % len(colors)]
        marker = markers[i % len(markers)]
        label = labels[i] if i < len(labels) else f'Dataset {i+1}'
        
        ax.plot(x, y, marker=marker, color=color, label=label, markersize=3)
        
    ax.set_title(title, fontsize=14, fontweight='bold')
    if iteration_to_metric:
        ax.set_xlabel('Iteration', fontsize=12)
        ax.set_ylabel('Metric Value', fontsize=12)
    if timestamp_to_iteration:
        ax.set_xlabel('Time (seconds)', fontsize=12)
        ax.set_ylabel('Iteration', fontsize=12)
    ax.grid(True, alpha=0.3)
    ax.legend()
    
    plt.tight_layout()
    if filename is None:
        safe = title.lower().replace(' ', '_').replace('/', '_')
        filename = f"{safe}.png"
    _finalize_figure(fig, show, save, save_dir, filename, dpi)


def plot_time_series(*datasets, title: str = "Geometric Mean KM1 over Time", labels=None, filename: str = None,
                     show: bool = True, save: bool = False, save_dir: str = None, dpi: int = 300, 
                     log_x: bool = True, x_label: str = "time $t$", y_label: str = "mean min $(\lambda - 1)$",
                     max_time: float = None):
    
    # 1. Setup Figure
    fig, ax = plt.subplots(figsize=(7, 5))
    
    # 2. Colors matching the reference PDF (Green, Blue, Brown, Teal, Orange, Purple, Red, Black)
    academic_colors = ["#4DAF4A", "#377EB8", "#A65628", "#98E0D6", "#FF7F00", "#984EA3", "#E41A1C", "#000000"]
    
    if labels is None:
        labels = [f'Dataset {i+1}' for i in range(len(datasets))]
        
    # 3. Plot Data
    for i, data in enumerate(datasets):
        if isinstance(data, dict):
            times = list(data.keys())
            values = list(data.values())
        else:
            times = [entry[0] for entry in data]
            values = [entry[1] for entry in data]
        
        # Apply max_time cutoff first
        if max_time is not None:
            filtered_data = [(t, v) for t, v in zip(times, values) if t <= max_time]
            if filtered_data:
                times, values = zip(*filtered_data)
            else:
                continue  # Skip this dataset if no data within max_time
        
        color = academic_colors[i % len(academic_colors)]
        ax.plot(times, values, color=color, label=labels[i], linewidth=1.5)

    # 4. Axis Configuration (Matplotlib side - for the PNG preview)
    if log_x:
        ax.set_xscale('log')
        # Hardcode the specific ticks from your reference image
        specific_ticks = [1, 2, 5, 10, 20, 50, 100, 200, 500]
        ax.xaxis.set_major_locator(FixedLocator(specific_ticks))
        ax.xaxis.set_major_formatter(ScalarFormatter())
    
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)
    ax.grid(True, which="major", linestyle=":", linewidth=0.8, alpha=0.7)
    
    # Legend settings
    legend = ax.legend(title="Algorithm", loc='best', frameon=False)
    plt.setp(legend.get_title(), fontsize=10, fontweight='bold')

    # 5. Handle Title
    # We show title in PNG preview, but REMOVE it for LaTeX (academic figures use \caption below)
    if title:
        ax.set_title(title)

    plt.tight_layout()

    # 6. Save Logic
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
        
    # Save PNG (Keep title for preview)
    if filename is None:
        filename = "plot_output"
    
    if save:
        png_path = os.path.join(save_dir, f"{filename}.png")
        fig.savefig(png_path, dpi=dpi, bbox_inches='tight')
        print(f"Saved PNG to: {png_path}")

    # 7. Save TikZ (The Critical Part for LaTeX)
    if save:
        tex_path = os.path.join(save_dir, f"{filename}.tex")
        
        # Remove title specifically for the TeX file so it doesn't appear at the top
        current_title = ax.get_title()
        ax.set_title("") 
        
        # PGFPlots commands to force the style
        extra_axis_opts = [
            'height=7cm',
            'width=9cm',
            'grid=both',
            'grid style={dotted, gray!50}',  # Matches the faint dotted grid
            'major grid style={dotted, gray!50}',
            'xlabel near ticks',
            'ylabel near ticks',
            'tick align=outside',
            'legend cell align={left}',
            'legend style={draw=none, fill=none, font=\\small}', # No box around legend
            'x label style={font=\\large}',
            'y label style={font=\\large}',
        ]
        
        # Explicitly force Log Mode and Ticks for PGFPlots
        if log_x:
            extra_axis_opts.extend([
                'xmode=log',
                'log ticks with fixed point', # Shows "10" instead of "10^1"
                'xtick={1, 2, 5, 10, 20, 50, 100, 200, 500}', # Exact ticks from reference
            ])

        tikzplotlib.save(
            tex_path,
            axis_height='7cm',
            axis_width='9cm',
            extra_axis_parameters=extra_axis_opts
        )
        
        # Restore title in case plot object is reused (optional)
        ax.set_title(current_title)
        print(f"Saved TikZ to: {tex_path}")

    if show:
        plt.show()
    else:
        plt.close(fig)


def plot_single_matrix(matrix, title: str = "Difference Matrix", filename: str = None,
                       show: bool = True, save: bool = False, save_dir: str = None, dpi: int = 150):
    fig, ax = plt.subplots(figsize=(10, 8))
    if not isinstance(matrix, np.ndarray):
        matrix = np.array(matrix)
    
    im = ax.imshow(matrix, cmap='viridis', aspect='auto', interpolation='nearest')
    fig.colorbar(im, ax=ax, label='Difference Value')

    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.grid(False)
    
    plt.tight_layout()
    if filename is None:
        safe = title.lower().replace(' ', '_').replace('/', '_')
        filename = f"{safe}.png"
    _finalize_figure(fig, show, save, save_dir, filename, dpi)
    
def plot_frac_time_vs_improvement(points, title: str, filename: str = None,
                                  show: bool = True, save: bool = False,
                                  save_dir: str = None, dpi: int = 150):
    if not points:
        return

    xs = [p[0] for p in points]
    ys = [p[1] for p in points]

    fig, ax = plt.subplots(figsize=(6, 6))
    sc = ax.scatter(xs, ys, c=ys, cmap="viridis", s=20, alpha=0.7)

    ax.set_xlabel("Fraction of time used", fontsize=12)
    ax.set_ylabel("Fraction of improvement achieved", fontsize=12)
    ax.set_title(title, fontsize=14, fontweight="bold")
    ax.grid(True, alpha=0.3)

    # Always use fixed 0..1 ranges
    ax.set_xlim(0.0, 1.0)
    ax.set_ylim(0.0, 1.0)
    cbar = fig.colorbar(sc, ax=ax)
    cbar.set_label("Fraction of improvement", fontsize=12)

    # Optional guide lines
    ax.axvline(0.5, color="gray", linestyle="--", alpha=0.4)
    ax.axhline(0.8, color="gray", linestyle="--", alpha=0.4)

    plt.tight_layout()
    if filename is None:
        safe = title.lower().replace(" ", "_").replace("/", "_")
        filename = f"{safe}.png"
    _finalize_figure(fig, show, save, save_dir, filename, dpi)

# --- Bulk helpers ---

def save_multiple_time_series(series_groups, titles, labels_groups=None, save_dir=None, base_filename=None, dpi=150, show=True):
    for i, group in enumerate(series_groups):
        title = titles[i] if i < len(titles) else f"series_{i+1}"
        labels = None
        if labels_groups and i < len(labels_groups):
            labels = labels_groups[i]
        # Normalize group to tuple of datasets
        if isinstance(group, (list, tuple)) and group and isinstance(group[0], (list, tuple)):
            datasets = group
        else:
            datasets = (group,)
        filename = None
        if base_filename:
            filename = f"{base_filename}_{i+1}.png"
        plot_time_series(*datasets, title=title, labels=labels, filename=filename, show=show, save=True, save_dir=save_dir, dpi=dpi)


def save_multiple_matrices(matrices, titles, save_dir=None, base_filename=None, dpi=150, show=True):
    for i, matrix in enumerate(matrices):
        title = titles[i] if i < len(titles) else f"matrix_{i+1}"
        filename = None
        if base_filename:
            filename = f"{base_filename}_{i+1}.png"
        plot_single_matrix(matrix, title=title, filename=filename, show=show, save=True, save_dir=save_dir, dpi=dpi)

In [None]:
RESULT_FILES_DIR = f"{HOME_DIR}/Documents/experiment_results/2025-11-13_long_diff_combined_results/2025-11-13_long_diff_test/mt_kahypar_evo_results"
DIFF_FILES_DIR = f"{HOME_DIR}/Documents/experiment_results/2025-11-13_long_diff_combined_results/2025-11-13_long_diff_test/evo_diff"
HISTORY_FILES_DIR = f"{HOME_DIR}/Documents/experiment_results/2025-11-13_long_diff_combined_results/2025-11-13_long_diff_test/evo_history"

SHORT_TIMELIMIT = None
LONG_TIMELIMIT = 21600


# Aggregate runs from result files
runs = aggregate_runs(RESULT_FILES_DIR)
k8_runs, k32_runs = split_runs_k_value(runs, '8', '32')

diff_array_k2 = []
for instance_name, results in k32_runs.items():
    short_km1 = results["short"]
    long_km1 = results["long"]
    if short_km1 is not None and long_km1 is not None:
        diff = long_km1 - short_km1
        diff_array_k2.append(diff)
diff_array_k8 = []
for instance_name, results in k8_runs.items():
    short_km1 = results["short"]
    long_km1 = results["long"]
    if short_km1 is not None and long_km1 is not None:
        diff = long_km1 - short_km1
        diff_array_k8.append(diff)

# How often is long run better than short runs (percentage)
if diff_array_k2:
    better_count = sum(1 for d in diff_array_k2 if d < 0)
    same_count = sum(1 for d in diff_array_k2 if d == 0)
    percentage_better = (better_count / len(diff_array_k2)) * 100
    percentage_same = (same_count / len(diff_array_k2)) * 100
    print(f"Percentage of instances where long run is better than short runs (k=32): {percentage_better:.2f}%")
    print(f"Percentage of instances where long run is the same as short runs (k=32): {percentage_same:.2f}%")
    # better_count_k8 = sum(1 for d in diff_array_k8 if d < 0)
    # percentage_better_k8 = (better_count_k8 / len(diff_array_k8)) * 100
    # print(f"Percentage of instances where long run is better than short runs (k=8): {percentage_better_k8:.2f}%")
else:
    print("No valid differences to analyze.")


# Get Geomean for short and long runs
short_km1_values_k32 = [km1["short"] for km1 in k32_runs.values() if km1["short"] is not None]
long_km1_values_k32 = [km1["long"] for km1 in k32_runs.values() if km1["long"] is not None]
if short_km1_values_k32 and long_km1_values_k32:
    geomean_short = geometric_mean(short_km1_values_k32)
    geomean_long = geometric_mean(long_km1_values_k32)
    print(f"Geometric Mean KM1 (k=32) - Short Runs: {geomean_short}, Long Runs: {geomean_long}")
    
short_km1_values_k8 = [km1["short"] for km1 in k8_runs.values() if km1["short"] is not None]
long_km1_values_k8 = [km1["long"] for km1 in k8_runs.values() if km1["long"] is not None]
if short_km1_values_k8 and long_km1_values_k8:
    geomean_short_k8 = geometric_mean(short_km1_values_k8)
    geomean_long_k8 = geometric_mean(long_km1_values_k8)
    print(f"Geometric Mean KM1 (k=8) - Short Runs: {geomean_short_k8}, Long Runs: {geomean_long_k8}")

# Analyze largest relative differences (in favor of long runs)
relative_diffs = []
for run, km1 in k32_runs.items():
    short_km1 = km1["short"]
    long_km1 = km1["long"]
    if short_km1 is not None and long_km1 is not None:
        diff = long_km1 - short_km1
        relative_diff = diff / short_km1
        relative_diffs.append((run, relative_diff))
    elif short_km1 is None and long_km1 is not None:
        relative_diffs.append((run, float('inf')))
relative_diffs.sort(key=lambda x: x[1], reverse=True)
print("relative differences: ", relative_diffs[:10])


# Diff Matrices Analysis
diff_runs = aggregate_diff_runs(DIFF_FILES_DIR)

# Analyze diff matrices for worst long runs
show_full_history = False

history_runs = aggregate_history_runs(HISTORY_FILES_DIR, full_history=show_full_history)
for run, _ in relative_diffs[:10]:
    short_instance = f"{run}.{SHORT_TIMELIMIT}"
    long_instance = f"{run}.{LONG_TIMELIMIT}"   
    if long_instance in diff_runs:
        
        #short_matrices = diff_runs[short_instance]
        long_matrices = diff_runs[long_instance]

        if len(long_matrices[0]['matrices']) == 0:
            continue
        #if len(short_matrices[0]['matrices']) == 0:
        #    continue

        #best_short_matrices = get_diff_matrices_for_best_run(history_runs, short_matrices, short_instance)
        #last_short_matrix = best_short_matrices[-1] if best_short_matrices else None
        last_long_matrix = long_matrices[0]['matrices'][-1]
        
        long_diff_run = long_matrices[0]['matrices']
        #short_diff_run = best_short_matrices
        #short_history_run = history_runs[short_instance]['history']
        long_history_run = history_runs[long_instance]['history']
        
        #combined_short = combine_history_and_diff(short_history_run, short_diff_run)
        combined_long = combine_history_and_diff(long_history_run, long_diff_run, time_limit=LONG_TIMELIMIT)
        #plot_combined_data(combined_short, title=f"Short Run Combined Data for {short_instance}")
        plot_combined_data(combined_long, title=f"Long Run Combined Data for {long_instance}")

        if last_long_matrix:
            #plot_single_matrix(last_short_matrix, title=f"Short Run Diff Matrix for {short_instance}")
            plot_single_matrix(last_long_matrix, title=f"Long Run Diff Matrix for {long_instance}")
        
# Geometric Mean over all instances
short_history_runs, long_history_runs = split_long_short_history_runs(history_runs)
short_k8_runs, short_k32_runs = split_k_value_history_runs(short_history_runs, '8', '32')
long_k8_runs, long_k32_runs = split_k_value_history_runs(long_history_runs, '8', '32')
# geomean_short = create_geomean_over_all_instances(short_k32_runs)
geomean_long = create_geomean_over_all_instances(long_k32_runs)
# plot_time_series(geomean_short, geomean_long, 
#                  title="Geometric Mean KM1 over Time (Short vs Long Runs)",
#                  labels=['Short Runs', 'Long Runs'])

# Show full history
show_full_history = True
history_runs = aggregate_history_runs(HISTORY_FILES_DIR, full_history=show_full_history)
short_history_runs_full, _ = split_long_short_history_runs(history_runs)
_, short_k32_runs_full = split_k_value_history_runs(short_history_runs_full, '8', '32')
all_seeds_short_histories_k32 = split_seed_history_runs(short_k32_runs_full)

fixed_short_history_runs = make_history_runs_sequential(*all_seeds_short_histories_k32, time_limit=SHORT_TIMELIMIT)
#fixed_geomean_short = create_geomean_over_all_instances(fixed_short_history_runs)

geomean_for_all_seeds = []
for seed_runs in all_seeds_short_histories_k32:
    geomean = create_geomean_over_all_instances(seed_runs)
    geomean_for_all_seeds.append(geomean)
plot_time_series(*geomean_for_all_seeds, geomean_long,
                 title="Geometric Mean KM1 over Time (k=32)",
                 labels=[f'Seed {i+1}' for i in range(len(geomean_for_all_seeds))] + ['Long Runs'])




In [None]:
CONFIGS_NAME = "2025-11-28_combined_results"
ALL_CONFIGS_DIR = os.path.expanduser(f"~/Documents/experiment_results/{CONFIGS_NAME}")
RESULTS = "mt_kahypar_evo_results"
DIFF = "evo_diff"
HISTORY = "evo_history"
DEFAULT_POP_SIZE = 10
show_full_history = True


INSTANCES_TO_EXCLUDE = ["Pd_rhs.mtx.hgr", "wb-edu.mtx.hgr"]

k_pop_to_kway_geomeans, save_path = eval_k_kway(configs_name=CONFIGS_NAME,
                                         all_configs_root=os.path.expanduser("~/Documents/experiment_results"),
                                         results_dirname=RESULTS,
                                         diff_dirname=DIFF,
                                         history_dirname=HISTORY,
                                         instances_to_exclude=INSTANCES_TO_EXCLUDE,
                                         show_full_history=show_full_history,
                                         default_pop_size="dynamic")

for (k, pop_size), kway_dict in k_pop_to_kway_geomeans.items():
    datasets = []
    labels = []
    
    # Sort kway keys for consistent legend order if desired
    sorted_kways = sorted(kway_dict.keys(), key=lambda x: str(x)) # simple sort, might need int conversion if kway is numeric string

    for kway in sorted_kways:
        geomean = kway_dict[kway]
        datasets.append(geomean)
        labels.append(f'kway={kway}')

    if datasets:
        plot_time_series(*datasets,
                         title=f"Geometric Mean KM1 over Time (k={k}, pop={pop_size})",
                         labels=labels,
                         save=True,
                         save_dir=save_path)

In [None]:
CONFIGS_NAME = "2025-11-29_parallel_combined_results"
ALL_CONFIGS_DIR = os.path.expanduser(f"~/Documents/BA_results/{CONFIGS_NAME}")
RESULTS = "mt_kahypar_evo_results"
DIFF = "evo_diff"
HISTORY = "evo_history"
ITERATION_LOG = "evo_iteration_log"
DEFAULT_POP_SIZE = 10
INSTANCES_TO_EXCLUDE = ["Pd_rhs.mtx.hgr"]
show_full_history = True


threads_per_worker_to_geomeans, save_path, max_threads = eval_evothreads(
    configs_name=CONFIGS_NAME,
    all_configs_root=os.path.expanduser("~/Documents/experiment_results"),
    results_dirname=RESULTS,
    diff_dirname=DIFF,
    history_dirname=HISTORY,
    instances_to_exclude=INSTANCES_TO_EXCLUDE,
    show_full_history=show_full_history,
)

all_aggregated_iteration_runs = {}
all_aggregated_history_runs = {}
for config_dir in glob.glob(ALL_CONFIGS_DIR + "/*/"):
    iteration_log_path = os.path.join(config_dir, ITERATION_LOG)
    history_path = os.path.join(config_dir, HISTORY)
    config_name = os.path.basename(os.path.normpath(config_dir))
    config_name = remove_date_prefix(config_name)
    all_aggregated_iteration_runs[config_name] = aggregate_iterations_runs(iteration_log_path, show_full_history=show_full_history, instances_to_exclude=INSTANCES_TO_EXCLUDE)
    all_aggregated_history_runs[config_name] = aggregate_history_runs(history_path, full_history=show_full_history, instances_to_exclude=INSTANCES_TO_EXCLUDE)

# Create geomeans for iterations
all_geomeans_iterations = {}
baseline_config = None
baseline_iterations = None
current_max = 0
for config_name, iteration_runs in all_aggregated_iteration_runs.items():
    max_iterations = max_iterations_per_instance(
        iteration_runs, multiple_seeds=show_full_history
    )
    if not max_iterations:
        continue
    avg_iters = mean(max_iterations.values())
    if avg_iters > current_max:
        current_max = avg_iters
        baseline_config = config_name
        baseline_iterations = max_iterations

# Calculate multipliers for each config based on baseline       
for config_name, iteration_runs in all_aggregated_iteration_runs.items():
    max_iterations = max_iterations_per_instance(iteration_runs, multiple_seeds=show_full_history)
    multipliers = get_iteration_multipliers(baseline_max_iterations=baseline_iterations, other_max_iterations=max_iterations)

    adjusted_history_runs = create_multiplier_based_history_runs(all_aggregated_history_runs[config_name], multipliers)
    all_geomeans_iterations[config_name] = create_geomean_over_all_instances(adjusted_history_runs, multiple_seeds=show_full_history)


# Plot all in one plot
labels_for_all = []
for config_name, geomean in threads_per_worker_to_geomeans.items():
    labels_for_all.append(config_name)
plot_time_series(*threads_per_worker_to_geomeans.values(), 
                     title="Geometric Mean KM1 over actual Time for Modified Combine Strategies",
                     labels=labels_for_all,
                     save=True,
                     save_dir=save_path)

labels_for_all = []
for config_name, geomean in all_geomeans_iterations.items():
    labels_for_all.append(config_name)
plot_time_series(*all_geomeans_iterations.values(), 
                     title="Geometric Mean KM1 over iteration normalized Time for Modified Combine Strategies",
                     labels=labels_for_all,
                     save=True,
                     save_dir=save_path)

In [None]:
CONFIGS_NAME = "2026-1-5_meta"
ALL_CONFIGS_DIR = os.path.expanduser(f"~/Documents/BA_results/{CONFIGS_NAME}")
RESULTS = "mt_kahypar_evo_results"
DIFF = "evo_diff"
HISTORY = "evo_history"
ITERATION_LOG = "evo_iteration_log"
DEFAULT_POP_SIZE = 10
INSTANCES_TO_EXCLUDE = ["Pd_rhs.mtx.hgr"]
MAX_TIME = 7200
show_full_history = True

time_normalized = False

modifiedCombines_to_geomeans, save_path = eval_generic_by_name(
    configs_name=CONFIGS_NAME,
    all_configs_root=os.path.expanduser("~/Documents/BA_results"),
    results_dirname=RESULTS,
    diff_dirname=DIFF,
    history_dirname=HISTORY,
    instances_to_exclude=INSTANCES_TO_EXCLUDE,
    show_full_history=show_full_history)

# Go through all config directories
all_aggregated_iteration_runs = {}
all_aggregated_history_runs = {}
for config_dir in glob.glob(ALL_CONFIGS_DIR + "/*/"):
    iteratiion_log_path = os.path.join(config_dir, ITERATION_LOG)
    history_path = os.path.join(config_dir, HISTORY)
    config_name = os.path.basename(os.path.normpath(config_dir))
    config_name = remove_date_prefix(config_name)
    all_aggregated_iteration_runs[config_name] = aggregate_iterations_runs(iteratiion_log_path, show_full_history=show_full_history, instances_to_exclude=INSTANCES_TO_EXCLUDE)
    all_aggregated_history_runs[config_name] = aggregate_history_runs(history_path, full_history=show_full_history, instances_to_exclude=INSTANCES_TO_EXCLUDE)

# Create geomeans for iterations
all_geomeans_iterations = {}
current_max = 0
for config_name, iteration_runs in all_aggregated_iteration_runs.items():
    max_iterations = max_iterations_per_instance(iteration_runs, multiple_seeds=show_full_history)
    if "default" in config_name:
        baseline_iterations = max_iterations
        break

# Calculate multipliers for each config based on baseline     
if time_normalized:  
    for config_name, iteration_runs in all_aggregated_iteration_runs.items():
        max_iterations = max_iterations_per_instance(iteration_runs, multiple_seeds=show_full_history)
        multipliers = get_iteration_multipliers(baseline_max_iterations=baseline_iterations, other_max_iterations=max_iterations)

        adjusted_history_runs = create_multiplier_based_history_runs(all_aggregated_history_runs[config_name], multipliers)
        all_geomeans_iterations[config_name] = create_geomean_over_all_instances(adjusted_history_runs, multiple_seeds=show_full_history)


# Plot all in one plot
labels_for_all = []
for config_name, geomean in modifiedCombines_to_geomeans.items():
    labels_for_all.append(config_name)
plot_time_series(*modifiedCombines_to_geomeans.values(), 
                     title="Geometric Mean KM1 over all instances",
                     labels=labels_for_all,
                     save=True,
                     save_dir=save_path,
                     log_x=False,
                     max_time=MAX_TIME
                 )

# labels_for_all = []
# for config_name, geomean in all_geomeans_iterations.items():
#     labels_for_all.append(config_name)
# plot_time_series(*all_geomeans_iterations.values(), 
#                      title="Geometric Mean KM1 over iteration normalized Time for Modified Combine Strategies",
#                      labels=labels_for_all,
#                      save=True,
#                      save_dir=save_path)

In [None]:
### Stopping Criteria Evaluation ####
CONFIGS_NAME = "2025-11-25_combined_results"
ALL_CONFIGS_DIR = os.path.expanduser(f"~/Documents/experiment_results/{CONFIGS_NAME}")
RESULTS = "mt_kahypar_evo_results"
DIFF = "evo_diff"
HISTORY = "evo_history"
ITERATION_LOG = "evo_iteration_log"
DEFAULT_POP_SIZE = 10
INSTANCES_TO_EXCLUDE = ["Pd_rhs.mtx.hgr", "wb-edu.mtx.hgr"]
show_full_history = True
REFERENCE_CONFIG = "2025-11-25_defaultCombine"

from statistics import mean
import json


history_runs = aggregate_history_runs(os.path.join(ALL_CONFIGS_DIR, REFERENCE_CONFIG, HISTORY), full_history=show_full_history, instances_to_exclude=INSTANCES_TO_EXCLUDE)
iteration_runs = aggregate_iterations_runs(os.path.join(ALL_CONFIGS_DIR, REFERENCE_CONFIG, ITERATION_LOG), show_full_history=show_full_history, instances_to_exclude=INSTANCES_TO_EXCLUDE)



# 1) Simulate "last K iterations without improvement" stopping criteria
max_k_values = [10, 20, 50, 65, 100]
stopping_criteria_stats = {k: [] for k in max_k_values}
stopping_criteria_summary = {} 

for instance_name, iteration_data in iteration_runs.items():
    timestamp_to_iteration, iteration_to_km1 = iteration_data
    if len(iteration_to_km1) < 2:
        continue

    for k in max_k_values:
        res = simulate_stop_after_k_no_improvement_iterations(
            timestamp_to_iteration, iteration_to_km1, k
        )
        stop_time, stop_iter, stop_km1, final_km1, frac_improvement, frac_time = res
        if stop_time is None:
            continue
        stopping_criteria_stats[k].append({
            "instance": instance_name,
            "stop_time": stop_time,
            "stop_iter": stop_iter,
            "stop_km1": stop_km1,
            "final_km1": final_km1,
            "frac_improvement": frac_improvement,
            "frac_time": frac_time,
        })
for k in max_k_values:
    stats = stopping_criteria_stats[k]
    if not stats:
        continue

    avg_frac_impr = mean(s["frac_improvement"] for s in stats)
    min_frac_impr = min(s["frac_improvement"] for s in stats)
    avg_frac_time = mean(s["frac_time"] for s in stats)

    print(f"K={k}: "
          f"avg frac_improvement={avg_frac_impr:.3f}, "
          f"min frac_improvement={min_frac_impr:.3f}, "
          f"avg frac_time={avg_frac_time:.3f}, "
          f"num_instances={len(stats)}")
    
    stopping_criteria_summary[k] = {
        "avg_frac_improvement": avg_frac_impr,
        "min_frac_improvement": min_frac_impr,
        "avg_frac_time": avg_frac_time,
        "num_instances": len(stats),
    }
    
    # scatter plot per instance
    points = [(s["frac_time"], s["frac_improvement"]) for s in stats]
    plot_frac_time_vs_improvement(
        points,
        title=f"K={k}: frac_time vs frac_improvement",
        save=True,
        save_dir=FIG_SAVE_DIR,
    )


# 2) Simulate "sliding window improvement rate" stopping criteria
# (early window, recent window, alpha, max iters without improv)
rate_params = [
    (5, 5, 0.1, 100),
    (5, 5, 0.05, 100),
    (5, 5, 0.1, 200),
    (5, 5, 0.05, 200),
    (5, 5, 0.1, 300),
]
rate_summary = {}
for params in rate_params:
    ew, rw, alpha, max_iters_without_improv = params
    stats_for_param = []
    for instance_name, iteration_data in iteration_runs.items():
        timestamp_to_iteration, iteration_to_km1 = iteration_data
        if len(iteration_to_km1) < 2:
            continue

        res = simulate_stop_by_improvement_rate(
            timestamp_to_iteration, iteration_to_km1,
            early_window_improvs=ew,
            recent_window_improvs=rw,
            alpha=alpha,
            max_iters_without_improv=max_iters_without_improv,
        )
        stop_time, stop_iter, stop_km1, final_km1, frac_improvement, frac_time = res
        if stop_time is None:
            continue
        stats_for_param.append({
            "instance": instance_name,
            "frac_improvement": frac_improvement,
            "frac_time": frac_time,
        })

    if stats_for_param:
        avg_frac_improvement = mean(s["frac_improvement"] for s in stats_for_param)
        min_frac_impr = min(s["frac_improvement"] for s in stats_for_param)
        avg_frac_time = mean(s["frac_time"] for s in stats_for_param)

        print(
            f"Sliding window improvements (early={ew}, recent={rw}, alpha={alpha}, "
            f"max_iters_without_improv={max_iters_without_improv}): "
            f"avg frac_improvement={avg_frac_improvement:.3f}, "
            f"min frac_improvement={min_frac_impr:.3f}, "
            f"avg frac_time={avg_frac_time:.3f}, "
            f"num_instances={len(stats_for_param)}"
        )
        rate_summary[str(params)] = {
            "early_window_improvs": ew,
            "recent_window_improvs": rw,
            "alpha": alpha,
            "max_iters_without_improv": max_iters_without_improv,
            "avg_frac_improvement": avg_frac_improvement,
            "min_frac_improvement": min_frac_impr,
            "avg_frac_time": avg_frac_time,
            "num_instances": len(stats_for_param),
        }
        points = [(s["frac_time"], s["frac_improvement"]) for s in stats_for_param]
        plot_frac_time_vs_improvement(
            points,
            title=f"Sliding window (early={ew}, recent={rw}, alpha={alpha}, max_idle={max_iters_without_improv})",
            save=True,
            save_dir=FIG_SAVE_DIR,
        )



# 3) Simulate "sliding window iteration rate" stopping criteria
# (early window iters, recent window iters, alpha)
iter_rate_params = [
    (20, 200, 0.1),
    (50, 200, 0.1),
    (15, 100, 0.05),]

rate_stats = {p: [] for p in iter_rate_params}
iter_rate_summary = {}
for instance_name, iteration_data in iteration_runs.items():
    timestamp_to_iteration, iteration_to_km1 = iteration_data
    if len(iteration_to_km1) < 2:
        continue

    for params in iter_rate_params:
        ew, rw, alpha = params
        res = simulate_stop_by_improvement_rate_iter_window(
            timestamp_to_iteration, iteration_to_km1,
            early_window_iters=ew,
            recent_window_iters=rw,
            alpha=alpha,
        )
        stop_time, stop_iter, stop_km1, final_km1, frac_improvement, frac_time = res
        if stop_time is None:
            continue
        rate_stats[params].append({
            "instance": instance_name,
            "stop_time": stop_time,
            "stop_iter": stop_iter,
            "stop_km1": stop_km1,
            "final_km1": final_km1,
            "frac_improvement": frac_improvement,
            "frac_time": frac_time,
        })
for params, stats in rate_stats.items():
    if not stats:
        continue
    ew, rw, alpha = params
    avg_frac_impr = mean(s["frac_improvement"] for s in stats)
    min_frac_impr = min(s["frac_improvement"] for s in stats)
    avg_frac_time = mean(s["frac_time"] for s in stats)

    print(f"Sliding window iterations (early={ew}, recent={rw}, alpha={alpha}): "
          f"avg frac_improvement={avg_frac_impr:.3f}, "
          f"min frac_improvement={min_frac_impr:.3f}, "
          f"avg frac_time={avg_frac_time:.3f}, "
          f"num_instances={len(stats)}")

    iter_rate_summary[str(params)] = {
        "early_window_iters": ew,
        "recent_window_iters": rw,
        "alpha": alpha,
        "avg_frac_improvement": avg_frac_impr,
        "min_frac_improvement": min_frac_impr,
        "avg_frac_time": avg_frac_time,
        "num_instances": len(stats),
    }

    points = [(s["frac_time"], s["frac_improvement"]) for s in stats]
    plot_frac_time_vs_improvement(
        points,
        title=f"Iter-window (early={ew}, recent={rw}, alpha={alpha})",
        save=True,
        save_dir=FIG_SAVE_DIR,
    )


# 4) Simulate "global tangent" stopping criteria
global_tangent_params = [
    (0.35, 50, 300),  # (alpha, min_iters_for_rate, max_iters_without_improv)
]

global_stats = {p: [] for p in global_tangent_params}
global_summary = {}
for instance_name, iteration_data in iteration_runs.items():
    timestamp_to_iteration, iteration_to_km1 = iteration_data
    if len(iteration_to_km1) < 2:
        continue

    for params in global_tangent_params:
        alpha, min_iters_for_rate, max_idle = params
        res = simulate_stop_by_global_tangent(
            timestamp_to_iteration,
            iteration_to_km1,
            alpha=alpha,
            min_iters=min_iters_for_rate,
            max_iters_without_improv=max_idle,
        )
        stop_time, stop_iter, stop_km1, final_km1, frac_improvement, frac_time = res
        if stop_time is None:
            continue
        global_stats[params].append(
            {
                "instance": instance_name,
                "stop_time": stop_time,
                "stop_iter": stop_iter,
                "stop_km1": stop_km1,
                "final_km1": final_km1,
                "frac_improvement": frac_improvement,
                "frac_time": frac_time,
            }
        )

for params, stats in global_stats.items():
    if not stats:
        continue
    alpha, min_iters_for_rate, max_idle = params
    avg_frac_impr = mean(s["frac_improvement"] for s in stats)
    min_frac_impr = min(s["frac_improvement"] for s in stats)
    avg_frac_time = mean(s["frac_time"] for s in stats)

    print(
        f"Global tangent (alpha={alpha}, min_iters_for_rate={min_iters_for_rate}, "
        f"max_iters_without_improv={max_idle}): "
        f"avg frac_improvement={avg_frac_impr:.3f}, "
        f"min frac_improvement={min_frac_impr:.3f}, "
        f"avg frac_time={avg_frac_time:.3f}, "
        f"num_instances={len(stats)}"
    )
    global_summary[str(params)] = {
        "alpha": alpha,
        "min_iters_for_rate": min_iters_for_rate,
        "max_iters_without_improv": max_idle,
        "avg_frac_improvement": avg_frac_impr,
        "min_frac_improvement": min_frac_impr,
        "avg_frac_time": avg_frac_time,
        "num_instances": len(stats),
    }
    points = [(s["frac_time"], s["frac_improvement"]) for s in stats]
    plot_frac_time_vs_improvement(
        points,
        title=f"Global tangent (alpha={alpha}, min_iters={min_iters_for_rate}, max_idle={max_idle})",
        save=True,
        save_dir=FIG_SAVE_DIR,
    )

_stats_out_dir = FIG_SAVE_DIR
_ensure_dir(_stats_out_dir)

with open(os.path.join(_stats_out_dir, "stopping_criteria_summary.json"), "w") as f:
    json.dump(stopping_criteria_summary, f, indent=2)

with open(os.path.join(_stats_out_dir, "rate_summary.json"), "w") as f:
    json.dump(rate_summary, f, indent=2)

with open(os.path.join(_stats_out_dir, "iter_rate_summary.json"), "w") as f:
    json.dump(iter_rate_summary, f, indent=2)

with open(os.path.join(_stats_out_dir, "global_summary.json"), "w") as f:
    json.dump(global_summary, f, indent=2)

