In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import glob
import seaborn as sns

## Read from `~/metrics/epoch_metrics/` directory

In [2]:
df = pd.read_parquet("../../metrics/epoch_metrics/")

# Hyper-Parameters

## All tests summary

Note that this is a summary. 

Notes:
1. `synchronous` method implies `theta` equal to `0.0` and the reverse.
2. The stopping criterion (epochs at the moment) is not the same for all combination, in fact, it changes. For `LeNet-5` currently all tests end at 50 epochs. For the `AdvancedCNN` we go up to 350 epochs in some tests and some others up until 250.
3. `Theta` greater than 2 are exploratory tests and will not be considered at all later on.

In [3]:
df_lenet = df[df['nn_name'] == 'LeNet-5']

In [4]:
for col in ['dataset_name', 'fda_name', 'nn_name', 'num_clients', 'batch_size', 'num_steps_until_rtc_check', 'theta', 'bias']:
    print(f"{col}: {sorted(list(df_lenet[col].unique()))}")

dataset_name: ['MNIST']
fda_name: ['gm', 'linear', 'naive', 'sketch', 'synchronous']
nn_name: ['LeNet-5']
num_clients: [5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60]
batch_size: [32, 64, 128, 256]
num_steps_until_rtc_check: [1]
theta: [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 5.0, 7.0, 10.0, 12.0, 15.0]
bias: [nan, 0.3, 0.6, 0.9]


## Query All tests

In [5]:
test_combinations = df.groupby(['dataset_name', 'nn_name', 'fda_name', 'num_steps_until_rtc_check', 'batch_size', 'theta', 'num_clients'])['epoch'].max().reset_index()

In [6]:
test_combinations[
    (test_combinations['nn_name'] == 'AdvancedCNN') &
    (test_combinations['fda_name'] == 'naive') &
    (test_combinations['theta'] == 20.0) &
    (test_combinations['num_clients'] == 5)
]

Unnamed: 0,dataset_name,nn_name,fda_name,num_steps_until_rtc_check,batch_size,theta,num_clients,epoch
348,MNIST,AdvancedCNN,naive,1,32,20.0,5,100


# Helpful new Dataframe metrics

### Add Helpful Dataset Metrics 

In [7]:
def dataset_n_train(row):
    if row['dataset_name'] == "MNIST":
        return 60_000
    else:
        return -1


def dataset_one_sample_bytes(row):
    if row['dataset_name'] == "MNIST":
        # input image 784 tf.float32 pixels and a tf.int32 label
        return 4 * (784 + 1)
    else:
        return -1

In [8]:
df['n_train'] = df.apply(dataset_n_train, axis=1)
df['one_sample_bytes'] = df.apply(dataset_one_sample_bytes, axis=1)

### Add Helpful model metrics

In [9]:
df['model_bytes'] = df['nn_num_weights'] * 4

### Add Helpful FDA method metrics

In [10]:
def fda_local_state_bytes(row):
    if row['fda_name'] == "naive":
        return 4
    if row['fda_name'] == "linear":
        return 8
    if row['fda_name'] == "sketch":
        return row['sketch_width'] * row['sketch_depth'] * 4 + 4
    if row['fda_name'] == "synchronous":
        return 0
    if row['fda_name'] == 'gm':
        return 0.125  # one bit

In [11]:
df['local_state_bytes'] = df.apply(fda_local_state_bytes, axis=1)

### Add Total Steps

total steps (a single fda step might have many normal SGD steps, batch steps)

In [12]:
df['total_steps'] = df['total_fda_steps'] * df['num_steps_until_rtc_check']

### Add communication metrics

The communication bytes exchanged for model synchronization. Remember that the Clients send their models to the Server and the Server sends the global model back. This happens at the end of every round.

In [13]:
df['model_bytes_exchanged'] = df['total_rounds'] * df['model_bytes'] * df['num_clients'] * 2

The communication bytes exchanged for monitoring the variance. This happens at the end of every FDA step which consists of `num_steps_until_rtc_check` number of steps. 

In [14]:
df['monitoring_bytes_exchanged'] = df['local_state_bytes'] * df['total_fda_steps'] * df['num_clients']

The total communication bytes exchanged in the whole Federated Learning lifecycle.

In [15]:
df['total_communication_bytes'] = df['model_bytes_exchanged'] + df['monitoring_bytes_exchanged']

In [16]:
df['total_communication_gb'] = df['total_communication_bytes'] / 10**9

Add rounds in one epoch.

In [17]:
df = df.sort_values(by=['dataset_name', 'fda_name', 'nn_num_weights', 'num_clients', 'batch_size', 'num_steps_until_rtc_check', 'theta', 'epoch'])

df['epoch_rounds'] = df.groupby(['dataset_name', 'fda_name', 'nn_num_weights', 'num_clients', 'batch_size', 'num_steps_until_rtc_check', 'theta'])['total_rounds'].diff()

# NaN first epoch
df['epoch_rounds'] = df['epoch_rounds'].fillna(df['total_rounds'])

df['epoch_rounds'] = df['epoch_rounds'].astype(int)

# HyperParameter ranking

### LeNet-5
On 2 `Nvidia A10`:
1. Batch Size = 32 -> 6.613 ms ± 0.128 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2. *Batch Size* = 64 -> 7.509 ms ± 0.065 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
3. *Batch Size* = 128 -> 8.02 ms ± 0.099 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
4. *Batch Size* = 256 -> 9.258 ms ± 0.336 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

