# Objectives
This notebook is a new version, that will work on more logged data. The old one is kept for retrocompatibility reasons.

## Data definition and loading

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import re

TOTAL_PROCESSES = 128
MAX_MACHINES =  8
STARTING_ITERATION = 0
MAX_ITERATIONS=4000

# results_path = "formatted_results/36nodes"
# save_directory = "assets/36nodes"

# results_path = "formatted_results/128nodes"
# save_directory = "assets/128nodes"

results_path = "formatted_results/averaging_steps"
save_directory = "assets/averaging_steps"

dirlist = os.listdir(results_path)
print(dirlist)

def get_noise(key):
    r = re.findall(r'\d+', key)
    if len(r) == 0:
        return 0
    else:
        return int(r[0])

dirlist.sort(key = lambda x: 1 if "dynamic" in x else 0)
dirlist.sort(key = lambda x: 2 if "ZeroSum" in x else 1 if "Gaussian" in x else 0)
dirlist.sort(key = lambda x: get_noise(x))
print(dirlist)

global_data = {}
for filename in dirlist:
    key = filename.split('.')[0]
    global_data[key] = pd.read_csv(f"{results_path}/{filename}")

linestyles = {
    "ZeroSum" : "--"
}

fontsize=20

alpha = 0.1
print(global_data.keys())

# Global loss display:

In [None]:
# Must be user defined, feel free to adjust figsize (limits to the naïve autoscaling performed here)
figsize = (25,10)
alpha = 0.1

attributes = ["test_acc","test_niid_acc"]

metrics = ["mean"]  # The metric we want to evaluate, must be computed in the table

# to_plot = [ 
#     "No_noise_static",
#     "No_noise_dynamic",

#     # "Gaussian64_static",
#     # "Gaussian64_dynamic",

#     # "Gaussian32_static",
#     # "Gaussian32_dynamic",

#     "Gaussian16_static",
#     "Gaussian16_dynamic",

#     # "Gaussian8_static",
#     # "Gaussian8_dynamic",

#     # "Gaussian4_static",
#     # "Gaussian4_dynamic",

#     # "Gaussian2_static",
#     # "Gaussian2_dynamic",

#     # "ZeroSum64_static",
#     # "ZeroSum64_dynamic",

#     # "ZeroSum32_static",
#     # "ZeroSum32_dynamic",
    
#     "ZeroSum16_static",
#     "ZeroSum16_dynamic",
    
#     # "ZeroSum8_static",
#     # "ZeroSum8_dynamic",
    
#     # "ZeroSum4_static",
#     # "ZeroSum4_dynamic"

#     # "ZeroSum2_static",
#     # "ZeroSum2_dynamic"
# ]

to_plot = ['No_noise_static',
 'Muffliato16_static_1step', 
 'ZeroSum16_static',
#  'Muffliato16_static_2step', 
 'Muffliato16_static_3step',
#  'Muffliato16_static_5step',
 'Muffliato16_static_10step', 
 'Muffliato16_static_20step'
]

