In [4]:
import os
import sys
import re
import matplotlib.pyplot as plt
import matplotlib
import numpy as np

# Configure matplotlib for consistent, large font sizes
matplotlib.rc('xtick', labelsize=17)
matplotlib.rc('ytick', labelsize=17)

def parse_log(file_name):
    """
    Parse log file and extract experimental metrics.
    
    Args:
        file_name (str): Path to the log file to parse
    
    Returns:
        Tuple of lists containing rounds, gradient similarities, losses, and accuracies
    """
    rounds = []
    accu = []
    loss = []
    sim = []

    try:
        with open(file_name, 'r') as file:
            for line in file:
                # Parse training rounds
                search_train_accu = re.search(r'At round (.*) training accuracy: (.*)', line, re.M | re.I)
                if search_train_accu:
                    rounds.append(int(search_train_accu.group(1)))
                
                # Parse test accuracy
                search_test_accu = re.search(r'At round (.*) accuracy: (.*)', line, re.M | re.I)
                if search_test_accu:
                    accu.append(float(search_test_accu.group(2)))
                
                # Parse training loss
                search_loss = re.search(r'At round (.*) training loss: (.*)', line, re.M | re.I)
                if search_loss:
                    loss.append(float(search_loss.group(2)))
                
                # Parse gradient difference
                search_sim = re.search(r'gradient difference: (.*)', line, re.M | re.I)
                if search_sim:
                    sim.append(float(search_sim.group(1)))

    except FileNotFoundError:
        print(f"Warning: File {file_name} not found.")
        return [], [], [], []

    return rounds, sim, loss, accu

def plot_experiments(metric='loss', log_types=None):
    """
    Create visualization for multiple experiments and mu values.
    
    Args:
        metric (str): Type of metric to plot ('loss', 'accuracy', or 'similarity')
        log_types (list): List of log type prefixes to process
    """

    print(f"Current working directory: {os.getcwd()}")
    
    # List all files in the log_synthetic directory
    try:
        print("Files in log_synthetic directory:")
        print(os.listdir('log_synthetic'))
    except FileNotFoundError:
        print("log_synthetic directory not found!")

    # Default log types if not specified
    if log_types is None:
        log_types = ["synthetic_iid", "synthetic_0_0", "synthetic_0.5_0.5", "synthetic_1_1"]
    
    # Color palette for different mu values
    colors = ['#17becf', '#e377c2', '#7f7f7f', '#bcbd22', '#9467bd', '#8c564b']
    
    # Create figure with subplots for each log type
    f, axes = plt.subplots(1, len(log_types), figsize=(20, 4))
    
    # Ensure axes is always a list, even if only one subplot
    if len(log_types) == 1:
        axes = [axes]
    
    # Process each log type
    for idx, log in enumerate(log_types):
        ax = axes[idx]
        
        # Prepare to store data for all mu values
        all_rounds = []
        all_values = []
        labels = []
        
        # Find and plot data for mu values 0 to 5
        for mu in range(6):
            try:
                file_path = f"log_synthetic/{log}_client10_epoch20_mu{mu}"
                rounds_data, sim_data, loss_data, accu_data = parse_log(file_path)
                
                # Select metric to plot
                if metric == 'loss':
                    values = loss_data
                    ylabel = 'Training Loss'
                elif metric == 'accuracy':
                    values = accu_data
                    ylabel = 'Testing Accuracy'
                else:  # similarity/gradient difference
                    values = sim_data
                    ylabel = "Variance of Local Grad."
                
                # Plot if data exists
                if values:
                    ax.plot(np.asarray(rounds_data[:len(values)]), 
                            np.asarray(values), 
                            linewidth=3.0, 
                            label=f'μ={mu}, E=20', 
                            color=colors[mu],
                            linestyle='--' if mu % 2 == 0 else '-')
            
            except Exception as e:
                print(f"Error processing {log} with mu={mu}: {e}")
        
        # Subplot styling
        ax.set_xlabel("# Rounds", fontsize=22)
        ax.set_title(log, fontsize=22)
        
        # Y-label only for first subplot
        if idx == 0:
            ax.set_ylabel(ylabel, fontsize=22)
        
        # Spine and tick styling
        for spine in ['bottom', 'top', 'right', 'left']:
            ax.spines[spine].set_color('#dddddd')
        
        ax.tick_params(color='#dddddd')
        
        # Legend for last subplot
        if idx == len(log_types) - 1:
            ax.legend(fontsize=14, loc='best')
    
    # Adjust layout and save
    plt.tight_layout()
    plt.savefig(f'{metric}_multi_mu.pdf')
    plt.close()

def main():
    """Main function to handle command-line arguments and trigger plotting"""
    # Default to loss if no argument provided
    metric = sys.argv[1] if len(sys.argv) > 1 else 'loss'
    
    # Supported metrics
    supported_metrics = ['loss', 'accuracy', 'similarity']
    
    if metric not in supported_metrics:
        print(f"Invalid metric. Choose from: {', '.join(supported_metrics)}")
        sys.exit(1)
    
    plot_experiments(metric)

if __name__ == "__main__":
    main()

Invalid metric. Choose from: loss, accuracy, similarity


SystemExit: 1