## AdvancedCNN

On 2 `Nvidia A10`:
1. *Batch Size* = 32 -> 8.853 ms ± 0.0917 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2. *Batch Size* = 64 -> 10.325 ms ± 0.215 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
3. *Batch Size* = 128 -> 11.989 ms ± 0.134 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
4. *Batch Size* = 256 -> 16.47 ms ± 0.294 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [18]:
step_ms = {
    ("AdvancedCNN", 32): 8.853,
    ("AdvancedCNN", 64): 10.325,
    ("AdvancedCNN", 128): 11.989,
    ("AdvancedCNN", 256): 16.47,
    ("LeNet-5", 32): 6.613,
    ("LeNet-5", 64): 7.509,
    ("LeNet-5", 128): 8.02,
    ("LeNet-5", 256): 9.258
}

In [19]:
import numpy as np

def cpu_time_cost(row):
    """ Total cpu time cost in (sec).
    A single `step` means each client performed a single `step` 
    """
    return row['total_steps'] * step_ms[(row['nn_name'], row['batch_size'])] / 1000

def communication_time_cost(num_clients, total_communication_bytes, comm_model):
    """ Assuming channel is 1Gbps """

    total_communication_gbit = total_communication_bytes * 8e-9

    if comm_model == 'common_channel':
        
        return ((num_clients - 1) / num_clients) * total_communication_gbit    # sec

    if comm_model == 'hypercube':

        return (np.ceil(np.log(num_clients)) / num_clients) * total_communication_gbit   # sec

In [20]:
df['cpu_time_cost'] = df.apply(cpu_time_cost, axis=1)

In [21]:
df['hypercube_communication_time_cost'] = communication_time_cost(df['num_clients'], df['total_communication_bytes'], 'hypercube')

In [22]:
df['common_channel_communication_time_cost'] = communication_time_cost(df['num_clients'], df['total_communication_bytes'], 'common_channel')

In [23]:
df['hypercube_time_cost'] = df['cpu_time_cost'] + df['hypercube_communication_time_cost']

In [24]:
df['common_channel_time_cost'] = df['cpu_time_cost'] + df['common_channel_communication_time_cost']

In [25]:
df['hypercube_comm_cpu_time_ratio'] = df['hypercube_communication_time_cost'] / df['cpu_time_cost']

In [26]:
df['common_channel_comm_cpu_time_ratio'] = df['common_channel_communication_time_cost'] / df['cpu_time_cost']

# Plots about cost

In [27]:
# Define styles for each fda_name
fda_styles = {
    'naive': 'o-r',
    'linear': 's-g',
    'sketch': '^-b',
    'synchronous': 'x-c'
}
fda_names = ['gm', 'naive', 'linear', 'sketch', 'synchronous']

In [28]:
import matplotlib

num_clients_values = sorted(df['num_clients'].unique())
cmap = matplotlib.colormaps['tab20b']
colors_dict = {
    num_clients: color 
    for num_clients, color in zip(num_clients_values, cmap(np.linspace(0, 1, len(num_clients_values))))
}

## KDE Helper

In [29]:
"""
sns_params = {
    'bw_method': 'scott',
    'bw_adjust': 0.7,
    'fill': True,
    'alpha': 0.2
}
"""
sns_params = {
    'bw_method': 'scott',
    'bw_adjust': 0.7,
    'fill': False,
    'alpha': 1
}

sns_params_biases = {
    'bw_method': 'scott',
    'bw_adjust': 0.7,
    'fill': False,
    'alpha': 0.8
}

base_colors = {
    'gm': 'blue',
    'naive': 'orange',
    'linear': 'green',
    'sketch': 'red',
    'synchronous': 'purple'
}

In [30]:
import matplotlib.pyplot as plt

plt.rcParams['font.size'] = 14

## Total time cost with accuracy

### KDE

In [31]:
def kde_time_cost(df, filename, x_log=True):
    
    if x_log:
        log_scale = (True, False)
    else:
        log_scale = False
        
    plt.rcParams['font.size'] = 20
    plt.rcParams['legend.fontsize'] = 12
    
    pdf = PdfPages(filename)
    
    fig, axs = plt.subplots(1, 2, figsize=(20, 8))

    # Prepare a list to store the average communication for each method
    avg_time_cost_common_channel_dict = {}
    avg_time_cost_hypercube_dict = {}
    
    # Prepare lists to store the average information (string) for each subplot
    avg_info_common_channel = []
    avg_info_hypercube = []

    for fda_name in fda_names:
        common_channel_df = df[df['fda_name'] == fda_name]['common_channel_time_cost']
        avg_time_cost_common_channel_dict[fda_name] = common_channel_df.mean()
        avg_info_common_channel.append(f'{fda_name}: {avg_time_cost_common_channel_dict[fda_name]:.2f} sec')
        sns.kdeplot(common_channel_df, label=fda_name, ax=axs[0], color=base_colors[fda_name], log_scale=log_scale, **sns_params)
        
        hypercube_df = df[df['fda_name'] == fda_name]['hypercube_time_cost']
        avg_time_cost_hypercube_dict[fda_name] = hypercube_df.mean()
        avg_info_hypercube.append(f'{fda_name}: {avg_time_cost_hypercube_dict[fda_name]:.2f} sec')
        sns.kdeplot(hypercube_df, label=fda_name, ax=axs[1], color=base_colors[fda_name], log_scale=log_scale, **sns_params)
        
    
    """
    text = "Average Time Cost:\n" + '\n'.join(avg_info_common_channel)
    # Add the text annotation inside the plot
    axs[0].text(0.78, 0.68, text, transform=axs[0].transAxes, fontsize=9, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.4))
    """
    
    if not x_log:
        axs[0].set_xlim(left=0)
    axs[0].set_xlabel('Time Cost (sec)')
    axs[0].set_ylabel('Density')
    axs[0].legend()
    axs[0].set_title("Common Channel Communication Model")
    axs[0].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
    axs[0].legend()

    """
    text = "Average Time Cost:\n" + '\n'.join(avg_info_hypercube)
    # Add the text annotation inside the plot
    axs[1].text(0.79, 0.68, text, transform=axs[1].transAxes, fontsize=9, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.4))
    """
    
    if not x_log:
        axs[1].set_xlim(left=0)
        
    axs[1].set_xlabel('Time Cost (sec)')
    axs[1].set_ylabel('Density')
    axs[1].legend()
    axs[1].set_title("Hypercube Communication Model")
    axs[1].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
    axs[1].legend()
    
    plt.tight_layout()
    
    pdf.savefig(fig)
        
    # Close the current figure to prevent it from being displayed in the notebook
    plt.close(fig)
    pdf.close()
    
    plt.rcParams['font.size'] = 14
    plt.rcParams['legend.fontsize'] = 12