def plot_attributes(to_plot,attributes,metrics):
    subplot_dim = (len(attributes),len(metrics))
    fig,axs = plt.subplots(subplot_dim[0],subplot_dim[1],sharex= True, figsize=figsize)

    if subplot_dim[0] == 1:
        axs = [axs]
    if subplot_dim[1] ==1:
        axs = [[ax] for ax in axs]
    # Loop to set axis and subplot titles. 
    for i,attribute in enumerate(attributes):
        for j, metric in enumerate(metrics):
            if "test" in attribute:
                axs[i][j].set_title(f"{metric} of {attribute}, evaluated on global test set",fontsize=fontsize)
            elif "train" in attribute: 
                axs[i][j].set_title(f"{metric} of {attribute}, evaluated on local train set",fontsize=fontsize) 
            else:
                axs[i][j].set_title(f"{metric} of {attribute}",fontsize=fontsize)
            axs[i][j].set_xlabel(f"Communication rounds",fontsize=fontsize)
            axs[i][j].set_ylabel(f"{attribute} {metric}",fontsize=fontsize)

    #Plots the data
    for key in to_plot:
        data = global_data[key] 
        for i, attribute in enumerate(attributes):
            for j, metric in enumerate(metrics):
                line = "-"
                for substring,linestyle in linestyles.items():
                    if substring in key:
                        line = linestyle
                no_nan_data = data[["iteration",f"{attribute} {metric}"]].dropna()
                axs[i][j].plot(no_nan_data["iteration"],no_nan_data[f"{attribute} {metric}"],line,label=key)  
                if metric=="mean": # When displaying the mean, we also display the min and max for each iteration. 
                    min_data = data[["iteration",f"{attribute} min"]].dropna()
                    max_data = data[["iteration",f"{attribute} max"]].dropna()
                    axs[i][j].fill_between(min_data["iteration"], min_data[f"{attribute} min"], max_data[f"{attribute} max"], alpha=alpha)
    
    for i in range(subplot_dim[0]):
        for j in range(subplot_dim[1]):
            axs[i][j].legend(ncol=3,fontsize=3*fontsize/4)
            axs[i][j].grid()
            axs[i][j].tick_params(labelbottom=True,labelleft = True,axis="both", labelsize=fontsize) 

    fig.tight_layout()
    for i in range(subplot_dim[0]):
        for j in range(subplot_dim[1]):
            extent = axs[i][j].get_tightbbox(fig.canvas.get_renderer()).transformed(fig.dpi_scale_trans.inverted())
            fig.savefig(f"{save_directory}/{axs[i][j].get_title()}.pdf", bbox_inches=extent)
plot_attributes(to_plot,attributes,metrics)

In [None]:
attributes = ["total_bytes"]

metrics = ["sum"]  # The metric we want to evaluate, must be computed in the table
plot_attributes(to_plot,attributes,metrics)

In [None]:
import re

def get_number_steps(name):
    parsed = name.split("_")
    for sequence in parsed:
        if "step" in sequence:
            print(sequence)
            numbers = re.findall(r'\d+', sequence)
            assert len(numbers) == 1
            return int(numbers[0])
    return -1

x = [0]
for name in to_plot[1::]:
    x.append(get_number_steps(name))

print(x)
for i,name in enumerate(to_plot):
    data = global_data[name]
    last_iteration_data = data[data["iteration"] == MAX_ITERATIONS]
    data_to_plot = last_iteration_data["total_bytes sum"].item()
    plt.bar(x[i], data_to_plot,label = name)   

plt.title("total_bytes sent at last iteration")
plt.legend()

## Privacy attack display:

The relevant data must have been loaded by now, into `privacy_data_dict`

In [None]:
metrics = ["Attacker advantage","AUC"]
metrics = [metrics[0]]
figsize = (15,15)
LOCATIONS_OF_ATTACKS = ["PRE-STEP", "PRE-STEP-niid"] # The one we want to show here
window_size = 10
alpha=0.1
# to_plot = [ 
#     "No_noise_static",
#     "No_noise_dynamic",

#     # "Gaussian64_static",
#     # "Gaussian64_dynamic",

#     # "Gaussian32_static",
#     # "Gaussian32_dynamic",

#     "Gaussian16_static",
#     "Gaussian16_dynamic",

#     # "Gaussian8_static",
#     # "Gaussian8_dynamic",

#     # "Gaussian4_static",
#     # "Gaussian4_dynamic",

#     # "Gaussian2_static",
#     # "Gaussian2_dynamic",

#     # "ZeroSum64_static",
#     # "ZeroSum64_dynamic",

#     # "ZeroSum32_static",
#     # "ZeroSum32_dynamic",
    
#     "ZeroSum16_static",
#     "ZeroSum16_dynamic",
    
#     # "ZeroSum8_static",
#     # "ZeroSum8_dynamic",
    
#     # "ZeroSum4_static",
#     # "ZeroSum4_dynamic",

#     # "ZeroSum2_static",
#     # "ZeroSum2_dynamic"
# ]

to_plot = ['No_noise_static',
 'Muffliato16_static_1step', 
 'Muffliato16_static_2step', 
 'Muffliato16_static_5step',
 'Muffliato16_static_10step', 
 'Muffliato16_static_20step'
]


linewidth = 3


