In [34]:
import nest_asyncio
nest_asyncio.apply()

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import os
import shutil
import re
import numpy as np

# Regular expression pattern to extract float values
float_pattern = r'tf\.Tensor\(([-+]?\d*\.\d+|\d+),.*\)'  

# Function to extract float value from a tensor or use float directly
def extract_float_value(value):
    match = re.match(float_pattern, str(value))
    if match:
        return float(match.group(1))
    else:
        return float(value)

def plot_multiple_csv(directory_path, column_name, window_size, start_round, end_round, task, pattern, algorithm):
    if task == 'epochs' or task =='cohorts':
        csv_files = [file for file in os.listdir(directory_path) if file.startswith(algorithm) and 'validation_data.csv' in file ]
    elif task == 'algorithms':
        csv_files = [file for file in os.listdir(directory_path) if pattern in file and 'validation_data.csv' in file ]
    # Create a figure and axis
    fig, ax = plt.subplots(figsize=(8, 6))

    # Define a color mapping dictionary
    color_mapping = {}

    # Collect legend labels and sort them
    if task == 'epochs':
        legend_labels = sorted((re.search(r'_e(\d+)_', csv_file).group(1) for csv_file in csv_files if re.search(r'_e\d+_', csv_file)), key=lambda x: int(x))    
        legend_prefix="E:"
    elif task == 'cohorts':
        legend_labels = sorted((re.search(r'_c(\d+)_', csv_file).group(1) for csv_file in csv_files if re.search(r'_c\d+_', csv_file)), key=lambda x: int(x))  
        legend_prefix="C:"
    elif task == 'algorithms':
        legend_labels = sorted((re.search(r'^([a-zA-Z]+)_', csv_file).group(1) for csv_file in csv_files if re.search(r'^[a-zA-Z]+_', csv_file)), key=str)
        legend_prefix="A:"

    for legend_label in legend_labels:
        # Find the corresponding CSV file for the current legend label
        matching_csv_file = next(csv_file for csv_file in csv_files if legend_label+"_" in csv_file)
        print(f'Legend Label: {legend_label}, Matching CSV File: {matching_csv_file}')

        data = pd.read_csv(os.path.join(directory_path, matching_csv_file), nrows=int(total_rounds / rounds_per_eval))
        output_filename = os.path.splitext(matching_csv_file)[0] + "_" + column_name + ".jpg"
        data_column = data[column_name]  
        data_column = data_column.apply(extract_float_value)

        if len(data_column) < total_rounds:
            acc_points = int(total_rounds / rounds_per_eval)
            accuracy_series = pd.Series(data_column)
        else:
            acc_points = len(data_column)
            accuracy_series = data_column

        x_values = [i * rounds_per_eval for i in range(acc_points)]
        x_series = pd.Series(x_values[:acc_points])
        moving_average = accuracy_series.rolling(window=window_size).mean()
        new_data = np.concatenate([accuracy_series.head(window_size), moving_average])
        new_data=new_data[~np.isnan(new_data)]
        moving_average=moving_average[~np.isnan(moving_average)]

        # Assign a color to the legend label if not already assigned
        if legend_label not in color_mapping:
            color_mapping[legend_label] = plt.colormaps.get_cmap('tab10')(len(color_mapping) % 10)

        line_color = color_mapping[legend_label]

        # Plot only the moving average line within the specified range
        plt.plot(x_series[start_round // rounds_per_eval:end_round // rounds_per_eval], 
                 new_data[start_round // rounds_per_eval:end_round // rounds_per_eval], 
                 label=f'{legend_prefix+legend_label}', linestyle='-', marker='o', alpha=0.7, color=line_color)

    plt.xlabel('Round')
    plt.ylabel(column_name)
    plt.title(f'Validation {column_name}')
    plt.legend(loc='lower right')
    plt.tight_layout()

    plt.savefig(target_dir + "/" + algorithm + '_'+ dataset + '_'+pattern+task+'.jpg')

    plt.show()

algorithms = ["fedAvg", "fedAdadb", "fedAdam"]
target_dir = "figures/"
pattern = "c10_e4_" # e4_ / c5_ / c10_e4_
task = 'algorithms' #cohorts / epochs / algorithms
dataset="emnist" #emnist / shakespeare
total_rounds = 2000
rounds_per_eval = 5

%rm -f $target_dir*.csv

for algorithm in algorithms:
    source_path = "results/official/"+dataset+"/"+algorithm+"/training"
    
    # List files matching the pattern in the source directory
    matching_files = [f for f in os.listdir(source_path) if pattern in f and 'validation_data.csv' in f]
    
    # Create target directory if it doesn't exist
    target_algorithm_dir = os.path.join(target_dir)
    
    # Copy matching files to the target directory
    for file_name in matching_files:
        source_file_path = os.path.join(source_path, file_name)
        target_file_path = os.path.join(target_algorithm_dir, algorithm + "_" +file_name)
        shutil.copy(source_file_path, target_file_path)

    # Usage: Specify start_round and end_round parameters
    if task != 'algorithms':
        plot_multiple_csv(target_dir, 'Accuracy', 1, 5, 250, task, pattern, algorithm)
        #plot_multiple_csv(target_dir, 'Accuracy', 5, 0, 5000)

if task == 'algorithms':
    plot_multiple_csv(target_dir, 'Accuracy', 1, 5, 250, task, pattern, 'all')


%rm -f $target_dir*.csv



In [None]:
import pandas as pd
import os
import tensorflow as tf
import glob


algorithms = ["fedAvg", "fedAdam", "fedAdadb"]
dataset="shakespeare"
total_rounds = 250
rounds_per_eval = 5
cohorts=[5,10,50]
epochs=[1,2,4,8,16]
c1,c2=0,0
th1=40 #40 shakespeare 85 emnist
th2=50 #55 shakespeare 99 emnist
last_rounds=10 # 10 shakespeare 100 emnist

# Regular expression pattern to extract float values
float_pattern = r'tf\.Tensor\(([-+]?\d*\.\d+|\d+),.*\)'  

# Function to extract float value from a tensor or use float directly
def extract_float_value(value):
    match = re.match(float_pattern, str(value))
    if match:
        return float(match.group(1))
    else:
        return float(value)

def get_avg_metrics(csv_file, column_name, metric, window_size, convergence_percentage=None):
    data = pd.read_csv(csv_file)

    # Extract data from the specified column
    data_column = data[column_name]
    data_column = data_column.apply(extract_float_value)

    if metric == "convergence":
        # Calculate the moving average line
        moving_average = data_column.rolling(window=window_size).mean()
    
        # Find the round where the moving average crosses target convergence percentage
        metrics = next((i for i, avg in enumerate(moving_average) if avg > convergence_percentage), None)
    elif metric == "avgAcc":
        # Calculate the average accuracy over window_size
        metrics = data_column.tail(window_size).mean()
    return metrics

for algorithm in algorithms:
    source_path = "results/official/"+dataset+"/"+algorithm+"/training"
    for c in cohorts:
        for e in epochs:
            pattern="c"+str(c)+"_e"+str(e)+"_"
            file_path = os.path.join(source_path, pattern + 'training_settings.txt')

            # Check if the specific file exists
            if os.path.exists(file_path):
                with open(file_path, 'r') as file:
                    # Iterate over each line in the file
                    for line in file:
                        # Split the line into key and value using ':' as delimiter
                        key, value = line.strip().split(': ')
                        # Check if the key is 'client_learning_rate'
                        if key == 'client_learning_rate':
                            # Assign the value to the variable
                            n = float(value)
                            # Exit the loop as we found the value
                        elif key == 'server_learning_rate':
                            ns = float(value)
            else:
                # Search for any matching file
                wildcard_path = os.path.join(source_path, '*' + 'training_settings.txt')
                matching_files = glob.glob(wildcard_path)
                
                if matching_files:
                    # Open the first matching file
                    with open(matching_files[0], 'r') as file:
                        # Iterate over each line in the file
                        for line in file:
                            # Split the line into key and value using ':' as delimiter
                            key, value = line.strip().split(': ')
                            # Check if the key is 'client_learning_rate'
                            if key == 'client_learning_rate':
                                # Assign the value to the variable
                                n = float(value)
                                # Exit the loop as we found the value
                            elif key == 'server_learning_rate':
                                ns = float(value)
                        
            matching_files = [f for f in os.listdir(source_path) if pattern in f and 'validation_data.csv' in f]
            for file_name in matching_files:
                c1=get_avg_metrics(source_path+"/"+file_name, 'Accuracy', 'convergence', 5, th1) 
                c2=get_avg_metrics(source_path+"/"+file_name, 'Accuracy', 'convergence', 5, th2) 
                if c1 != None:
                    c1*=5
                if c2 != None:
                    c2*=5
                data = pd.read_csv(source_path+"/"+file_name)
                # Extract the first 5000 points (adjust as needed)
                selected_data = data.iloc[:total_rounds]
                # Calculate the average of the last 500 points
                last_500_average = selected_data['Accuracy']
                last_500_average = last_500_average.apply(extract_float_value).tail(100).mean()        
                # Print or use the calculated average as needed
                #max_acc=str(last_500_average).replace('.', ',')
                max_acc=str(last_500_average)
                print(f"{algorithm},{str(e)},{str(c)},{str(ns)},{str(n)},{str(c1)},{str(c2)},{max_acc}")

            


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import re
from scipy.stats import friedmanchisquare
import scikit_posthocs as sp
import seaborn as sns


def extract_tensor_value(tensor_str):
    """
    Extracts the numeric value from a string in the format:
    "tf.Tensor(42.425056, shape=(), dtype=float32)"
    """
    match = re.search(r"tf\.Tensor\(([-+]?[0-9]*\.?[0-9]+)", tensor_str)
    return float(match.group(1)) if match else np.nan

# Define the base directory for the results
# (update path as needed)
base_dir = 'results/official/cifar100'
algorithms = ['fedAvg', 'fedAdam', 'fedAdadb']
num_runs = 5  # Number of runs per algorithm

# Initialize dictionary to collect final-accuracy results
metric_data = {'Algorithm': [], 'FinalAccuracy': []}

# Iterate through each algorithm and its runs
for algorithm in algorithms:
    runs_dir = os.path.join(base_dir, algorithm, 'training', 'runs')
    for run_number in range(1, num_runs + 1):
        file_path = os.path.join(runs_dir, str(run_number), 'c10_e4_validation_data.csv')
        try:
            data = pd.read_csv(file_path)

            # Ensure numeric accuracy
            if data['Accuracy'].dtype == object:
                data['Accuracy'] = data['Accuracy'].apply(extract_tensor_value)

            # Filter to final 100 rounds (1901-2000)
            final_rounds = data[(data['Round'] >= 1901) & (data['Round'] <= 2000)]

            if not final_rounds.empty:
                final_acc = final_rounds['Accuracy'].mean()
                metric_data['Algorithm'].append(algorithm)
                metric_data['FinalAccuracy'].append(final_acc)
            else:
                print(f"Insufficient final rounds in {file_path}")

        except FileNotFoundError:
            print(f"File missing: {file_path}")
        except Exception as e:
            print(f"Error processing {file_path}: {e}")

# Build DataFrame
df_metric = pd.DataFrame(metric_data)
print("\nFinal Model Quality (Avg. Last 100 Rounds):")
print(df_metric)

# --- Statistical Testing (Keep as is) ---
# Prepare for Friedman test
groups = [
    df_metric[df_metric['Algorithm'] == alg]['FinalAccuracy'].values[:num_runs]
    for alg in algorithms
]

# Run Friedman test if possible
if all(len(g) >= 2 for g in groups):
    # Ensure groups have the same length for Friedman test
    min_len = min(len(g) for g in groups)
    if min_len < num_runs:
        print(f"Warning: Not all algorithms have {num_runs} runs. Using {min_len} runs for testing.")
        groups = [g[:min_len] for g in groups]

    if min_len >= 2: # Check again after potential truncation
        stat, p = friedmanchisquare(*groups)
        print(f"\nFriedman stat: {stat:.3f}, p-value: {p:.3e}")
        if p < 0.05:
            print("Significant: running Nemenyi post-hoc")
            # Prepare matrix for posthoc test correctly after potential truncation
            matrix = np.vstack([g[:min_len] for g in groups]).T # Use min_len
            p_nem = sp.posthoc_nemenyi_friedman(matrix)
            p_nem.index = p_nem.columns = algorithms
            print("\nNemenyi pairwise p-values:")
            print(p_nem)
        else:
            print("No significant differences.")
    else:
        print("Need at least 2 runs per algorithm (with data) for statistical testing.")
else:
    print("Need at least 2 runs per algorithm for statistical testing.")
# --- End Statistical Testing ---


# --- Plotting and Saving ---
plt.figure(figsize=(8, 6))
sns.boxplot(x='Algorithm', y='FinalAccuracy', data=df_metric)
plt.title('Boxplot Comparison of Final Validation Accuracy (Rounds 1901–2000)')
plt.xlabel('Algorithm')
plt.ylabel('Average Accuracy')
plt.grid(True)

# Define filename and parameters for saving
# You can choose PNG (raster) or PDF (vector, often better for papers)
output_filename_png = 'boxplot_final_accuracy.png'
output_filename_pdf = 'boxplot_final_accuracy.pdf'
dpi_setting = 300 # Common requirement for papers, increase to 600 if needed

# Save the plot BEFORE showing it
# Save as PNG
plt.savefig(output_filename_png, dpi=dpi_setting, bbox_inches='tight')
print(f"\nPlot saved as high-quality PNG: {output_filename_png}")

# Save as PDF (vector format is great for scaling in papers)
plt.savefig(output_filename_pdf, bbox_inches='tight') # DPI is less relevant for vector formats like PDF
print(f"Plot saved as high-quality PDF: {output_filename_pdf}")

# Now display the plot (optional)
plt.show()


In [None]:
import pandas as pd
import numpy as np
import os
import re
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import ttest_rel

# --- Tensor extractor ---
def extract_tensor_value(tensor_str):
    match = re.search(r"tf\.Tensor\(([-+]?[0-9]*\.?[0-9]+)", tensor_str)
    return float(match.group(1)) if match else np.nan

# --- Configuration ---
base_dir = 'results/official/cifar100'
algorithms = ['fedAvg', 'fedAdam', 'fedAdadb']
num_runs = 5

# --- Load data ---
metric_data = {'Algorithm': [], 'FinalAccuracy': [], 'Run': []}
for algorithm in algorithms:
    for run_number in range(1, num_runs + 1):
        path = os.path.join(base_dir, algorithm, 'training', 'runs', str(run_number), 'c10_e4_validation_data.csv')
        try:
            df = pd.read_csv(path)
            if df['Accuracy'].dtype == object:
                df['Accuracy'] = df['Accuracy'].apply(extract_tensor_value)
            final_acc = df[(df['Round'] >= 1901) & (df['Round'] <= 2000)]['Accuracy'].mean()
            metric_data['Algorithm'].append(algorithm)
            metric_data['FinalAccuracy'].append(final_acc)
            metric_data['Run'].append(run_number)
        except:
            print(f"Missing or failed file: {path}")

df = pd.DataFrame(metric_data)
pivot_df = df.pivot(index='Run', columns='Algorithm', values='FinalAccuracy')

# --- Paired t-tests ---
print("Paired t-tests:")
for a1, a2 in [('fedAdadb', 'fedAdam'), ('fedAdadb', 'fedAvg')]:
    t_stat, p_val = ttest_rel(pivot_df[a1], pivot_df[a2])
    print(f"{a1} vs {a2}: t = {t_stat:.3f}, p = {p_val:.3e}")

# --- Boxplot of final accuracies ---
plt.figure(figsize=(8, 6))
sns.boxplot(x='Algorithm', y='FinalAccuracy', data=df)
plt.title('Final Accuracy (Avg. of Last 100 Rounds)')
plt.grid(True)
plt.tight_layout()
plt.savefig('boxplot_accuracy.png', dpi=300)
plt.show()

# --- Histogram of Differences ---
diffs = {
    'fedAdadb vs fedAdam': pivot_df['fedAdadb'] - pivot_df['fedAdam'],
    'fedAdadb vs fedAvg': pivot_df['fedAdadb'] - pivot_df['fedAvg'],
}

for label, values in diffs.items():
    plt.figure(figsize=(6, 4))
    sns.histplot(values, kde=True)
    plt.title(f'Difference in Accuracy: {label}')
    plt.xlabel('Accuracy Difference')
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f'diff_{label.replace(" ", "_").replace("vs", "vs_")}.png', dpi=300)
    plt.show()


In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import os
import re
import numpy as np

# Regular expression pattern to extract float values
float_pattern = r'tf\.Tensor\(([-+]?\d*\.\d+|\d+),.*\)'

# Function to extract float value from a tensor or use float directly
def extract_float_value(value):
    """Extracts a float value from a string potentially representing a TensorFlow Tensor."""
    # Convert value to string to handle potential non-string inputs safely
    value_str = str(value)
    match = re.match(float_pattern, value_str)
    if match:
        try:
            return float(match.group(1))
        except (ValueError, TypeError):
            # Handle cases where the extracted group is not a valid float
            print(f"Warning: Could not convert extracted value '{match.group(1)}' to float. Original value: {value_str}")
            return np.nan # Return NaN or handle appropriately
    else:
        try:
            # Attempt to convert directly if it doesn't match the tensor pattern
            return float(value_str)
        except (ValueError, TypeError):
             # Handle cases where the value is not a float or tensor string
            print(f"Warning: Could not convert value '{value_str}' to float.")
            return np.nan # Return NaN or handle appropriately

def plot_multiple_runs(base_directory, run_numbers, csv_filename, column_name, window_size, start_round, end_round, rounds_per_eval, output_filename):
    """
    Plots data from a specific column of CSV files located in different run directories.

    Args:
        base_directory (str): The base path containing the numbered 'run' subdirectories.
        run_numbers (list): A list of integers representing the run directories to plot.
        csv_filename (str): The name of the CSV file within each run directory.
        column_name (str): The name of the column to plot from the CSV files.
        window_size (int): The window size for the moving average. Use 1 for no smoothing.
        start_round (int): The starting round index for plotting.
        end_round (int): The ending round index for plotting.
        rounds_per_eval (int): The number of rounds between each evaluation point.
        output_filename (str): The path and filename to save the plot.
    """
    fig, ax = plt.subplots(figsize=(10, 7)) # Increased figure size slightly

    all_data_found = True # Flag to track if all files were found

    for run_num in run_numbers:
        # Construct the full path to the CSV file for the current run
        file_path = os.path.join(base_directory, str(run_num), csv_filename)

        try:
            # Read the CSV file
            # Use error_bad_lines=False if pandas version < 1.4.0, otherwise on_bad_lines='skip'
            # Adjust nrows if you know the maximum expected rows to optimize loading
            data = pd.read_csv(file_path) # Consider adding on_bad_lines='skip' for robustness

            # Extract the specified column
            if column_name not in data.columns:
                 print(f"Warning: Column '{column_name}' not found in {file_path}. Skipping this run.")
                 continue # Skip to the next run if column is missing

            data_column = data[column_name].copy() # Use .copy() to avoid SettingWithCopyWarning

            # Apply the function to extract float values, handling potential NaNs
            data_column = data_column.apply(extract_float_value)
            data_column = data_column.dropna() # Remove rows where float extraction failed

            if data_column.empty:
                print(f"Warning: No valid data found in column '{column_name}' for {file_path} after processing. Skipping this run.")
                continue

            # Calculate x-axis values (Rounds)
            num_points = len(data_column)
            x_values = np.arange(num_points) * rounds_per_eval

            # Calculate moving average if window_size > 1
            if window_size > 1:
                # Calculate moving average, handling the initial window
                moving_average = data_column.rolling(window=window_size, min_periods=1).mean()
                plot_data = moving_average
            else:
                # Plot raw data if window_size is 1 or less
                plot_data = data_column

            # Determine indices for slicing based on rounds
            start_index = max(0, start_round // rounds_per_eval)
            # Ensure end_index doesn't exceed the number of points
            end_index = min(num_points, (end_round + rounds_per_eval -1) // rounds_per_eval ) # + rounds_per_eval -1 ensures inclusion


            # Plot the data for the specified range
            ax.plot(x_values[start_index:end_index],
                    plot_data.iloc[start_index:end_index],
                    label=f'Run {run_num}', alpha=0.8) # Use iloc for positional indexing

        except FileNotFoundError:
            print(f"Warning: File not found at {file_path}. Skipping this run.")
            all_data_found = False
        except pd.errors.EmptyDataError:
             print(f"Warning: File is empty at {file_path}. Skipping this run.")
        except Exception as e:
            print(f"An error occurred while processing {file_path}: {e}")
            all_data_found = False


    # Add plot labels, title, and legend
    ax.set_xlabel('Round')
    ax.set_ylabel(column_name)
    ax.set_title(f'Validation {column_name} Across Runs')
    ax.legend(loc='best') # Changed to 'best' for potentially better placement
    ax.grid(True, linestyle='--', alpha=0.6) # Added grid for readability
    plt.tight_layout()

    # Save the plot
    try:
        plt.savefig(output_filename)
        print(f"Plot saved to {output_filename}")
    except Exception as e:
        print(f"Error saving plot: {e}")


    # Display the plot
    plt.show()

    if not all_data_found:
        print("Note: Some run data files were not found or could not be processed.")

# --- Configuration ---
base_dir = "results/official/shakespeare/fedAdadb/training/runs/"
run_ids = [1, 2, 3, 4, 5]
csv_file = "c10_e4_validation_data_new.csv"
col_to_plot = 'Accuracy'
smoothing_window = 1  # Set to 1 to plot raw data points, >1 for moving average
plot_start_round = 0   # Start plotting from this round
plot_end_round = 2000  # Stop plotting at this round (inclusive if data exists)
eval_frequency = 5     # Rounds per evaluation point in the CSV
output_plot_file = "figures/shakespeare_fedAvg_runs_accuracy.jpg"

# --- Create output directory if it doesn't exist ---
output_dir = os.path.dirname(output_plot_file)
if output_dir and not os.path.exists(output_dir):
    os.makedirs(output_dir)
    print(f"Created output directory: {output_dir}")


# --- Generate the plot ---
plot_multiple_runs(
    base_directory=base_dir,
    run_numbers=run_ids,
    csv_filename=csv_file,
    column_name=col_to_plot,
    window_size=smoothing_window,
    start_round=plot_start_round,
    end_round=plot_end_round,
    rounds_per_eval=eval_frequency,
    output_filename=output_plot_file
)

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import os
import re
import numpy as np
from typing import List, Optional

# Regular expression pattern to extract float values (consistent with your original)
float_pattern = r'tf\.Tensor\(([-+]?\d*\.\d+|\d+),.*\)'

# Function to extract float value from a tensor or use float directly (consistent with your original)
def extract_float_value(value):
    """Extracts a float value from a string potentially representing a TensorFlow Tensor."""
    # Convert value to string to handle potential non-string inputs safely
    value_str = str(value)
    match = re.match(float_pattern, value_str)
    if match:
        try:
            return float(match.group(1))
        except (ValueError, TypeError):
            # Handle cases where the extracted group is not a valid float
            print(f"Warning: Could not convert extracted value '{match.group(1)}' to float. Original value: {value_str}")
            return np.nan # Return NaN or handle appropriately
    else:
        try:
            # Attempt to convert directly if it doesn't match the tensor pattern
            return float(value_str)
        except (ValueError, TypeError):
             # Handle cases where the value is not a float or tensor string
            # print(f"Warning: Could not convert value '{value_str}' to float.") # Optional: Reduce verbosity
            return np.nan # Return NaN or handle appropriately

def plot_algorithm_comparison(
    datasets: List[str],
    algorithms: List[str],
    run_numbers: List[int],
    base_results_dir: str = "results/official",
    csv_filename_template: str = "c10_e4_validation_data_new.csv", # Use a template or ensure consistency
    column_name: str = 'Accuracy',
    window_size: int = 1,
    start_round: int = 0,
    end_round: Optional[int] = None, # Allow plotting to the very end if None
    rounds_per_eval: int = 5,
    output_base_dir: str = "figures/comparisons"
):
    """
    Plots a comparison of multiple algorithms across multiple runs for specified datasets.

    Shows mean performance with standard deviation shading.

    Args:
        datasets (List[str]): List of dataset names (e.g., ["emnist", "shakespeare"]).
        algorithms (List[str]): List of algorithm names (e.g., ["fedAvg", "fedAdam"]).
        run_numbers (List[int]): List of run numbers to aggregate for each algorithm.
        base_results_dir (str): Base path containing dataset directories.
        csv_filename_template (str): Filename of the CSV within each run directory.
                                     Ensure this is consistent or adapt logic if needed.
        column_name (str): The name of the column to plot.
        window_size (int): Window size for moving average smoothing applied to the mean curve. Use 1 for no smoothing.
        start_round (int): The starting communication round for plotting.
        end_round (Optional[int]): The ending communication round for plotting. If None, plots all available data.
        rounds_per_eval (int): Number of rounds between each evaluation point in the CSV.
        output_base_dir (str): Base directory where comparison plots will be saved.
    """
    # --- Create base output directory if it doesn't exist ---
    if output_base_dir and not os.path.exists(output_base_dir):
        os.makedirs(output_base_dir)
        print(f"Created base output directory: {output_base_dir}")

    # Loop through each dataset specified
    for dataset in datasets:
        print(f"\n--- Processing Dataset: {dataset} ---")
        fig, ax = plt.subplots(figsize=(12, 8)) # Larger figure for clarity
        results_found_for_dataset = False # Flag if any data is plotted for this dataset

        # Define a color cycle for consistency across algorithms
        # Using tab10 which is good for distinct colors
        colors = plt.cm.tab10.colors[:len(algorithms)]

        # Loop through each algorithm to plot
        for i, algorithm in enumerate(algorithms):
            print(f" Processing Algorithm: {algorithm}")
            # Construct the base path for the specific algorithm and dataset
            algo_base_path = os.path.join(base_results_dir, dataset, algorithm, "training", "runs")
            run_data_list = []
            max_points_algo = 0 # Track max data points for this algorithm's runs

            # --- Aggregate data across runs for the current algorithm ---
            for run_num in run_numbers:
                # Assuming csv_filename is consistent, otherwise adapt path construction
                file_path = os.path.join(algo_base_path, str(run_num), csv_filename_template)
                try:
                    # Use error handling for robust reading
                    data = pd.read_csv(file_path, on_bad_lines='skip')
                    if column_name not in data.columns:
                        print(f" Warning: Column '{column_name}' missing in {file_path}. Skipping run {run_num}.")
                        continue

                    data_column = data[column_name].copy()
                    data_column = data_column.apply(extract_float_value)
                    # Drop rows where extraction failed or original value was NaN
                    data_column = data_column.dropna()

                    if data_column.empty:
                        print(f" Warning: No valid data in column '{column_name}' for {file_path}. Skipping run {run_num}.")
                        continue

                    # Store the cleaned data series, reset index for easier alignment later
                    run_data_list.append(data_column.reset_index(drop=True))
                    max_points_algo = max(max_points_algo, len(data_column))

                except FileNotFoundError:
                    print(f" Warning: File not found {file_path}. Skipping run {run_num}.")
                except pd.errors.EmptyDataError:
                    print(f" Warning: File empty {file_path}. Skipping run {run_num}.")
                except Exception as e:
                    print(f" Error processing {file_path}: {e}. Skipping run {run_num}.")

            # Check if any valid data was collected for this algorithm
            if not run_data_list:
                print(f" Error: No valid data found for any specified run of algorithm '{algorithm}'. Skipping this algorithm for dataset '{dataset}'.")
                continue # Skip to the next algorithm

            # --- Data Alignment and Calculation ---
            # Create a DataFrame with runs as columns, index as evaluation step
            # Aligns series based on index, shorter runs will be padded with NaN
            aligned_data = pd.concat(run_data_list, axis=1)
            # Rename columns for clarity (optional)
            aligned_data.columns = [f'run_{run_num}' for run_num in run_numbers[:len(aligned_data.columns)]] # Handle cases where some runs failed

            # Calculate mean and std dev across runs (axis=1), skipping NaNs
            mean_values = aligned_data.mean(axis=1, skipna=True)
            std_values = aligned_data.std(axis=1, skipna=True)

            # Apply moving average smoothing to the MEAN curve if window_size > 1
            if window_size > 1:
                # Use min_periods=1 to get rolling mean even at the beginning
                plot_mean = mean_values.rolling(window=window_size, min_periods=1).mean()
                # Standard deviation is typically NOT smoothed, or requires careful consideration
                plot_std = std_values
            else:
                plot_mean = mean_values
                plot_std = std_values

            # --- Prepare data for plotting (slicing based on rounds) ---
            num_points = len(plot_mean)
            x_values = np.arange(num_points) * rounds_per_eval

            # Determine indices for slicing based on rounds
            start_index = 0
            if start_round > 0:
                 # Find the first index where round >= start_round
                 start_indices = np.where(x_values >= start_round)[0]
                 if len(start_indices) > 0:
                     start_index = start_indices[0]

            end_index = num_points # Default to end of data
            if end_round is not None:
                 # Find the first index where round > end_round
                 end_indices = np.where(x_values > end_round)[0]
                 if len(end_indices) > 0:
                     end_index = end_indices[0] # Slice up to this index (exclusive)


            # Slice data for the plot
            x_plot = x_values[start_index:end_index]
            mean_plot = plot_mean.iloc[start_index:end_index]
            std_plot = plot_std.iloc[start_index:end_index] # Ensure std dev is sliced consistently

            # Check if there's actually data to plot in the specified range
            if len(x_plot) == 0:
                 print(f" Warning: No data points available for algorithm '{algorithm}' in the specified round range ({start_round}-{end_round}). Skipping plot segment.")
                 continue

            # --- Plotting for the current algorithm ---
            color = colors[i]
            ax.plot(x_plot, mean_plot, label=f'{algorithm}', color=color, linewidth=2.5)
            # Fill between mean +/- std deviation
            ax.fill_between(x_plot,
                            (mean_plot - std_plot).clip(lower=0), # Prevent std dev going below 0 visually
                            mean_plot + std_plot,
                            color=color, alpha=0.25) # Lighter alpha for shading
            results_found_for_dataset = True # Mark that we plotted something

        # --- Final Plot Configuration for the dataset ---
        if not results_found_for_dataset:
            print(f"Error: No data could be plotted for dataset '{dataset}'. Skipping plot generation.")
            plt.close(fig) # Close the empty figure
            continue

        ax.set_xlabel('Communication Round', fontsize=14)
        ax.set_ylabel(f'Validation {column_name}', fontsize=14)
        ax.set_title(f'{dataset.upper()} Validation {column_name} Comparison', fontsize=16)
        ax.legend(loc='best', fontsize=12)
        ax.grid(True, linestyle='--', alpha=0.6)
        # --- Adjust Y-axis Limits ---
        # Get the automatic limits set by matplotlib after plotting all data
        current_bottom, current_top = ax.get_ylim()

        # Check if the column name suggests a percentage scale
        is_percentage_metric = 'accuracy' in column_name.lower() or 'percent' in column_name.lower()

        if is_percentage_metric:
            print(f"Adjusting y-axis for percentage scale (0-100) for metric: {column_name}")
            # For metrics scaled 0-100:
            # Ensure bottom is at 0 (or slightly below if data dips negative due to noise/smoothing)
            plot_bottom = max(0, current_bottom) if current_bottom < 10 else 0 # Handle potential noise making min slightly negative
            plot_bottom = 0 # Usually best to fix bottom at 0 for percentage

            # Ensure top is slightly above 100 or the max data point, whichever is higher
            plot_top = max(101, current_top * 1.02) # Ensure at least 101, give 2% margin above max data point
            plot_top = min(plot_top, 105) # Optionally cap the view slightly above 100, e.g. 105

            ax.set_ylim(bottom=plot_bottom, top=plot_top)

        elif 'loss' in column_name.lower():
            # For loss metrics, ensure the bottom is at or very close to 0
            print(f"Adjusting y-axis for loss scale for metric: {column_name}")
            plot_bottom = max(0, current_bottom) if current_bottom < 1 else 0 # Allow slightly negative bottom only if necessary due to noise
            ax.set_ylim(bottom=plot_bottom)
            # Let the top auto-scale for loss, maybe add a small margin
            # ax.set_ylim(bottom=plot_bottom, top=current_top * 1.05)

        # else: For other metrics, we might keep the default auto-scaling
        # print(f"Using default auto-scaling for y-axis for metric: {column_name}")

        ax.tick_params(axis='both', which='major', labelsize=12)
        plt.tight_layout() # Adjust layout AFTER setting ylim

        ax.tick_params(axis='both', which='major', labelsize=12)
        plt.tight_layout() # Adjust layout

        # --- Save the plot ---
        # Generate a descriptive filename
        output_filename = os.path.join(output_base_dir, f"{dataset}_algos_comparison_{column_name.lower().replace(' ', '_')}.jpg")

        try:
            plt.savefig(output_filename, dpi=300, bbox_inches='tight') # Use tight bounding box
            print(f" Plot saved to {output_filename}")
        except Exception as e:
            print(f" Error saving plot {output_filename}: {e}")

        plt.show() # Display the plot for the current dataset
        plt.close(fig) # Close figure before processing next dataset

# --- Configuration ---
datasets_to_plot = ["emnist", "shakespeare", "cifar100"] # Select datasets
# datasets_to_plot = ["emnist"] # Example: Just plot for EMNIST

algorithms_to_plot = ["fedAvg", "fedAdam", "fedAdadb"] # Select algorithms
run_ids_to_process = [1, 2, 3, 4, 5] # Specify run numbers to average over

base_results = "results/official" # Path containing dataset folders

# *** IMPORTANT: Verify this CSV filename exists and is correct for ALL datasets/algorithms ***
# If filenames differ, you'll need more complex logic to determine the filename per dataset/algo
common_csv_file = "c10_e4_validation_data.csv"
col_to_plot = 'Accuracy'

smoothing_window = 5      # Window size for smoothing the MEAN curve (e.g., 5). 1 = no smoothing.
plot_start_round = 0      # Start plotting from this communication round
plot_end_round = 2000     # Stop plotting at this communication round (set to None to plot all data)
eval_frequency = 5        # IMPORTANT: Ensure this matches the frequency used during experiments

output_plot_dir = "figures/comparisons" # Directory to save comparison plots

# --- Generate the plots ---
plot_algorithm_comparison(
    datasets=datasets_to_plot,
    algorithms=algorithms_to_plot,
    run_numbers=run_ids_to_process,
    base_results_dir=base_results,
    csv_filename_template=common_csv_file, # Pass template name
    column_name=col_to_plot,
    window_size=smoothing_window,
    start_round=plot_start_round,
    end_round=plot_end_round,
    rounds_per_eval=eval_frequency,
    output_base_dir=output_plot_dir
)

print("\n--- Plotting script finished ---")

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import os
import re
import numpy as np
from typing import List, Optional, Dict, Any

# (Keep extract_float_value function as is)
# ... (extract_float_value function from previous code) ...
def extract_float_value(value):
    """Extracts a float value from a string potentially representing a TensorFlow Tensor."""
    # Convert value to string to handle potential non-string inputs safely
    value_str = str(value)
    match = re.match(float_pattern, value_str)
    if match:
        try:
            return float(match.group(1))
        except (ValueError, TypeError):
            # Handle cases where the extracted group is not a valid float
            print(f"Warning: Could not convert extracted value '{match.group(1)}' to float. Original value: {value_str}")
            return np.nan # Return NaN or handle appropriately
    else:
        try:
            # Attempt to convert directly if it doesn't match the tensor pattern
            return float(value_str)
        except (ValueError, TypeError):
             # Handle cases where the value is not a float or tensor string
            # print(f"Warning: Could not convert value '{value_str}' to float.") # Optional: Reduce verbosity
            return np.nan # Return NaN or handle appropriately

def calculate_metrics(
    mean_values: pd.Series,
    x_values: np.ndarray,
    thresholds_to_check: List[float] = [75.0, 85.0],
    consistency_window: int = 5 # Number of consecutive points (including current) >= threshold
) -> Dict[float, float]:
    """Calculates the rounds required to consistently reach specified thresholds."""
    rounds_achieved = {thr: np.inf for thr in thresholds_to_check}

    if mean_values.empty:
        return rounds_achieved

    # Use rolling window to check for consistency
    rolling_min = mean_values.rolling(window=consistency_window, min_periods=consistency_window).min()

    for threshold in thresholds_to_check:
        # Find first index where the rolling minimum meets or exceeds the threshold
        # This ensures 'consistency_window' consecutive points were >= threshold
        indices = np.where(rolling_min >= threshold)[0]

        if len(indices) > 0:
            first_consistent_index = indices[0]
            # The round corresponds to the *end* of the first consistent window
            if first_consistent_index < len(x_values):
                 rounds_achieved[threshold] = x_values[first_consistent_index]

    return rounds_achieved


def plot_algorithm_comparison(
    datasets: List[str],
    algorithms: List[str],
    run_numbers: List[int],
    base_results_dir: str = "results/official",
    csv_filename_template: str = "c10_e4_validation_data_new.csv",
    column_name: str = 'Accuracy',
    window_size: int = 1,
    start_round: int = 0,
    end_round: Optional[int] = None,
    rounds_per_eval: int = 5,
    output_base_dir: str = "figures/comparisons",
    metric_thresholds: List[float] = [75.0, 85.0], # Thresholds for Metric 2
    post_avg_threshold: float = 85.0, # Threshold for Metric 3
    metric_consistency_window: int = 5 # Window for Metric 2 consistency check
):
    """
    Plots algorithm comparison, calculates, and prints performance metrics.
    (Args documentation mostly unchanged, added metric args)
    """
    # --- Create base output directory ---
    if output_base_dir and not os.path.exists(output_base_dir):
        os.makedirs(output_base_dir)
        print(f"Created base output directory: {output_base_dir}")

    all_datasets_metrics = {} # Store metrics for all datasets

    # Loop through each dataset
    for dataset in datasets:
        print(f"\n--- Processing Dataset: {dataset} ---")
        fig, ax = plt.subplots(figsize=(12, 8))
        results_found_for_dataset = False
        colors = plt.cm.tab10.colors[:len(algorithms)]

        # --- Data storage for metrics calculation ---
        metrics_for_current_dataset: Dict[str, Dict[str, Any]] = {}
        mean_data_per_algo: Dict[str, pd.Series] = {}
        x_values_per_algo: Dict[str, np.ndarray] = {}
        max_round_in_dataset = 0 # Track the max round based on data length

        # Loop through each algorithm
        for i, algorithm in enumerate(algorithms):
            print(f" Processing Algorithm: {algorithm}")
            algo_base_path = os.path.join(base_results_dir, dataset, algorithm, "training", "runs")
            run_data_list = []
            # ... (Data aggregation loop - same as before) ...
            for run_num in run_numbers:
                # Assuming csv_filename is consistent, otherwise adapt path construction
                file_path = os.path.join(algo_base_path, str(run_num), csv_filename_template)
                try:
                    # Use error handling for robust reading
                    data = pd.read_csv(file_path, on_bad_lines='skip')
                    if column_name not in data.columns:
                        print(f" Warning: Column '{column_name}' missing in {file_path}. Skipping run {run_num}.")
                        continue

                    data_column = data[column_name].copy()
                    data_column = data_column.apply(extract_float_value)
                    # Drop rows where extraction failed or original value was NaN
                    data_column = data_column.dropna()

                    if data_column.empty:
                        print(f" Warning: No valid data in column '{column_name}' for {file_path}. Skipping run {run_num}.")
                        continue

                    # Store the cleaned data series, reset index for easier alignment later
                    run_data_list.append(data_column.reset_index(drop=True))
                    #max_points_algo = max(max_points_algo, len(data_column)) # Not needed with concat

                except FileNotFoundError:
                    print(f" Warning: File not found {file_path}. Skipping run {run_num}.")
                except pd.errors.EmptyDataError:
                    print(f" Warning: File empty {file_path}. Skipping run {run_num}.")
                except Exception as e:
                    print(f" Error processing {file_path}: {e}. Skipping run {run_num}.")


            if not run_data_list:
                print(f" Error: No valid data found for any specified run of algorithm '{algorithm}'. Skipping this algorithm for dataset '{dataset}'.")
                metrics_for_current_dataset[algorithm] = { # Initialize metrics entry even if skipped
                    'rounds_to_threshold': {thr: np.inf for thr in metric_thresholds},
                    'post_threshold_avg_acc': np.nan
                 }
                continue

            # --- Data Alignment and Calculation ---
            aligned_data = pd.concat(run_data_list, axis=1)
            aligned_data.columns = [f'run_{r}' for r in run_numbers[:len(aligned_data.columns)]]

            # Calculate mean and std dev (used for plotting AND metrics)
            # Use mean_values *before* smoothing for metric calculations
            mean_values = aligned_data.mean(axis=1, skipna=True)
            std_values = aligned_data.std(axis=1, skipna=True)

            # Calculate x-axis values (Rounds)
            num_points = len(mean_values)
            x_values = np.arange(num_points) * rounds_per_eval
            if len(x_values) > 0 :
                 max_round_in_dataset = max(max_round_in_dataset, x_values[-1])

            # Store data needed for cross-algorithm metrics
            mean_data_per_algo[algorithm] = mean_values
            x_values_per_algo[algorithm] = x_values

            # --- Calculate Metric 2: Rounds to Threshold ---
            rounds_achieved = calculate_metrics(
                mean_values, x_values, metric_thresholds, metric_consistency_window
            )
            metrics_for_current_dataset[algorithm] = {
                'rounds_to_threshold': rounds_achieved,
                'post_threshold_avg_acc': np.nan # Placeholder for Metric 3
            }

            # --- Smoothing for Plotting ---
            if window_size > 1:
                plot_mean = mean_values.rolling(window=window_size, min_periods=1).mean()
                plot_std = std_values # Std dev not smoothed
            else:
                plot_mean = mean_values
                plot_std = std_values

            # --- Slicing for Plotting ---
            # (Slicing logic - same as before) ...
            start_index = 0
            if start_round > 0:
                 start_indices = np.where(x_values >= start_round)[0]
                 if len(start_indices) > 0: start_index = start_indices[0]
            effective_end_round = end_round if end_round is not None else max_round_in_dataset
            end_index = num_points
            if end_round is not None:
                 end_indices = np.where(x_values > effective_end_round)[0]
                 if len(end_indices) > 0: end_index = end_indices[0]

            x_plot = x_values[start_index:end_index]
            mean_plot = plot_mean.iloc[start_index:end_index]
            std_plot = plot_std.iloc[start_index:end_index]

            if len(x_plot) == 0:
                 print(f" Warning: No data points available for algorithm '{algorithm}' in the plot range. Skipping plot segment.")
                 continue

            # --- Plotting ---
            # (Plotting code - same as before) ...
            color = colors[i]
            ax.plot(x_plot, mean_plot, label=f'{algorithm}', color=color, linewidth=2.5)
            ax.fill_between(x_plot,
                            (mean_plot - std_plot).clip(lower=0), # Prevent std dev going below 0 visually
                            mean_plot + std_plot,
                            color=color, alpha=0.15) # Lighter alpha for shading
            results_found_for_dataset = True # Mark that we plotted something


        # --- Calculate Metric 3: Post-Threshold Average Accuracy ---
        # This must be done after looping through all algorithms for the dataset
        last_algo_reach_round = 0
        all_algos_reached_metric3_thr = True
        algos_reaching_thr = []

        for algo, metrics in metrics_for_current_dataset.items():
             if algo not in mean_data_per_algo: # Skip if algo had no data
                 all_algos_reached_metric3_thr = False
                 continue # Skip algos that failed entirely

             round_reached = metrics['rounds_to_threshold'].get(post_avg_threshold, np.inf)
             if round_reached == np.inf:
                 all_algos_reached_metric3_thr = False
                 # Decide: Calculate Metric 3 only for those who reached? Or only if ALL reached?
                 # Let's calculate only if ALL algorithms *that had data* reached the threshold.
                 print(f" Info: Algorithm '{algo}' did not reach {post_avg_threshold}% threshold for Metric 3 calculation.")
                 # Break? Or continue to find max round among those who did?
                 # If we need ALL algos, we set flag and might skip calculation later.
             else:
                 last_algo_reach_round = max(last_algo_reach_round, round_reached)
                 algos_reaching_thr.append(algo)

        if not all_algos_reached_metric3_thr or not algos_reaching_thr:
             print(f" Info: Metric 3 (Post-{post_avg_threshold}% Avg Acc) cannot be calculated because not all algorithms reached the threshold.")
        else:
             print(f" Info: Calculating Metric 3 starting after round {last_algo_reach_round} (last algo hit {post_avg_threshold}%)")
             # Find start index for metric calculation (round *after* the last one reached)
             # Need a reference x_values, assume they are similar length or use longest? Use per-algo.
             for algo in algos_reaching_thr:
                  mean_data = mean_data_per_algo[algo]
                  x_vals = x_values_per_algo[algo]
                  metric3_start_indices = np.where(x_vals > last_algo_reach_round)[0]

                  if len(metric3_start_indices) > 0:
                      metric3_start_index = metric3_start_indices[0]
                      # Slice the mean data from this index onwards
                      post_threshold_data = mean_data.iloc[metric3_start_index:]
                      if not post_threshold_data.empty:
                          avg_acc = post_threshold_data.mean(skipna=True)
                          metrics_for_current_dataset[algo]['post_threshold_avg_acc'] = avg_acc
                      else:
                           metrics_for_current_dataset[algo]['post_threshold_avg_acc'] = np.nan # No data after start round
                  else: # Reached threshold but no data points exist after last_algo_reach_round
                      metrics_for_current_dataset[algo]['post_threshold_avg_acc'] = np.nan


        # --- Final Plot Configuration ---
        # (Plot labels, title, grid - same as before) ...
        ax.set_xlabel('Communication Round', fontsize=14)
        ax.set_ylabel(f'Validation {column_name}', fontsize=14)
        ax.set_title(f'{dataset.upper()} Validation {column_name} Comparison', fontsize=16)
        ax.legend(loc='best', fontsize=12)
        ax.grid(True, linestyle='--', alpha=0.6)

        # (Y-axis limit adjustments - same as before) ...
        current_bottom, current_top = ax.get_ylim() # Get limits after plotting all data
        is_percentage_metric = 'accuracy' in column_name.lower() or 'percent' in column_name.lower()
        if is_percentage_metric:
            #print(f"Adjusting y-axis for percentage scale (0-100) for metric: {column_name}")
            plot_bottom = 0
            plot_top = min(105, max(101, current_top * 1.02))
            ax.set_ylim(bottom=plot_bottom, top=plot_top)
        elif 'loss' in column_name.lower():
            #print(f"Adjusting y-axis for loss scale for metric: {column_name}")
            plot_bottom = max(0, current_bottom) if current_bottom < 1 else 0
            ax.set_ylim(bottom=plot_bottom)

        ax.tick_params(axis='both', which='major', labelsize=12)
        plt.tight_layout()

        # --- Save Plot ---
        # (Saving code - same as before) ...
        if results_found_for_dataset: # Only save if something was plotted
             output_filename = os.path.join(output_base_dir, f"{dataset}_algos_comparison_{column_name.lower().replace(' ', '_')}.jpg")
             try:
                 plt.savefig(output_filename, dpi=300, bbox_inches='tight')
                 print(f" Plot saved to {output_filename}")
             except Exception as e:
                 print(f" Error saving plot {output_filename}: {e}")
        else:
             print(f"Error: No data could be plotted for dataset '{dataset}'. Skipping plot generation.")


        plt.show()
        plt.close(fig)

        # --- Print Calculated Metrics ---
        print(f"\n--- Calculated Metrics for Dataset: {dataset} ---")
        # Use the collected metrics_for_current_dataset dictionary
        metrics_df = pd.DataFrame.from_dict(metrics_for_current_dataset, orient='index')

        # Create columns for display
        display_cols = []
        for thr in metric_thresholds:
             col_name = f'Rounds to {thr}%'
             metrics_df[col_name] = metrics_df['rounds_to_threshold'].apply(lambda x: x.get(thr, np.inf))
             display_cols.append(col_name)

        metric3_col_name = f'Avg Acc Post-{post_avg_threshold}%'
        metrics_df[metric3_col_name] = metrics_df['post_threshold_avg_acc']
        display_cols.append(metric3_col_name)

        # Format for printing
        display_df = metrics_df[display_cols].copy()
        report_end_round = end_round if end_round is not None else max_round_in_dataset

        for col in display_df.columns:
             if 'Rounds' in col:
                 display_df[col] = display_df[col].apply(
                     lambda x: f">{report_end_round}" if x == np.inf else (int(x) if pd.notna(x) else "N/A")
                 )
             elif 'Avg Acc' in col:
                 display_df[col] = display_df[col].apply(
                     lambda x: f"{x:.2f}%" if pd.notna(x) else "N/A"
                 )

        # Ensure index name is clear
        display_df.index.name = 'Algorithm'
        print(display_df)
        print("-" * (len(display_df.columns)*15)) # Adjust separator width


        # Store for potential overall summary later
        all_datasets_metrics[dataset] = metrics_for_current_dataset


    return all_datasets_metrics # Optionally return all metrics

# --- Configuration ---
datasets_to_plot = ["emnist", "shakespeare", "cifar100"] # Select datasets
algorithms_to_plot = ["fedAvg", "fedAdam", "fedAdadb"] # Select algorithms
run_ids_to_process = [1, 2, 3, 4, 5] # Specify run numbers to average over
base_results = "results/official" # Path containing dataset folders
common_csv_file = "c10_e4_validation_data.csv" # *** VERIFY FILENAME CONSISTENCY ***
col_to_plot = 'Accuracy'
smoothing_window = 5      # Window size for smoothing the MEAN curve (e.g., 5). 1 = no smoothing.
plot_start_round = 0      # Start plotting from this communication round
plot_end_round = 2000     # Stop plotting at this communication round (set to None to plot all data)
eval_frequency = 5        # *** IMPORTANT: Ensure this matches the frequency used during experiments ***
output_plot_dir = "figures/comparisons" # Directory to save comparison plots

# --- Metric Calculation Configuration ---
metrics_thresholds_list = [35.0, 45.0] # Thresholds for Metric 2 (e.g., 75%, 85%)
metric3_threshold = 45.0              # Threshold for starting Metric 3 calculation (e.g., 85%)
metric_consistency = 3                # How many consecutive points needed >= threshold for Metric 2 (e.g., 3 -> current + 2 previous)

# --- Generate the plots and calculate metrics ---
calculated_metrics = plot_algorithm_comparison(
    datasets=datasets_to_plot,
    algorithms=algorithms_to_plot,
    run_numbers=run_ids_to_process,
    base_results_dir=base_results,
    csv_filename_template=common_csv_file,
    column_name=col_to_plot,
    window_size=smoothing_window,
    start_round=plot_start_round,
    end_round=plot_end_round,
    rounds_per_eval=eval_frequency,
    output_base_dir=output_plot_dir,
    metric_thresholds=metrics_thresholds_list,
    post_avg_threshold=metric3_threshold,
    metric_consistency_window=metric_consistency
)

print("\n--- Plotting and Metrics Script Finished ---")


In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
from matplotlib import colors

server_learning_rates = np.logspace(-3, 1, num=9)
client_learning_rates = np.logspace(-3, 1, num=9)

def plot_heatmap_from_csv(csv_file, server_learning_rates, client_learning_rates,
                          results_path, dpi=300, font_size=8):
    # Load the results matrix from the CSV file
    results = np.genfromtxt(csv_file, delimiter=',')
    basename = os.path.splitext(os.path.basename(csv_file))[0]
    output_file = os.path.join(results_path, basename + ".jpg")

    # Format tick labels
    s_labels = [f"{lr:.3f}" for lr in server_learning_rates]
    c_labels = [f"{lr:.3f}" for lr in client_learning_rates]

    # Set up figure
    fig, ax = plt.subplots(figsize=(8, 7))
    im = ax.pcolormesh(results, cmap='YlGn', shading='auto')
    fig.colorbar(im, ax=ax, label='Accuracy')

    # Ticks
    ax.set_xticks(np.arange(len(s_labels)) + 0.5)
    ax.set_yticks(np.arange(len(c_labels)) + 0.5)
    ax.set_xticklabels(s_labels, rotation=45, ha='right')
    ax.set_yticklabels(c_labels)
    ax.set_xlabel('Client Learning Rates')
    ax.set_ylabel('Server Learning Rates')
    ax.set_title('Validation Accuracy Grid')

    # Normalize for color‐to‐RGBA lookup
    norm = colors.Normalize(vmin=np.nanmin(results), vmax=np.nanmax(results))
    cmap = plt.get_cmap('YlGn')

    # Annotate each cell
    for i in range(results.shape[0]):
        for j in range(results.shape[1]):
            val = results[i, j]
            rgba = cmap(norm(val))
            # Perceived brightness formula
            brightness = 0.299*rgba[0] + 0.587*rgba[1] + 0.114*rgba[2]
            text_color = 'white' if brightness < 0.5 else 'black'
            ax.text(j + 0.5, i + 0.5, f"{val:.2f}",
                    ha='center', va='center',
                    color=text_color,
                    fontsize=font_size,
                    fontweight='bold')

    # Tight layout so labels don’t get cut off
    plt.tight_layout()

    # Save at high resolution
    plt.savefig(output_file, dpi=dpi, bbox_inches='tight')
    plt.show()

results_path='results/official/emnist/fedAdadb/tuning'
plot_heatmap_from_csv('results/official/emnist/fedAdadb/tuning/c5_e1_tuning_data.csv', server_learning_rates, client_learning_rates, results_path)