In [32]:
def kde_time_cost_biases(df, filename, x_log=False):
    
    if x_log:
        log_scale = (True, False)
    else:
        log_scale = False
        
    plt.rcParams['font.size'] = 20
    plt.rcParams['legend.fontsize'] = 12
    
    pdf = PdfPages(filename)
    
    fig, axs = plt.subplots(1, 2, figsize=(20, 8))

    for fda_name in fda_names:
        
        df_fda_method = df[df['fda_name'] == fda_name]
        
        biases = sorted(df_fda_method['bias'].unique(), reverse=True)[::-1]
        
        color_palette = sns.light_palette(base_colors[fda_name], n_colors=len(biases)+1)[::-1]
        
        for i, bias in enumerate(biases): #HERE
            if pd.isna(bias):
                mask = df_fda_method['bias'].isna()
            else:
                mask = df_fda_method['bias'] == bias
                
            df_bias = df_fda_method[mask]
            
            if df_bias.empty:
                continue

            common_channel_df = df_bias['common_channel_time_cost']
            sns.kdeplot(common_channel_df, label=f'{fda_name} w/ bias {bias}', ax=axs[0], color=color_palette[i], log_scale=log_scale, **sns_params_biases)

            hypercube_df = df_bias['hypercube_time_cost']
            sns.kdeplot(hypercube_df, label=f'{fda_name} w/ bias {bias}', ax=axs[1], color=color_palette[i], log_scale=log_scale, **sns_params_biases)

    
    if not x_log:
        axs[0].set_xlim(left=0)
    axs[0].set_xlabel('Time Cost (sec)')
    axs[0].set_ylabel('Density')
    axs[0].legend()
    axs[0].set_title("Common Channel Communication Model")
    axs[0].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
    axs[0].legend()
    
    if not x_log:
        axs[1].set_xlim(left=0)
        
    axs[1].set_xlabel('Time Cost (sec)')
    axs[1].set_ylabel('Density')
    axs[1].legend()
    axs[1].set_title("Hypercube Communication Model")
    axs[1].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
    axs[1].legend()
    
    plt.tight_layout()
    
    pdf.savefig(fig)
        
    # Close the current figure to prevent it from being displayed in the notebook
    plt.close(fig)
    pdf.close()
    
    plt.rcParams['font.size'] = 14
    plt.rcParams['legend.fontsize'] = 12

## Total Communication cost (in gb) with accuracy (scatter)

### KDE

In [33]:
def kde_communication_cost(df, filename, x_log=True):
    
    nn_name = df['nn_name'].unique()
    if len(nn_name) > 1:
        print("problem... kda_communication_cost")
        return
    for_lenet = True if nn_name[0] == 'LeNet-5' else False
    
    if x_log:
        log_scale = (True, False)
    else:
        log_scale = False
        
    pdf = PdfPages(filename)
    
    plt.figure(figsize=(10, 6))

    avg_communications_dict = {}
    avg_info = []

    for fda_name in fda_names:
        fda_data_df = df[df['fda_name'] == fda_name]['total_communication_gb']
        avg_communications_dict[fda_name] = fda_data_df.mean()
        avg_info.append(f'{fda_name}: {avg_communications_dict[fda_name]:.2f} GB')
        
        # Plotting only the KDE using kdeplot
        sns.kdeplot(fda_data_df, label=fda_name, log_scale=log_scale, color=base_colors[fda_name], **sns_params)

    
    text = "Average Communication:\n" + '\n'.join(avg_info)
    # Add the text annotation inside the plot
    plt.text(0.02, 0.97, text, transform=plt.gca().transAxes, fontsize=9, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.4))
    

    if not x_log:
        plt.xlim(left=0)
        
    plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
    plt.xlabel('Communication (GB)')
    plt.ylabel('Density')
    #plt.legend()  # We'll be moving this
    if for_lenet:
        plt.gca().legend(loc='best', bbox_to_anchor=(0, 0.1, 1, 0.67))
    else:
        plt.legend()
        
    plt.tight_layout()
    
    pdf.savefig(plt.gcf()) # Save the current figure
        
    # Close the current figure to prevent it from being displayed in the notebook
    plt.close()
    pdf.close()