def plot_privacy_data(to_plot,LOCATIONS_OF_ATTACKS):
    fig,axs = plt.subplots(len(metrics),2,sharey=False,sharex= True, figsize=(figsize[0]*2,figsize[1]*len(metrics)))

    if len(metrics) ==1:
        axs = [axs]
    # Group by iteration and extracted mean and std
    for label in to_plot: 
        data = global_data[label]
        line_style = '-'
        for key,line in linestyles.items():
            if key in label:
                line_style =  line
        # entire_dataset = data[data['slice feature'] == 'Entire dataset']
        for i,metric in enumerate(metrics):
            for j,location in enumerate(LOCATIONS_OF_ATTACKS):
                complementary_name = ""
                if "niid" in location:
                    complementary_name = " NIID"
                location_label = "".join(location.split('-'))
                
                # Plot
                mean = data[f"{metric} mean {location}"].rolling(window=window_size).mean()
                # print(metric_data)
                
                # data_max, data_min = metric_data["amax"], metric_data["amin"]

                # mean = metric_data['mean']
                # std = metric_data['std']
                axs[i][j].plot(data.index, mean, line_style, label= label + "-MEAN",linewidth=linewidth*2)
                # axs[i][j].fill_between(averaged.index, mean-2*std, mean + 2*std, alpha=alpha)
                axs[i][j].set_title(f"Evolution of the {metric} | {location} | Window : {window_size}",fontsize=fontsize)
                axs[i][j].set_ylabel(metric+complementary_name, fontsize=fontsize)
                axs[i][j].set_xlabel('Communication rounds', fontsize=fontsize)
                
    for i,metric in enumerate(metrics):
        for j,location in enumerate(LOCATIONS_OF_ATTACKS):
            axs[i][j].legend(fontsize=fontsize,ncol=2)
            axs[i][j].tick_params(labelbottom=True,labelleft = True)
            axs[i][j].tick_params(axis="both", labelsize=fontsize) 
            axs[i][j].grid()
    
    fig.tight_layout()
    for i,metric in enumerate(metrics):
        for j,location in enumerate(LOCATIONS_OF_ATTACKS):
            extent = axs[i][j].get_tightbbox(fig.canvas.get_renderer()).transformed(fig.dpi_scale_trans.inverted())
            fig.savefig(f"{save_directory}/{axs[i][j].get_title()}.pdf", bbox_inches=extent)


    

plot_privacy_data(to_plot,LOCATIONS_OF_ATTACKS)


## Tentative d'affichage loss/privacy

In [None]:
# attributes = ["train_loss", "test_loss","test_acc","test_niid_loss","test_niid_acc"]
attributes = ["test_acc mean","test_niid_acc mean"]
figsize = (10,10)
point_size = 30

to_plot = [ 
    "No_noise_static",
    "No_noise_dynamic",

    "Gaussian64_static",
    # "Gaussian64_dynamic",

    "Gaussian32_static",
    "Gaussian32_dynamic",

    "Gaussian16_static",
    "Gaussian16_dynamic",

    "Gaussian8_static",
    "Gaussian8_dynamic",

    "Gaussian4_static",
    # "Gaussian4_dynamic",

    "Gaussian2_static",
    # "Gaussian2_dynamic",

    "ZeroSum64_static",
    # "ZeroSum64_dynamic",

    "ZeroSum32_static",
    "ZeroSum32_dynamic",
    
    "ZeroSum16_static",
    "ZeroSum16_dynamic",
    
    "ZeroSum8_static",
    "ZeroSum8_dynamic",
    
    "ZeroSum4_static",
    # "ZeroSum4_dynamic"

    "ZeroSum2_static",
    # "ZeroSum2_dynamic",

    "Muffliato64_static",
    # "Muffliato64_dynamic",

    "Muffliato32_static",
    "Muffliato32_dynamic",

    "Muffliato16_static",
    "Muffliato16_dynamic",

    "Muffliato8_static",
    # "Muffliato8_dynamic",

    "Muffliato4_static",
    # "Muffliato4_dynamic",

    "Muffliato2_static",
    # "Muffliato2_dynamic",
]

min_iteration = 3000
max_iteration = 4000

shapes = {
    ("No_noise",): ("o", "No noise"),
    ("64",) : ("1","σ/64"),
    ("32",): ("*","σ/32"),
    ("16",): ("D","σ/16"),
    ("8",): ("X","σ/8"),
    ("4",): ("^","σ/4"),
    ("2",): ("s","σ/2")
}

color_groups = {
    ("Gaussian","static"): "blue",
    ("Gaussian","dynamic") : "cyan",
    ("ZeroSum","static"): "red",
    ("ZeroSum","dynamic"): "orange",
    ("Muffliato","static"): "violet",
    ("Muffliato","dynamic"): "pink",
    ("ZeroSum","dynamic"): "orange",
    ("No_noise_static"):"darkgreen",
    ("No_noise_dynamic") :"limegreen",
}


locations = ["PRE-STEP", "POST-STEP", "PRE-STEP-niid", "POST-STEP-niid"]
# locations = ["PRE-STEP"]


from matplotlib.patches import Patch
from matplotlib.lines import Line2D

def clip_iterations_bounds(data,min_iteration,max_iteration):
    if "iteration" in data.columns:
        tmp_df= data[data["iteration"] <= max_iteration]
        tmp_df = tmp_df[tmp_df["iteration"] >= min_iteration]

        return tmp_df 
    else:

        print(data.columns)
        print(data)
        raise KeyError(f"`iteration` key not in the dataframe column, probably missing a reset_index")

def find_discriminator(key,shapes):
    # print(shapes)
    if shapes is not None:
        for shape_discriminator,shape in shapes.items():
            all_attributes_in = True
            for attribute in shape_discriminator: #All the attributes of the discriminator must be matched in the name we want to display. 
                if attribute not in key:
                    all_attributes_in = False
                    break
            if all_attributes_in:
                # print(f"{key} matched by {shape_discriminator}")
                return shape_discriminator
    print(f"RETURNED NONE on {key}")
    return None

def format_attribute(attr):
    if type(attr) != tuple:
        return attr
    for i,attribute in enumerate(attr):
        if i != 0:
            formatted += " + " + attribute
        else:
            formatted = attribute
    return formatted

def get_custom_legend(shapes,color_groups):
    custom_legend = []
    for label, shape in shapes.items():
        formatted_label = format_attribute(shape[1])
        custom_legend.append(Line2D([0], [0], marker=shape[0], linewidth=0,color='blue', label=formatted_label, markersize=15))
    
    for label,color in color_groups.items():
        formatted_label = format_attribute(label)
        custom_legend.append(Line2D([0], [0], marker=".", color=color, lw=4, label=formatted_label))
    return custom_legend

def plot_privacyvsloss(to_plot, locations = ["PRE-STEP", "POST-STEP"], display_type = "all" ,shapes = None,\
                        min_iteration=STARTING_ITERATION, max_iteration=MAX_ITERATIONS):
    subplot_dim = (len(attributes), len(locations))
    fig,axs = plt.subplots(subplot_dim[0], subplot_dim[1],sharey=False,sharex= False, figsize=(figsize[1]* subplot_dim[0], figsize[0] * subplot_dim[1]))


    for i,attribute in enumerate(attributes):
        for j,location in enumerate(locations):
            legend_elements=None
            if display_type=="center":
                legend_elements = get_custom_legend(shapes,color_groups)
           
            axs[i][j].set_xlabel(attribute,fontsize=fontsize*2/3)
            suffix = "IID"
            if "niid" in location: 
                suffix = "NIID"
            axs[i][j].set_ylabel(f"Attacker advantage {suffix}",fontsize=fontsize*2/3)
            if display_type == "all":
                axs[i][j].set_title(f"Attacker advantage vs {attribute} at {location}, iterations {min_iteration}-{max_iteration}",fontsize=fontsize*2/3)
            elif display_type == "center":
                points_history = {}
                for shape_key in color_groups.keys():
                    points_history[shape_key] = []
                axs[i][j].set_title(f"Centroid of attacker advantage vs {attribute} at {location}, iterations {min_iteration}-{max_iteration}",fontsize=fontsize*2/3)
            for key in to_plot:
                data = global_data[key]

                x_axis = data[attribute]
                privacy_data = data[f"Attacker advantage mean {location}"]
            
                
                
                # x_axis, averaged = get_scatterplot(data,privacy_data,attribute,min_iteration,max_iteration)

                shape_discriminator = find_discriminator(key,shapes)
                marker = None
                if shape_discriminator is not None:
                    marker = shapes[shape_discriminator][0]

                color_discriminator = find_discriminator(key,color_groups)
                color = color_groups[color_discriminator]

                if display_type == "all":
                    axs[i][j].scatter(x_axis,privacy_data, s= point_size,marker = marker, color=color, label = key)
                elif display_type == "center":
                    # axs[i][j].scatter(x_axis,privacy_data, s= point_size,marker = marker, color='w', alpha=0)
                    axs[i][j].scatter(x_axis,privacy_data, s= point_size,marker = marker,color = color,alpha=0)
                    x_value = np.mean(x_axis)
                    y_value = privacy_data.mean(axis=0)
                    if color_discriminator is not None:
                        points_history[color_discriminator].append((x_value,y_value))

                    axs[i][j].scatter(x_value,y_value, s= point_size * 10,marker = marker, color = color, label = key)

          
            if display_type == "center":
                for key, coordinates in points_history.items():
                    if len(coordinates)>1: #To discard the unnoised baselines.
                        formatted_label = format_attribute(key)
                        p1, = axs[i][j].plot([x[0] for x in coordinates],[x[1] for x in coordinates],label = formatted_label,color=color_groups[key])
                        # legend_elements.append(p1)
    
            axs[i][j].legend(fontsize=fontsize*2/3,ncol=3,handles=legend_elements)

    for ax_list in axs:
        for ax in ax_list:
            ax.tick_params(labelbottom=True,labelleft = True,axis="both", labelsize=2*fontsize/3)
            ax.grid()
    fig.tight_layout()
    for ax_list in axs:
        for ax in ax_list:
            extent = ax.get_tightbbox(fig.canvas.get_renderer()).transformed(fig.dpi_scale_trans.inverted())
            fig.savefig(f"{save_directory}/{ax.get_title()}.pdf", bbox_inches=extent)