## Total CPU time (in Seconds) with accuracy

### KDE

In [34]:
def kde_cpu_time_cost(df, filename, x_log=False):
    
    if x_log:
        log_scale = (True, False)
    else:
        log_scale = False
        
    pdf = PdfPages(filename)
    
    plt.figure(figsize=(10, 6))

    avg_cpu_dict = {}
    avg_info = []

    for fda_name in fda_names:
        fda_data_df = df[df['fda_name'] == fda_name]['cpu_time_cost']
        avg_cpu_dict[fda_name] = fda_data_df.mean()
        avg_info.append(f'{fda_name}: {avg_cpu_dict[fda_name]:.2f} sec')
        
        # Plotting only the KDE using kdeplot
        sns.kdeplot(fda_data_df, label=fda_name, log_scale=log_scale, color=base_colors[fda_name], **sns_params)

    """
    text = "Average CPU cost:\n" + '\n'.join(avg_info)
    # Add the text annotation inside the plot
    plt.text(0.81, 0.7, text, transform=plt.gca().transAxes, fontsize=9, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.4))
    """

    if not x_log:
        plt.xlim(left=0)
    plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
    plt.xlabel('CPU time cost (sec)')
    plt.ylabel('Density')
    plt.legend()
    plt.tight_layout()
    
    pdf.savefig(plt.gcf()) # Save the current figure
        
    # Close the current figure to prevent it from being displayed in the notebook
    plt.close()
    pdf.close()

## Different FDA-method runs visualization with lines (per clients)

In [35]:
import matplotlib
from matplotlib import cm

def fda_method_run_line_plot(df, fda_name, filename, limit_x_axis=False):
    
    hybrid_df = df[df.fda_name == fda_name]
    
    batch_size_values = sorted(hybrid_df['batch_size'].unique())
    theta_values = sorted(hybrid_df['theta'].unique())
    
    pdf = PdfPages(filename)

    for batch_size in batch_size_values:
        for theta in theta_values:
            filtered_data = hybrid_df[(hybrid_df['theta'] == theta) &
                                       (hybrid_df['batch_size'] == batch_size)] 

            if filtered_data.empty:
                continue
            
            num_clients_values = sorted(filtered_data['num_clients'].unique())
            
            #print(f"{batch_size} {theta} {filtered_data[['fda_name','accuracy']]}")
                
            # Because num_clients = 2 reaches very good accuracy very early we need to limit the x-axis so we can 
            # visualize the rest of the data more easily. We get the next max time cost and limit x
            # Get the maximum 'common_channel_time_cost' and 'hypercube_time_cost' when num_clients is not 2
            if limit_x_axis:
                # The num clients = 2 cause problems. We put limit 10 after the next in line
                max_common_channel_time_cost = filtered_data[filtered_data['num_clients'] != 2]['common_channel_time_cost'].max()
                max_hypercube_time_cost = filtered_data[filtered_data['num_clients'] != 2]['hypercube_time_cost'].max()


            fig, axs = plt.subplots(1, 2, figsize=(20, 6))

            # Plot each group with a unique color based on num_clients
            for num_clients in num_clients_values:
                data = filtered_data[filtered_data['num_clients'] == num_clients]

                axs[0].plot(data['common_channel_time_cost'], data['accuracy'], color=colors_dict[num_clients], label=num_clients, marker='o', markersize=3)
                axs[1].plot(data['hypercube_time_cost'], data['accuracy'], color=colors_dict[num_clients], label=num_clients, marker='o', markersize=3)
                
            
            axs[0].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
            axs[1].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
            
            #axs[0].set_ylim(top=1)
            #axs[1].set_ylim(top=1)

            axs[0].set_xlabel('Time Cost')
            axs[0].set_ylabel('Accuracy')
            axs[0].set_title("Common Channel Communication Model")
            axs[0].legend(title='Num Clients')
            if limit_x_axis:
                axs[0].set_xlim(0, max_common_channel_time_cost+700)  # set x-axis limit

            axs[1].set_xlabel('Time Cost')
            axs[1].set_ylabel('Accuracy')
            axs[1].set_title("Hypercube Communication Model")
            axs[1].legend(title='Num Clients')
            if limit_x_axis:
                axs[1].set_xlim(0, max_hypercube_time_cost+500)  # set x-axis limit

            title = f'Batch Size : {batch_size} , $\Theta$ : {theta}'

            fig.suptitle(title)

            plt.tight_layout()

            pdf.savefig(fig)

            plt.close(fig)

    pdf.close()

## Keep Hyper-parameters fixed and plot for filtered

In [59]:
# color=base_colors[fda_name], see next