plot_privacyvsloss(to_plot,["PRE-STEP","PRE-STEP-niid"], shapes = shapes, min_iteration=min_iteration, max_iteration=max_iteration)

### Affichage des centres de gravité


In [None]:
to_plot = global_data.keys()

shapes = {
    ("No_noise",): ("o", "No noise"),
    ("64",) : ("1","σ/64"),
    ("32",): ("*","σ/32"),
    ("16",): ("D","σ/16"),
    ("8",): ("X","σ/8"),
    ("4",): ("^","σ/4"),
    ("2",): ("s","σ/2")
}

color_groups = {
    ("Gaussian","static"): "blue",
    ("Gaussian","dynamic") : "cyan",
    ("ZeroSum","static"): "red",
    ("ZeroSum","dynamic"): "orange",
    ("Muffliato","static"): "violet",
    ("Muffliato","dynamic"): "pink",
    ("No_noise_static"):"darkgreen",
    ("No_noise_dynamic") :"limegreen",
}

plot_privacyvsloss(to_plot,["PRE-STEP","PRE-STEP-niid"],display_type = "center", shapes = shapes, min_iteration=min_iteration, max_iteration=max_iteration)

## Utility comparison display
We will simply present the utility of each method in a clearer manner here.

In [None]:
for key in global_data.keys():
    print(f"\"{key}\",")


In [None]:
metrics = ["test_acc","test_niid_acc"]
figsize = (20,4)
alpha = 0.4
width = 0.25  # the width of the bars
window_size = 1
ncols = 1

topology = "static"


prop_cycle = plt.rcParams['axes.prop_cycle']
total_colors = prop_cycle.by_key()['color']

attributes = [ #The attributes to display and color by
    (["Muffliato", "static"],total_colors[0]),
    # (["Muffliato", "dynamic"],total_colors[1]),
    (["Gaussian", "static"],total_colors[2]),
    # (["Gaussian", "dynamic"],total_colors[3]),
    (["ZeroSum", "static"],total_colors[4]),
    # (["ZeroSum", "dynamic"],total_colors[5]) 
]


baselines = {
    f"No_noise_static" : "b",
    f"No_noise_dynamic" : "r"
}

to_plot = [["64"],["32"], ["16"], ["8"], ["4"], ["2"]]
for e in to_plot:
    e.append(topology)



to_plot_dict = {}
for attr_list in to_plot:
    to_plot_dict[attr_list[0]] = []