def fda_methods_batch_size(df, filename):
    pdf = PdfPages(filename)
    
    num_clients_values = sorted(df['num_clients'].unique())
    theta_values = sorted(df['theta'].unique())[1:]

    for num_clients in num_clients_values:
        for theta in theta_values:
            filtered_df = df[(df['theta'] == theta) & (df['num_clients'] == num_clients)] 
            synchronous_data = df[(df['num_clients'] == num_clients) & (df['fda_name'] == 'synchronous')] 
            
            if filtered_data.empty:
                continue
            
            fig, axs = plt.subplots(1, 2, figsize=(20, 6))
            
            for fda_name in fda_names:
                
                fda_data = filtered_df[filtered_df['fda_name'] == fda_name]
                
                if fda_data.empty:
                    continue

                axs[0].plot(fda_data['batch_size'], fda_data['common_channel_time_cost'], marker='o', label=fda_name, markersize=3)
                axs[1].plot(fda_data['batch_size'], fda_data['hypercube_time_cost'], marker='o', label=fda_name, markersize=3)
            
            if not synchronous_data.empty:
                axs[0].plot(synchronous_data['batch_size'], synchronous_data['common_channel_time_cost'], marker='o', label='synchronous', markersize=3)
                axs[1].plot(synchronous_data['batch_size'], synchronous_data['hypercube_time_cost'], marker='o', label='synchronous', markersize=3)

            
            if not axs[0].has_data():
                plt.close(fig)
                continue
            
            x_ticks = [32, 64, 128, 256]
            
            axs[0].set_xticks(x_ticks)
            axs[0].set_xlabel('Batch size')
            axs[0].legend()
            axs[0].set_title("Common Channel Communication Model")
            axs[0].set_ylabel('Time Cost')
            axs[0].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
            
            axs[1].set_xticks(x_ticks)
            axs[1].set_xlabel('Batch size')
            axs[1].legend()
            axs[1].set_title("Hypercube Communication Model")
            axs[1].set_ylabel('Time Cost')
            axs[1].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
            
            title = f'Num Clients : {num_clients} , $\Theta$ : {theta}'
            fig.suptitle(title)

            plt.tight_layout()
            pdf.savefig(fig)

            plt.close(fig)
    pdf.close()

In [60]:
def fda_methods_clients(df, filename):
    pdf = PdfPages(filename)
    
    batch_size_values = sorted(df['batch_size'].unique())
    theta_values = sorted(df['theta'].unique())[1:]

    for batch_size in batch_size_values:
        for theta in theta_values:
            
            filtered_df = df[(df['theta'] == theta) & (df['batch_size'] == batch_size)] 
            
            if filtered_df.empty:
                continue
            
            fig, axs = plt.subplots(1, 2, figsize=(20, 6))
            
            for fda_name in fda_names:
                
                if fda_name == 'synchronous':
                    continue
                
                fda_data = filtered_df[filtered_df['fda_name'] == fda_name]
                
                if fda_data.empty:
                    continue

                axs[0].plot(fda_data['num_clients'], fda_data['common_channel_time_cost'], marker='o', color=base_colors[fda_name], label=fda_name, markersize=3)
                axs[1].plot(fda_data['num_clients'], fda_data['hypercube_time_cost'], marker='o', color=base_colors[fda_name], label=fda_name, markersize=3)
            
            synchronous_data = df[(df['batch_size'] == batch_size) & (df['fda_name'] == 'synchronous')] 
            
            if not synchronous_data.empty:
                axs[0].plot(synchronous_data['num_clients'], synchronous_data['common_channel_time_cost'], marker='o', label='synchronous', color=base_colors['synchronous'], markersize=3)
                axs[1].plot(synchronous_data['num_clients'], synchronous_data['hypercube_time_cost'], marker='o', label='synchronous', color=base_colors['synchronous'], markersize=3)

            if not axs[0].has_data():
                plt.close(fig)
                continue
                
            # Set xticks every 5 units based on available values
            min_clients = 5
            max_clients = 60
            
            x_ticks = list(range(min_clients, max_clients + 1, 5))
            
            axs[0].set_yscale('log')
            axs[0].set_xticks(x_ticks)
            axs[0].set_xlabel('Number of clients')
            axs[0].legend()
            axs[0].set_title("Common Channel Communication Model")
            axs[0].set_ylabel('Time Cost')
            axs[0].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
            
            axs[1].set_yscale('log')
            axs[1].set_xticks(x_ticks)
            axs[1].set_xlabel('Number of clients')
            axs[1].legend()
            axs[1].set_title("Hypercube Communication Model")
            axs[1].set_ylabel('Time Cost')
            axs[1].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
            
            title = f'Batch Size : {batch_size} , $\Theta$ : {theta}'
            fig.suptitle(title)

            plt.tight_layout()

            pdf.savefig(fig)

            plt.close(fig)
    pdf.close()

In [69]:
def fda_methods_theta(df, filename):
    pdf = PdfPages(filename)
    
    batch_size_values = sorted(df['batch_size'].unique())
    num_clients_values = sorted(df['num_clients'].unique())

    for batch_size in batch_size_values:
        for num_clients in num_clients_values:
            
            filtered_df = df[(df['batch_size'] == batch_size) & (df['num_clients'] == num_clients)] 
            
            if filtered_df[filtered_df['fda_name'] != 'synchronous'].empty:
                continue
            
            fig, axs = plt.subplots(1, 2, figsize=(20, 6))
            
            for fda_name in fda_names:
                
                if fda_name == 'synchronous':
                    continue
                
                fda_data = filtered_df[filtered_df['fda_name'] == fda_name]
                
                if fda_data.empty:
                    continue

                axs[0].plot(fda_data['theta'], fda_data['common_channel_time_cost'], marker='o', label=fda_name, color=base_colors[fda_name], markersize=3)
                axs[1].plot(fda_data['theta'], fda_data['hypercube_time_cost'], marker='o', label=fda_name, color=base_colors[fda_name], markersize=3)
            
            synchronous_data = df[(df['batch_size'] == batch_size) & (df['fda_name'] == 'synchronous') & (df['num_clients'] == num_clients)] 
            
            if not synchronous_data.empty:
                axs[0].axhline(y=synchronous_data['common_channel_time_cost'].iloc[0], label='synchronous', marker='o', color=base_colors['synchronous'], markersize=3)
                axs[1].axhline(y=synchronous_data['hypercube_time_cost'].iloc[0], label='synchronous', marker='o', color=base_colors['synchronous'], markersize=3)

            if not axs[0].has_data():
                plt.close(fig)
                continue
            
            axs[0].set_yscale('log')
            axs[0].set_xlabel('Theta')
            axs[0].legend()
            axs[0].set_title("Common Channel Communication Model")
            axs[0].set_ylabel('Time Cost')
            axs[0].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
            
            axs[1].set_yscale('log')
            axs[1].set_xlabel('Theta')
            axs[1].legend()
            axs[1].set_title("Hypercube Communication Model")
            axs[1].set_ylabel('Time Cost')
            axs[1].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
            
            title = f'Batch Size : {batch_size} , Num Clients : {num_clients}'
            fig.suptitle(title)

            plt.tight_layout()

            pdf.savefig(fig)

            plt.close(fig)
    pdf.close()