for e in global_data.keys():
    for attr_list in to_plot:
        match = True
        for attr in attr_list:
            if attr not in e:
                match = False
                break
        if match:
            to_plot_dict[attr_list[0]].append(e)
            break



def is_attribute_match(name,attribute_list):
    is_match = True
    for attribute in attribute_list:
        if attribute not in name:
            return False
    return True

def get_last_n_indexes(data_to_consider,window_size):
    data = data_to_consider.tail(window_size)
    return data.index

def plot_utility(metric,to_plot_dict,baselines, window_size):
    fig,ax = plt.subplots(figsize=figsize)
    x = np.arange(len(to_plot_dict.keys()))
    indexes = None
    multiplier = 0
    total_plot = {}
    for name,keys in to_plot_dict.items():
        for key in keys:
            data = global_data[key]
            data_to_consider = data[f"{metric} mean"].dropna()
            if indexes is None:
                indexes = get_last_n_indexes(data_to_consider, window_size)
                # print(indexes)
            data_to_plot = np.mean(data_to_consider[indexes])
            # print(f"{key} : {data_to_plot}")
            for attribute_list,color in attributes:
                
                if is_attribute_match(key,attribute_list):
                    attribute = " ".join(attribute_list)
                    if (attribute,color) not in total_plot.keys():
                        total_plot[(attribute,color)] = []
                    total_plot[(attribute,color)].append(data_to_plot)

    for (attribute,color), data_to_plot in total_plot.items():
        offset = width * multiplier
        rects = plt.bar(x + offset, data_to_plot, width, label=attribute,color=color)
        plt.bar_label(rects, padding=len(keys))
        multiplier += 1

    #Draw the baseline:
    for key,color in baselines.items():
        data = global_data[key]
        data_to_consider = data[f"{metric} mean"].dropna()
        data_to_plot = np.mean(data_to_consider[indexes])
        plt.axhline(y=data_to_plot,linewidth=1,label = key,linestyle ="--",color= color)

    # plt.bar(x,y,color = total_colors,alpha = alpha)
    ax.set_xticks(x + ((len(keys)-1) * width/2), to_plot_dict.keys(),fontsize=fontsize)
    plt.ylabel(f"{metric}",fontsize=fontsize)
    plt.grid()
    if topology != None:
        plt.title(f"Mean of last {window_size} {metric} comparison, {topology} topology")
    else:
        plt.title(f"Mean of last {window_size} {metric} comparison")
    plt.legend(fontsize=fontsize*2/3,ncol = ncols )
    plt.tick_params(labelbottom=True,labelleft = True,axis="both", labelsize=2*fontsize/3)
    extent = fig.axes[0].get_tightbbox(fig.canvas.get_renderer()).transformed(fig.dpi_scale_trans.inverted())
    fig.savefig(f"{save_directory}/{fig.axes[0].get_title()}.pdf", bbox_inches=extent)
    # plt.rc('ytick',labelsize=fontsize)
    plt.show()


for name,keys in to_plot_dict.items():
    print(f"{name} : {keys}")
for metric in metrics:
    plot_utility(metric,to_plot_dict,baselines,window_size)


    

In [None]:
ncols=1

attributes = [ #The attributes to display and color by
    # (["Muffliato", "static"],total_colors[0]),
    (["Muffliato", "dynamic"],total_colors[1]),
    # (["Gaussian", "static"],total_colors[2]),
    (["Gaussian", "dynamic"],total_colors[3]),
    # (["ZeroSum", "static"],total_colors[4]),
    (["ZeroSum", "dynamic"],total_colors[5]) 
]
topology="dynamic"
baselines = {
    f"No_noise_static" : "b",
    f"No_noise_dynamic" : "r"
}
to_plot = [["64"],["32"], ["16"], ["8"], ["4"], ["2"]]
# to_plot = [["32"], ["16"], ["8"]]

for e in to_plot:
    e.append(topology)

to_plot_dict = {}
for attr_list in to_plot:
    to_plot_dict[attr_list[0]] = []

# Initialize the dict of things to plot and group by
for e in global_data.keys():
    for attr_list in to_plot:
        match = True
        for attr in attr_list:
            if attr not in e:
                match = False
                break
        if match:
            to_plot_dict[attr_list[0]].append(e)
            break