# Help-Stat

In [70]:
def explore_top(df, acc_thresh, nn_name, fda_name):
    acceptable_acc_df = df[(df.accuracy > acc_thresh) & (df.nn_name == nn_name)]
    acceptable_acc_df = acceptable_acc_df[acceptable_acc_df['fda_name'] == fda_name]
    idx = acceptable_acc_df.groupby(['fda_name', 'num_clients', 'bias', 'batch_size', 'theta'], dropna=False)['epoch'].idxmin()
    filtered_acceptable_acc_df = acceptable_acc_df.loc[idx]
    return filtered_acceptable_acc_df[['num_clients', 'batch_size', 'theta', 'bias', 'total_rounds', 'total_fda_steps', 'epoch', 'total_communication_gb', 'cpu_time_cost', 'hypercube_time_cost', 'common_channel_time_cost']].sort_values(by='common_channel_time_cost')

In [71]:
def mean_epoch_per_method(df, acc_thresh, nn_name, fda_name):
    top_df = explore_top(df, acc_thresh, nn_name, fda_name)
    print(f"Mean epochs : {top_df['epoch'].mean()}")
    print(f"Std epochs : {top_df['epoch'].std()}")
    print(f"Mean rounds : {top_df['total_rounds'].mean()}")
    print(f"Std rounds : {top_df['total_rounds'].std()}")

## Save all those time-cost plots

In [72]:
import os

def time_cost_plots(df, acc_threshold, nn_name, limit_x_axis=False, show_runs=False, params=False, addi_name='', kde_time_log=True, kde_comm_log=True, kde_cpu_log=False,
                    show_bias=False, kde_time_log_bias=False):
    # Filter out based on `acc_threshold`
    acceptable_acc_df = df[(df.accuracy > acc_threshold) & (df.nn_name == nn_name)]
    
    str_thresh = str(acc_threshold).replace('.', '_')  # replace '.'
    
    if not os.path.exists(f"../../metrics/plots/{nn_name}/{str_thresh}"):
        os.makedirs(f"../../metrics/plots/{nn_name}/{str_thresh}")
    
    if show_runs:
        accept_acc_no_bias = acceptable_acc_df[acceptable_acc_df['bias'].isna()]
        # Plot the runs of each method. x-axis : time cost, y-axis : accuracy, PER number of clients (lines)
        fda_method_run_line_plot(accept_acc_no_bias, 'sketch', f"../../metrics/plots/{nn_name}/{str_thresh}/sketch_run.pdf", limit_x_axis=limit_x_axis)
        fda_method_run_line_plot(accept_acc_no_bias, 'naive', f"../../metrics/plots/{nn_name}/{str_thresh}/naive_run.pdf", limit_x_axis=limit_x_axis)
        fda_method_run_line_plot(accept_acc_no_bias, 'linear', f"../../metrics/plots/{nn_name}/{str_thresh}/linear_run.pdf", limit_x_axis=limit_x_axis)
        fda_method_run_line_plot(accept_acc_no_bias, 'synchronous', f"../../metrics/plots/{nn_name}/{str_thresh}/synchronous_run.pdf", limit_x_axis=limit_x_axis)
        
    # 1. Same runs not included
    
    # 2. Filter out same runs. We choose the instance which first hits the `acc_threshold`
    idx = acceptable_acc_df.groupby(['fda_name', 'num_clients', 'batch_size', 'theta', 'bias'], dropna=False)['epoch'].idxmin()
    filtered_acceptable_acc_df = acceptable_acc_df.loc[idx]
    
    # 2. NO BIAS. Same runs are NOT included
    no_bias_filt_df = filtered_acceptable_acc_df[filtered_acceptable_acc_df['bias'].isna()]
    kde_time_cost(no_bias_filt_df, f"../../metrics/plots/{nn_name}/{str_thresh}/kde_time_cost{addi_name}.pdf", x_log=kde_time_log)
    #scatter_time_cost_cpu_ratio(filtered_acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/scatter_time_cost_cpu_ratio.pdf")
    kde_communication_cost(no_bias_filt_df, f"../../metrics/plots/{nn_name}/{str_thresh}/kde_communication_cost{addi_name}.pdf", x_log=kde_comm_log)
    kde_cpu_time_cost(no_bias_filt_df, f"../../metrics/plots/{nn_name}/{str_thresh}/kde_cpu_time_cost{addi_name}.pdf", x_log=kde_cpu_log)
    
    if show_bias:
        kde_time_cost_biases(filtered_acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/biases.pdf", x_log=kde_time_log_bias)
    
    # Parameters fixed, and plot
    if params:
        #fda_methods_batch_size(no_bias_filt_df, f"../../metrics/plots/{nn_name}/{str_thresh}/fda_methods_batch_size.pdf")
        fda_methods_clients(no_bias_filt_df, f"../../metrics/plots/{nn_name}/{str_thresh}/fda_methods_clients.pdf")
        fda_methods_theta(no_bias_filt_df, f"../../metrics/plots/{nn_name}/{str_thresh}/fda_methods_theta.pdf")

In [73]:
pd.set_option('display.max_rows', None)

In [74]:
df_bias = df[~df['bias'].isna()]
df_no_bias = df[df['bias'].isna()]

## AdvancedCNN

In [77]:
df_32 = df[
    (df['batch_size'] == 32) | (df['fda_name'] == 'synchronous')
]

df_fin = df_32[
    (df_32['theta'] >= 15) | (df_32['fda_name'] == 'synchronous')
]

In [78]:
time_cost_plots(df_fin, 0.99, 'AdvancedCNN', addi_name='_b32', params=True)
time_cost_plots(df_fin, 0.993, 'AdvancedCNN', addi_name='_b32', params=True)
time_cost_plots(df_fin, 0.995, 'AdvancedCNN', addi_name='_b32', params=True)
#time_cost_plots(df_fin, 0.996, 'AdvancedCNN', addi_name='_b32')

In [46]:
#explore_top(df_fin, 0.995, 'AdvancedCNN', 'synchronous')

In [47]:
explore_top(df_fin, 0.995, 'AdvancedCNN', 'gm')

Unnamed: 0,num_clients,batch_size,theta,bias,total_rounds,total_fda_steps,epoch,total_communication_gb,cpu_time_cost,hypercube_time_cost,common_channel_time_cost
5910,5,32,100.0,,34,4125,11,3.525397,36.518625,47.799896,59.081168
6104,5,32,20.0,,77,1875,5,7.983983,16.599375,42.148122,67.696868
6308,5,32,50.0,,63,3375,9,6.532351,29.878875,50.782399,71.685922
6206,5,32,30.0,,75,2625,7,7.776608,23.239125,48.124269,73.009414
16,10,32,100.0,,32,3188,17,6.636041,28.223364,44.149863,76.00286
522,15,32,100.0,,28,2875,23,8.709804,25.452375,39.388062,90.485579
415,10,32,50.0,,62,3000,16,12.857326,26.559,57.416582,119.131745
1140,20,32,100.0,,36,3844,41,14.931093,34.030932,51.948244,147.50724
313,10,32,30.0,,84,2625,14,17.419601,23.239125,65.046167,148.66025
919,15,32,50.0,,55,2500,20,17.108538,22.1325,49.506161,149.87625


In [48]:
explore_top(df_fin, 0.995, 'AdvancedCNN', 'sketch')

Unnamed: 0,num_clients,batch_size,theta,bias,total_rounds,total_fda_steps,epoch,total_communication_gb,cpu_time_cost,hypercube_time_cost,common_channel_time_cost
178008,5,32,75.0,,24,3375,9,2.572956,29.878875,38.112336,46.345796
177511,5,32,100.0,,22,4500,12,2.393728,39.8385,47.498429,55.158358
177807,5,32,30.0,,55,3000,8,5.777904,26.559,45.048294,63.537588
170516,10,32,75.0,,27,3188,17,5.758684,28.223364,42.044205,69.685888
170827,15,32,100.0,,19,3500,28,6.172931,30.9855,40.862189,77.076715
170415,10,32,50.0,,36,3000,16,7.615662,26.559,44.836588,81.391765
171324,15,32,75.0,,24,3125,25,7.700104,27.665625,39.985792,85.159737
170027,10,32,100.0,,25,5250,28,5.447114,46.47825,59.551324,85.697471
177914,5,32,50.0,,55,5625,15,5.843582,49.798125,68.497587,87.197049
170311,10,32,30.0,,48,2250,12,10.066646,19.91925,44.0792,92.399099


In [49]:
explore_top(df_fin, 0.996, 'AdvancedCNN', 'synchronous')

Unnamed: 0,num_clients,batch_size,theta,bias,total_rounds,total_fda_steps,epoch,total_communication_gb,cpu_time_cost,hypercube_time_cost,common_channel_time_cost
183370,20,128,0.0,,4008,4008,171,1662.327299,48.051912,2042.84467,12681.739381
213695,5,256,0.0,,23250,23250,496,2410.74786,382.9275,8097.320652,15811.713804
221439,15,64,0.0,,8750,8750,140,2721.8121,90.34375,4445.24311,20413.20743
218767,35,32,0.0,,3643,3643,68,2644.149728,32.251479,2449.759802,20581.072223
182747,15,128,0.0,,17125,17125,548,5326.97511,205.311625,8728.471801,39980.059113
217342,10,32,0.0,,26813,26813,143,5560.376978,237.375489,13582.280236,40272.089731
198390,20,256,0.0,,13958,13958,1191,5789.112883,229.88826,7176.823719,44227.146167
221143,10,64,0.0,,32250,32250,344,6687.88116,332.98125,16383.896034,48485.725602
184888,25,128,0.0,,12919,12919,689,6697.731528,154.885891,8727.982246,51593.464023
196721,15,256,0.0,,23782,23782,1522,7397.729756,391.68954,12228.057149,55628.071716


In [50]:
explore_top(df_fin, 0.996, 'AdvancedCNN', 'gm')

Unnamed: 0,num_clients,batch_size,theta,bias,total_rounds,total_fda_steps,epoch,total_communication_gb,cpu_time_cost,hypercube_time_cost,common_channel_time_cost
419,10,32,50.0,,80,3750,20,16.590097,33.19875,73.014984,152.647452
1188,20,32,100.0,,130,8344,89,53.917822,73.869432,138.570819,483.644883
1087,15,32,75.0,,214,11000,88,66.567768,97.383,203.891429,594.422334
2163,25,32,50.0,,158,4800,64,81.913598,42.4944,147.343806,671.590834
2290,25,32,75.0,,193,6825,91,100.059019,60.421725,188.497269,828.874987
6394,5,32,50.0,,820,35625,95,85.024248,315.388125,587.465718,859.543311
6157,5,32,20.0,,1094,21750,58,113.434773,192.55275,555.544024,918.535298
1363,20,32,20.0,,459,6000,64,190.37133,53.118,281.563596,1499.940107
3288,35,32,30.0,,312,4768,89,226.454788,42.211104,249.255481,1802.08831


In [51]:
#explore_top(df_fin, 0.995, 'AdvancedCNN', 'linear')

## LeNet-5

In [52]:
explore_top(df_32, 0.985, 'LeNet-5', 'naive') #~~~

Unnamed: 0,num_clients,batch_size,theta,bias,total_rounds,total_fda_steps,epoch,total_communication_gb,cpu_time_cost,hypercube_time_cost,common_channel_time_cost
371118,10,32,3.0,,35,3563,19,0.172919,23.562119,23.977125,24.807138
369916,10,32,0.5,,149,3188,17,0.735663,21.082244,22.847835,26.379018
371826,15,32,1.0,,94,3375,27,0.696246,22.318875,23.432869,27.517513
387010,5,32,1.0,,101,4125,11,0.249375,27.278625,28.076624,28.874623
370120,10,32,1.0,,107,3938,21,0.528361,26.041994,27.31006,29.846192
370521,10,32,1.5,,78,4125,22,0.38521,27.278625,28.20313,30.05214
374543,20,32,3.0,,43,4125,44,0.424867,27.278625,27.788466,30.507616
375647,25,32,1.5,,71,3600,48,0.876585,23.8068,24.928829,30.538974
372231,15,32,1.5,,74,4000,32,0.548189,26.452,27.329103,30.545147
371626,15,32,0.5,,158,3375,27,1.170148,22.318875,24.191112,31.055982


In [53]:
explore_top(df_32, 0.98, 'LeNet-5', 'naive')

Unnamed: 0,num_clients,batch_size,theta,bias,total_rounds,total_fda_steps,epoch,total_communication_gb,cpu_time_cost,hypercube_time_cost,common_channel_time_cost
369909,10,32,0.5,,92,1875,10,0.454231,12.399375,13.48953,15.669839
371815,15,32,1.0,,59,2000,16,0.436998,13.226,13.925198,16.488922
370111,10,32,1.0,,64,2250,12,0.316025,14.87925,15.637709,17.154628
387506,5,32,15.0,,7,2625,7,0.01733,17.359125,17.414582,17.470038
374524,20,32,3.0,,27,2344,25,0.266757,15.500872,15.820981,17.528229
370313,10,32,10.0,,11,2625,14,0.054406,17.359125,17.4897,17.75085
371313,10,32,5.0,,18,2625,14,0.088962,17.359125,17.572633,17.999649
371113,10,32,3.0,,27,2625,14,0.13339,17.359125,17.679261,18.319533
387006,5,32,1.0,,64,2625,7,0.15802,17.359125,17.864789,18.370452
372820,15,32,3.0,,27,2625,21,0.200085,17.359125,17.679261,18.853093


In [54]:
#explore_top(df, 0.98, 'LeNet-5', 'linear')

In [79]:
time_cost_plots(df_32, 0.98, 'LeNet-5', kde_comm_log=True, kde_time_log_bias=True, params=True)
time_cost_plots(df_32, 0.975, 'LeNet-5', kde_comm_log=True, kde_time_log_bias=True, params=True)
time_cost_plots(df_32, 0.985, 'LeNet-5', kde_comm_log=True, kde_time_log_bias=True, params=True)
#time_cost_plots(df, 0.987, 'LeNet-5', kde_time_log=False, kde_comm_log=True, show_bias=True)
#time_cost_plots(df, 0.99, 'LeNet-5', kde_time_log=False, kde_comm_log=True, show_bias=True)

In [56]:
mean_epoch_per_method(df, 0.98, 'LeNet-5', 'synchronous')

Mean epochs : 116.35849056603773
Std epochs : 136.56654533243602
Mean rounds : 2803.106918238994
Std rounds : 3219.040861929935


In [57]:
mean_epoch_per_method(df_fin, 0.995, 'AdvancedCNN', 'naive')

Mean epochs : 43.24285714285714
Std epochs : 25.10005237842316
Mean rounds : 67.32857142857142
Std rounds : 32.167010478858586


In [58]:
mean_epoch_per_method(df, 0.995, 'AdvancedCNN', 'synchronous')

Mean epochs : 192.22222222222223
Std epochs : 310.4632088354497
Mean rounds : 3307.0666666666666
Std rounds : 1897.6396750604779