for name,keys in to_plot_dict.items():
    print(f"{name} : {keys}")
for metric in metrics:
    plot_utility(metric,to_plot_dict,baselines,window_size)


In [None]:
width = 0.20  # the width of the bars
ncols=2
attributes = [ #The attributes to display and color by
    (["Muffliato", "static"],total_colors[0]),
    (["Muffliato", "dynamic"],total_colors[1]),
    (["Gaussian", "static"],total_colors[2]),
    (["Gaussian", "dynamic"],total_colors[3]),
    (["ZeroSum", "static"],total_colors[4]),
    (["ZeroSum", "dynamic"],total_colors[5]) 
]
topology = None

baselines = {
    f"No_noise_static" : "b",
    f"No_noise_dynamic" : "r"
}
to_plot = [["64"],["32"], ["16"], ["8"], ["4"], ["2"]]



to_plot_dict = {}
for attr_list in to_plot:
    to_plot_dict[attr_list[0]] = []

# Initialize the dict of things to plot and group by
for e in global_data.keys():
    for attr_list in to_plot:
        match = True
        for attr in attr_list:
            if attr not in e:
                match = False
                break
        if match:
            to_plot_dict[attr_list[0]].append(e)
            break

for name,keys in to_plot_dict.items():
    print(f"{name} : {keys}")
for metric in metrics:
    plot_utility(metric,to_plot_dict,baselines,window_size)

In [None]:
to_plot = {
    "Gaussian_static" : [key for key in global_data.keys() if "Gaussian" in key and "static" in key],
    "Gaussian_dynamic" : [key for key in global_data.keys() if "Gaussian" in key and "dynamic" in key],
    "ZeroSum_static" : [key for key in global_data.keys() if "ZeroSum" in key and "static" in key],
    "ZeroSum_dynamic" : [key for key in global_data.keys() if "ZeroSum" in key and "dynamic" in key],
    "No noise_static" : ["No noise_static"],
    "No noise_dynamic" : ["No noise_dynamic"]
}

baseline = ["No_noise_static", "No_noise_dynamic"]
colors = {
    "static" : "b",
    "dynamic" : "r"
}
metrics = ["test_acc","test_niid_acc"]

for metric in metrics:
    fig = plt.figure(figsize = figsize)
    indexes = None
    for expe_type, expe_list in to_plot.items():
        
        if len(expe_list) == 1 and expe_list[0] in baseline:
            expe = expe_list[0] 
            data = global_data[expe][f"{metric} mean"].dropna()
            
            if indexes is None:
                indexes = get_last_n_indexes(data,window_size)
                print(indexes)
            y_value = np.mean(data[indexes])
            found = False
            for substring,color in colors.items():
                if substring in expe:
                    plt.axhline(y=y_value,linewidth=1, color=color,label = expe,linestyle ="--")
                    found = True
                    break
            if not found:
                plt.axhline(y=y_value,linewidth=1,label = expe,linestyle ="--")
        else:
            data_to_plot = []
            for expe in expe_list:
                data = global_data[expe]
                x_value = np.mean(data["generated_noise_std"]["mean"])

                data = data[metric]["mean"].dropna()
                if indexes is None:    
                    indexes = get_last_n_indexes(data,window_size)
                    print(indexes)
                y_value = np.mean(data[indexes])
                data_to_plot.append((x_value,y_value))

            data_to_plot.sort(key = lambda x : x[0])

            plt.plot([x for (x,y) in data_to_plot],[y for (x,y) in data_to_plot], marker = 'o', label=expe_type)

    
    plt.legend(fontsize=fontsize*2/3)
    plt.xlabel("Noise level",fontsize=fontsize)
    plt.ylabel(metric,fontsize=fontsize)
    plt.tick_params(labelbottom=True,labelleft = True,axis="both", labelsize=2*fontsize/3)
    plt.title(f"Evolution of the last {window_size} {metric} according to noise level",fontsize=fontsize)
    plt.grid()
    ax = fig.axes[0]
    extent = ax.get_tightbbox(fig.canvas.get_renderer()).transformed(fig.dpi_scale_trans.inverted())
    fig.savefig(f"{save_directory}/{ax.get_title()}.pdf", bbox_inches=extent)
    plt.show()