In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import glob
import seaborn as sns

## Read from `~/metrics/epoch_metrics/` directory

In [2]:
df = pd.read_parquet("../../metrics/epoch_metrics/")

# Hyper-Parameters

## All tests summary

Note that this is a summary. 

Notes:
1. `synchronous` method implies `theta` equal to `0.0` and the reverse.
2. The stopping criterion (epochs at the moment) is not the same for all combination, in fact, it changes. For `LeNet-5` currently all tests end at 50 epochs. For the `AdvancedCNN` we go up to 350 epochs in some tests and some others up until 250.
3. `Theta` greater than 2 are exploratory tests and will not be considered at all later on.

In [3]:
for col in ['dataset_name', 'fda_name', 'nn_num_weights', 'num_clients', 'batch_size', 'num_steps_until_rtc_check', 'theta']:
    print(f"{col}: {sorted(list(df[col].unique()))}")

dataset_name: ['MNIST']
fda_name: ['linear', 'naive', 'sketch', 'synchronous']
nn_num_weights: [61706, 2592202]
num_clients: [5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60]
batch_size: [32, 64, 128, 256]
num_steps_until_rtc_check: [1]
theta: [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 5.0, 7.0, 10.0, 15.0]


## Query All tests

In [4]:
test_combinations = df.groupby(['dataset_name', 'nn_num_weights', 'fda_name', 'num_steps_until_rtc_check', 'batch_size', 'theta', 'num_clients'])['epoch'].max().reset_index()

In [5]:
test_combinations[
    (test_combinations['nn_num_weights'] == 61706) &
    (test_combinations['fda_name'] == 'sketch') &
    (test_combinations['theta'] == 2) &
    (test_combinations['num_clients'] == 5)
]

Unnamed: 0,dataset_name,nn_num_weights,fda_name,num_steps_until_rtc_check,batch_size,theta,num_clients,epoch


# Helpful new Dataframe metrics

### Add Helpful Dataset Metrics 

In [6]:
def dataset_n_train(row):
    if row['dataset_name'] == "MNIST":
        return 60_000
    else:
        return -1


def dataset_one_sample_bytes(row):
    if row['dataset_name'] == "MNIST":
        # input image 784 tf.float32 pixels and a tf.int32 label
        return 4 * (784 + 1)
    else:
        return -1

In [7]:
df['n_train'] = df.apply(dataset_n_train, axis=1)
df['one_sample_bytes'] = df.apply(dataset_one_sample_bytes, axis=1)

### Add Helpful model metrics

In [8]:
df['model_bytes'] = df['nn_num_weights'] * 4

### Add Helpful FDA method metrics

In [9]:
def fda_local_state_bytes(row):
    if row['fda_name'] == "naive":
        return 4
    if row['fda_name'] == "linear":
        return 8
    if row['fda_name'] == "sketch":
        return row['sketch_width'] * row['sketch_depth'] * 4 + 4
    if row['fda_name'] == "synchronous":
        return 0

In [10]:
df['local_state_bytes'] = df.apply(fda_local_state_bytes, axis=1)

### Add Total Steps

total steps (a single fda step might have many normal SGD steps, batch steps)

In [11]:
df['total_steps'] = df['total_fda_steps'] * df['num_steps_until_rtc_check']

### Add communication metrics

The communication bytes exchanged for model synchronization. Remember that the Clients send their models to the Server and the Server sends the global model back. This happens at the end of every round.

In [12]:
df['model_bytes_exchanged'] = df['total_rounds'] * df['model_bytes'] * df['num_clients'] * 2

The communication bytes exchanged for monitoring the variance. This happens at the end of every FDA step which consists of `num_steps_until_rtc_check` number of steps. 

In [13]:
df['monitoring_bytes_exchanged'] = df['local_state_bytes'] * df['total_fda_steps'] * df['num_clients']

The total communication bytes exchanged in the whole Federated Learning lifecycle.

In [14]:
df['total_communication_bytes'] = df['model_bytes_exchanged'] + df['monitoring_bytes_exchanged']

In [15]:
df['total_communication_gb'] = df['total_communication_bytes'] / 10**9

Add rounds in one epoch.

In [16]:
df = df.sort_values(by=['dataset_name', 'fda_name', 'nn_num_weights', 'num_clients', 'batch_size', 'num_steps_until_rtc_check', 'theta', 'epoch'])

df['epoch_rounds'] = df.groupby(['dataset_name', 'fda_name', 'nn_num_weights', 'num_clients', 'batch_size', 'num_steps_until_rtc_check', 'theta'])['total_rounds'].diff()

# NaN first epoch
df['epoch_rounds'] = df['epoch_rounds'].fillna(df['total_rounds'])

df['epoch_rounds'] = df['epoch_rounds'].astype(int)

# HyperParameter ranking

### AdvancedCNN
On 8 CPUs, the step time:

1. *Batch Size* = 32 -> `307ms`
2. *Batch Size* = 64 -> `445ms`
3. *Batch Size* = 128 -> `815ms`
4. *Batch Size* = 256 -> `1401ms`

Best fit line:

step(ms) = 4.97092 * batch_size + 147.739

### LeNet-5
On 8 CPUs, the step time:

1. *Batch Size* = 32 -> `5.93ms`
2. *Batch Size* = 64 -> `9.16ms`
3. *Batch Size* = 128 -> `18.5ms`
4. *Batch Size* = 256 -> `30.6ms`

Best fit line:

step(ms) = 0.11124 * batch_size + 2.69913

In [17]:
def step_ms(batch_size, nn_name):
    if nn_name == 'AdvancedCNN': 
        return 4.97092 * batch_size + 147.739
    if nn_name == 'LeNet-5':
        return 0.11124 * batch_size + 2.69913

Time cost for training-reducing

In [18]:
import numpy as np

def cpu_time_cost(row):
    """ Total cpu time cost in (sec).
    A single `step` means each client performed a single `step` 
    """
    return row['total_steps'] * step_ms(row['batch_size'], row['nn_name']) / 1000

def communication_time_cost(num_clients, total_communication_bytes, comm_model):
    """ Assuming channel is 1Gbps """

    total_communication_gbit = total_communication_bytes * 8e-9

    if comm_model == 'common_channel':
        
        return ((num_clients - 1) / num_clients) * total_communication_gbit    # sec

    if comm_model == 'hypercube':

        return (np.ceil(np.log(num_clients)) / num_clients) * total_communication_gbit   # sec

In [19]:
df['cpu_time_cost'] = df.apply(cpu_time_cost, axis=1)

In [20]:
df['hypercube_communication_time_cost'] = communication_time_cost(df['num_clients'], df['total_communication_bytes'], 'hypercube')

In [21]:
df['common_channel_communication_time_cost'] = communication_time_cost(df['num_clients'], df['total_communication_bytes'], 'common_channel')

In [22]:
df['hypercube_time_cost'] = df['cpu_time_cost'] + df['hypercube_communication_time_cost']

In [23]:
df['common_channel_time_cost'] = df['cpu_time_cost'] + df['common_channel_communication_time_cost']

In [24]:
df['hypercube_comm_cpu_time_ratio'] = df['hypercube_communication_time_cost'] / df['cpu_time_cost']

In [25]:
df['common_channel_comm_cpu_time_ratio'] = df['common_channel_communication_time_cost'] / df['cpu_time_cost']

# Plots about cost

In [26]:
# Define styles for each fda_name
fda_styles = {
    'naive': 'o-r',
    'linear': 's-g',
    'sketch': '^-b',
    'synchronous': 'x-c'
}
fda_names = sorted(df['fda_name'].unique())

In [27]:
import matplotlib

num_clients_values = sorted(df['num_clients'].unique())
cmap = matplotlib.colormaps['tab20b']
colors_dict = {
    num_clients: color 
    for num_clients, color in zip(num_clients_values, cmap(np.linspace(0, 1, len(num_clients_values))))
}

## Total time cost with accuracy (scatter)

### KDE

In [98]:
def kde_time_cost(df, filename):
    pdf = PdfPages(filename)
    
    fig, axs = plt.subplots(1, 2, figsize=(20, 6))
    
    hist_data_common_channel = []
    hist_data_hypercube = []

    # Prepare a list to store the average communication for each method
    avg_time_cost_common_channel_dict = {}
    avg_time_cost_hypercube_dict = {}
    
    # Prepare lists to store the average information (string) for each subplot
    avg_info_common_channel = []
    avg_info_hypercube = []

    for fda_name in fda_names:
        common_channel_df = df[df['fda_name'] == fda_name]['common_channel_time_cost']
        avg_time_cost_common_channel_dict[fda_name] = common_channel_df.mean()
        avg_info_common_channel.append(f'{fda_name}: {avg_time_cost_common_channel_dict[fda_name]:.2f} sec')
        sns.kdeplot(common_channel_df, label=fda_name, fill=True, alpha=0.4, ax=axs[0], bw_method='scott', bw_adjust=0.2)
        
        hypercube_df = df[df['fda_name'] == fda_name]['hypercube_time_cost']
        avg_time_cost_hypercube_dict[fda_name] = hypercube_df.mean()
        avg_info_hypercube.append(f'{fda_name}: {avg_time_cost_hypercube_dict[fda_name]:.2f} sec')
        sns.kdeplot(hypercube_df, label=fda_name, fill=True, alpha=0.4, ax=axs[1], bw_method='scott', bw_adjust=0.2)
        
    
    text = "Average Time Cost:\n" + '\n'.join(avg_info_common_channel)
    # Add the text annotation inside the plot
    axs[0].text(0.62, 0.97, text, transform=axs[0].transAxes, fontsize=9, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.4))
    
    axs[0].set_xlim(left=0)
    axs[0].set_xlabel('Time Cost (sec)')
    axs[0].set_ylabel('Density')
    axs[0].legend()
    axs[0].set_title("Common Channel Communication Model")
    axs[0].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
    axs[0].legend()

    text = "Average Time Cost:\n" + '\n'.join(avg_info_hypercube)
    # Add the text annotation inside the plot
    axs[1].text(0.62, 0.97, text, transform=axs[1].transAxes, fontsize=9, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.4))
    
    axs[1].set_xlim(left=0)
    axs[1].set_xlabel('Time Cost (sec)')
    axs[1].set_ylabel('Density')
    axs[1].legend()
    axs[1].set_title("Hypercube Communication Model")
    axs[1].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
    axs[1].legend()
    
    plt.tight_layout()
    
    pdf.savefig(fig)
        
    # Close the current figure to prevent it from being displayed in the notebook
    plt.close(fig)
    pdf.close()

## Total Communication cost (in gb) with accuracy (scatter)

### KDE

In [99]:
def kde_communication_cost(df, filename):
    pdf = PdfPages(filename)
    
    plt.figure(figsize=(10, 6))

    avg_communications_dict = {}
    avg_info = []

    for fda_name in fda_names:
        fda_data_df = df[df['fda_name'] == fda_name]['total_communication_gb']
        avg_communications_dict[fda_name] = fda_data_df.mean()
        avg_info.append(f'{fda_name}: {avg_communications_dict[fda_name]:.2f} GB')
        
        # Plotting only the KDE using kdeplot
        sns.kdeplot(fda_data_df, label=fda_name, fill=True, alpha=0.4)

    text = "Average Communication:\n" + '\n'.join(avg_info)
    # Add the text annotation inside the plot
    plt.text(0.62, 0.97, text, transform=plt.gca().transAxes, fontsize=9, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.4))

    plt.xlim(left=0)
    plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
    plt.xlabel('Communication (GB)')
    plt.ylabel('Density')
    plt.legend()
    plt.tight_layout()
    
    pdf.savefig(plt.gcf()) # Save the current figure
        
    # Close the current figure to prevent it from being displayed in the notebook
    plt.close()
    pdf.close()

## Total CPU time (in Seconds) with accuracy

### KDE

In [100]:
def kde_cpu_time_cost(df, filename):
    pdf = PdfPages(filename)
    
    plt.figure(figsize=(10, 6))

    avg_cpu_dict = {}
    avg_info = []

    for fda_name in fda_names:
        fda_data_df = df[df['fda_name'] == fda_name]['cpu_time_cost']
        avg_cpu_dict[fda_name] = fda_data_df.mean()
        avg_info.append(f'{fda_name}: {avg_cpu_dict[fda_name]:.2f} sec')
        
        # Plotting only the KDE using kdeplot
        sns.kdeplot(fda_data_df, label=fda_name, fill=True, alpha=0.4)

    text = "Average CPU cost:\n" + '\n'.join(avg_info)
    # Add the text annotation inside the plot
    plt.text(0.62, 0.97, text, transform=plt.gca().transAxes, fontsize=9, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.4))

    plt.xlim(left=0)
    plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
    plt.xlabel('CPU time cost (sec)')
    plt.ylabel('Count')
    plt.legend()
    plt.tight_layout()
    
    pdf.savefig(plt.gcf()) # Save the current figure
        
    # Close the current figure to prevent it from being displayed in the notebook
    plt.close()
    pdf.close()

## Communication/CPU time - Accuracy

### Scatter

In [101]:
def scatter_time_cost_cpu_ratio(df, filename):
    
    pdf = PdfPages(filename)

    fig, axs = plt.subplots(1, 2, figsize=(20, 6))

    # Plot the data points for each method (fda_name) for single-bus model
    for fda_name in fda_names:
        fda_filtered_data = df[(df['fda_name'] == fda_name)] 
        axs[0].scatter(fda_filtered_data['common_channel_comm_cpu_time_ratio'], fda_filtered_data['accuracy'], label=fda_name)

    axs[0].set_xlabel('(Communication time) / (CPU time)')
    axs[0].set_ylabel('Accuracy')
    axs[0].legend()
    axs[0].set_title("Common Channel Communication Model")

    # Plot the data points for each method (fda_name) for multi-bus model
    for fda_name in fda_names:
        fda_filtered_data = df[(df['fda_name'] == fda_name)] 
        axs[1].scatter(fda_filtered_data['hypercube_comm_cpu_time_ratio'], fda_filtered_data['accuracy'], label=fda_name)

    axs[0].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
    axs[1].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)   
        
    axs[1].set_xlabel('(Communication time) / (CPU time)')
    axs[1].set_ylabel('Accuracy')
    axs[1].legend()
    axs[1].set_title("Hypercube Communication Model")
    
    axs[0].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
    axs[1].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)

    plt.tight_layout()
    pdf.savefig(fig)
        
    # Close the current figure to prevent it from being displayed in the notebook
    plt.close(fig)
    pdf.close()

## Different FDA-method runs visualization with lines (per clients)

In [102]:
import matplotlib
from matplotlib import cm

def fda_method_run_line_plot(df, fda_name, filename, limit_x_axis=False):
    
    hybrid_df = df[df.fda_name == fda_name]
    
    batch_size_values = sorted(hybrid_df['batch_size'].unique())
    theta_values = sorted(hybrid_df['theta'].unique())
    
    pdf = PdfPages(filename)

    for batch_size in batch_size_values:
        for theta in theta_values:
            filtered_data = hybrid_df[(hybrid_df['theta'] == theta) &
                                       (hybrid_df['batch_size'] == batch_size)] 

            if filtered_data.empty:
                continue
            
            num_clients_values = sorted(filtered_data['num_clients'].unique())
            
            #print(f"{batch_size} {theta} {filtered_data[['fda_name','accuracy']]}")
                
            # Because num_clients = 2 reaches very good accuracy very early we need to limit the x-axis so we can 
            # visualize the rest of the data more easily. We get the next max time cost and limit x
            # Get the maximum 'common_channel_time_cost' and 'hypercube_time_cost' when num_clients is not 2
            if limit_x_axis:
                # The num clients = 2 cause problems. We put limit 10 after the next in line
                max_common_channel_time_cost = filtered_data[filtered_data['num_clients'] != 2]['common_channel_time_cost'].max()
                max_hypercube_time_cost = filtered_data[filtered_data['num_clients'] != 2]['hypercube_time_cost'].max()


            fig, axs = plt.subplots(1, 2, figsize=(20, 6))

            # Plot each group with a unique color based on num_clients
            for num_clients in num_clients_values:
                data = filtered_data[filtered_data['num_clients'] == num_clients]

                axs[0].plot(data['common_channel_time_cost'], data['accuracy'], color=colors_dict[num_clients], label=num_clients, marker='o', markersize=3)
                axs[1].plot(data['hypercube_time_cost'], data['accuracy'], color=colors_dict[num_clients], label=num_clients, marker='o', markersize=3)
                
            
            axs[0].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
            axs[1].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
            
            #axs[0].set_ylim(top=1)
            #axs[1].set_ylim(top=1)

            axs[0].set_xlabel('Time Cost')
            axs[0].set_ylabel('Accuracy')
            axs[0].set_title("Common Channel Communication Model")
            axs[0].legend(title='Num Clients')
            if limit_x_axis:
                axs[0].set_xlim(0, max_common_channel_time_cost+700)  # set x-axis limit

            axs[1].set_xlabel('Time Cost')
            axs[1].set_ylabel('Accuracy')
            axs[1].set_title("Hypercube Communication Model")
            axs[1].legend(title='Num Clients')
            if limit_x_axis:
                axs[1].set_xlim(0, max_hypercube_time_cost+500)  # set x-axis limit

            title = f'Batch Size : {batch_size} , $\Theta$ : {theta}'

            fig.suptitle(title)

            plt.tight_layout()

            pdf.savefig(fig)

            plt.close(fig)

    pdf.close()

## Keep Hyper-parameters fixed and plot for filtered

In [103]:
def fda_methods_batch_size(df, filename):
    pdf = PdfPages(filename)
    
    num_clients_values = sorted(df['num_clients'].unique())
    theta_values = sorted(df['theta'].unique())[1:]

    for num_clients in num_clients_values:
        for theta in theta_values:
            filtered_df = df[(df['theta'] == theta) & (df['num_clients'] == num_clients)] 
            synchronous_data = df[(df['num_clients'] == num_clients) & (df['fda_name'] == 'synchronous')] 
            
            fig, axs = plt.subplots(1, 2, figsize=(20, 6))
            
            for fda_name in fda_names:
                
                fda_data = filtered_df[filtered_df['fda_name'] == fda_name]
                
                if fda_data.empty:
                    continue

                axs[0].plot(fda_data['batch_size'], fda_data['common_channel_time_cost'], marker='o', label=fda_name, markersize=3)
                axs[1].plot(fda_data['batch_size'], fda_data['hypercube_time_cost'], marker='o', label=fda_name, markersize=3)
            
            if not synchronous_data.empty:
                axs[0].plot(synchronous_data['batch_size'], synchronous_data['common_channel_time_cost'], marker='o', label='synchronous', markersize=3)
                axs[1].plot(synchronous_data['batch_size'], synchronous_data['hypercube_time_cost'], marker='o', label='synchronous', markersize=3)

            
            if not axs[0].has_data():
                plt.close(fig)
                continue
            
            x_ticks = [32, 64, 128, 256]
            
            axs[0].set_xticks(x_ticks)
            axs[0].set_xlabel('Batch size')
            axs[0].legend()
            axs[0].set_title("Common Channel Communication Model")
            axs[0].set_ylabel('Time Cost')
            axs[0].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
            
            axs[1].set_xticks(x_ticks)
            axs[1].set_xlabel('Batch size')
            axs[1].legend()
            axs[1].set_title("Hypercube Communication Model")
            axs[1].set_ylabel('Time Cost')
            axs[1].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
            
            title = f'Num Clients : {num_clients} , $\Theta$ : {theta}'
            fig.suptitle(title)

            plt.tight_layout()
            pdf.savefig(fig)

            plt.close(fig)
    pdf.close()

In [104]:
def fda_methods_clients(df, filename):
    pdf = PdfPages(filename)
    
    batch_size_values = sorted(df['batch_size'].unique())
    theta_values = sorted(df['theta'].unique())[1:]

    for batch_size in batch_size_values:
        for theta in theta_values:
            
            filtered_df = df[(df['theta'] == theta) & (df['batch_size'] == batch_size)] 
            
            if filtered_df.empty:
                continue
                
            synchronous_data = df[(df['batch_size'] == batch_size) & (df['fda_name'] == 'synchronous')] 
            
            fig, axs = plt.subplots(1, 2, figsize=(20, 6))
            
            for fda_name in fda_names:
                
                fda_data = filtered_df[filtered_df['fda_name'] == fda_name]
                
                if fda_data.empty:
                    continue

                empty = False
                axs[0].plot(fda_data['num_clients'], fda_data['common_channel_time_cost'], marker='o', label=fda_name, markersize=3)
                axs[1].plot(fda_data['num_clients'], fda_data['hypercube_time_cost'], marker='o', label=fda_name, markersize=3)
            
            if not synchronous_data.empty:
                axs[0].plot(synchronous_data['num_clients'], synchronous_data['common_channel_time_cost'], marker='o', label='synchronous', markersize=3)
                axs[1].plot(synchronous_data['num_clients'], synchronous_data['hypercube_time_cost'], marker='o', label='synchronous', markersize=3)

            if not axs[0].has_data():
                plt.close(fig)
                continue
                
            # Set xticks every 5 units based on available values
            min_clients = 5
            max_clients = 60
            
            x_ticks = list(range(min_clients, max_clients + 1, 5))
            
            axs[0].set_xticks(x_ticks)
            axs[0].set_xlabel('Number of clients')
            axs[0].legend()
            axs[0].set_title("Common Channel Communication Model")
            axs[0].set_ylabel('Time Cost')
            axs[0].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
            
            axs[1].set_xticks(x_ticks)
            axs[1].set_xlabel('Number of clients')
            axs[1].legend()
            axs[1].set_title("Hypercube Communication Model")
            axs[1].set_ylabel('Time Cost')
            axs[1].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
            
            title = f'Batch Size : {batch_size} , $\Theta$ : {theta}'
            fig.suptitle(title)

            plt.tight_layout()

            pdf.savefig(fig)

            plt.close(fig)
    pdf.close()

# Help-Stat

In [160]:
def explore_top(df, acc_thresh, nn_name, fda_name):
    acceptable_acc_df = df[(df.accuracy > acc_thresh) & (df.nn_name == nn_name)]
    acceptable_acc_df = acceptable_acc_df[acceptable_acc_df['fda_name'] == fda_name]
    idx = acceptable_acc_df.groupby(['fda_name', 'num_clients', 'batch_size', 'theta'])['epoch'].idxmin()
    filtered_acceptable_acc_df = acceptable_acc_df.loc[idx]
    return filtered_acceptable_acc_df[['num_clients', 'batch_size', 'theta', 'total_rounds', 'total_fda_steps', 'epoch', 'total_communication_gb', 'common_channel_time_cost']].sort_values(by='common_channel_time_cost')

In [161]:
def mean_epoch_per_method(df, acc_thresh, nn_name, fda_name):
    top_df = explore_top(df, acc_thresh, nn_name, fda_name)
    print(f"Mean epochs : {top_df['epoch'].mean()}")
    print(f"Std epochs : {top_df['epoch'].std()}")
    print(f"Mean rounds : {top_df['total_rounds'].mean()}")
    print(f"Std rounds : {top_df['total_rounds'].std()}")

## Save all those time-cost plots

In [162]:
import os

def time_cost_plots(df, acc_threshold, nn_name, limit_x_axis=False, show_runs=False, params=False, addi_name=''):
    # Filter out based on `acc_threshold`
    acceptable_acc_df = df[(df.accuracy > acc_threshold) & (df.nn_name == nn_name)]
    
    str_thresh = str(acc_threshold).replace('.', '_')  # replace '.'
    
    if not os.path.exists(f"../../metrics/plots/{nn_name}/{str_thresh}"):
        os.makedirs(f"../../metrics/plots/{nn_name}/{str_thresh}")
    
    # 1. Same runs are included
    #scatter_time_cost(acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/nonfiltered_scatter_time_cost.pdf")
    #scatter_time_cost_cpu_ratio(acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/nonfiltered_scatter_time_cost_cpu_ratio.pdf")
    
    if show_runs:
        # Plot the runs of each method. x-axis : time cost, y-axis : accuracy, PER number of clients (lines)
        fda_method_run_line_plot(acceptable_acc_df, 'sketch', f"../../metrics/plots/{nn_name}/{str_thresh}/sketch_run.pdf", limit_x_axis=limit_x_axis)
        fda_method_run_line_plot(acceptable_acc_df, 'naive', f"../../metrics/plots/{nn_name}/{str_thresh}/naive_run.pdf", limit_x_axis=limit_x_axis)
        fda_method_run_line_plot(acceptable_acc_df, 'linear', f"../../metrics/plots/{nn_name}/{str_thresh}/linear_run.pdf", limit_x_axis=limit_x_axis)
        fda_method_run_line_plot(acceptable_acc_df, 'synchronous', f"../../metrics/plots/{nn_name}/{str_thresh}/synchronous_run.pdf", limit_x_axis=limit_x_axis)
    
    # 2. Filter out same runs. We choose the instance which first hits the `acc_threshold`
    idx = acceptable_acc_df.groupby(['fda_name', 'num_clients', 'batch_size', 'theta'])['epoch'].idxmin()
    filtered_acceptable_acc_df = acceptable_acc_df.loc[idx]
    
    # 2. Same runs are NOT included
    kde_time_cost(filtered_acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/kde_time_cost{addi_name}.pdf")
    
    #scatter_time_cost_cpu_ratio(filtered_acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/scatter_time_cost_cpu_ratio.pdf")
    
    #kde_communication_cost(filtered_acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/kde_communication_cost.pdf")
    
    #kde_cpu_time_cost(filtered_acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/kde_cpu_time_cost.pdf")
    
    # Parameters fixed, and plot
    if params:
        fda_methods_batch_size(filtered_acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/fda_methods_batch_size.pdf")
        fda_methods_clients(filtered_acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/fda_methods_clients.pdf")

In [106]:
time_cost_plots(df, 0.95, 'LeNet-5')

In [107]:
time_cost_plots(df, 0.955, 'LeNet-5')

In [108]:
time_cost_plots(df, 0.96, 'LeNet-5')

In [109]:
time_cost_plots(df, 0.965, 'LeNet-5')

In [110]:
time_cost_plots(df, 0.97, 'LeNet-5')

In [111]:
time_cost_plots(df, 0.975, 'LeNet-5')

In [112]:
time_cost_plots(df, 0.98, 'LeNet-5')

In [113]:
time_cost_plots(df, 0.985, 'LeNet-5')

In [114]:
time_cost_plots(df, 0.98, 'AdvancedCNN')

In [115]:
time_cost_plots(df, 0.985, 'AdvancedCNN')

In [116]:
time_cost_plots(df, 0.988, 'AdvancedCNN')

In [117]:
time_cost_plots(df, 0.99, 'AdvancedCNN')

In [118]:
time_cost_plots(df, 0.993, 'AdvancedCNN')

In [119]:
time_cost_plots(df, 0.995, 'AdvancedCNN')

In [163]:
explore_top(df, 0.995, 'AdvancedCNN', 'synchronous')

Unnamed: 0,num_clients,batch_size,theta,total_rounds,total_fda_steps,epoch,total_communication_gb,common_channel_time_cost
340375,15,128,0.0,813,813,26,252.895227,2525.689988
371378,5,256,0.0,1360,1360,29,141.015789,2834.101596
341390,20,128,0.0,961,961,41,398.57698,3782.625151
342394,25,128,0.0,844,844,45,437.563698,4022.199343
378361,5,32,0.0,4500,4500,12,466.59636,4366.854684
351424,10,256,0.0,1758,1758,75,364.567289,5121.762249
383975,5,64,0.0,4875,4875,26,505.47939,5506.222761
339410,10,128,0.0,2860,2860,61,593.095818,6512.57782
378986,10,64,0.0,3469,3469,37,719.387899,6795.723239
355478,20,256,0.0,1512,1512,129,627.105508,6913.487174


In [150]:
explore_top(df, 0.995, 'AdvancedCNN', 'naive')

Unnamed: 0,num_clients,batch_size,theta,total_rounds,epoch,total_communication_gb,common_channel_time_cost
288654,5,64,1.5,159,5,16.486423,542.506562
259058,10,32,15.0,74,9,15.345903,628.383151
259458,10,32,7.0,121,9,25.092583,698.559243
263220,30,32,10.0,82,21,51.014693,797.353107
259258,10,32,3.0,194,9,40.231043,807.556153
144866,10,128,1.5,129,17,26.751557,817.472565
272262,10,64,1.5,198,13,41.060528,863.54094
258257,10,32,1.0,304,8,63.042413,914.118031
262211,25,32,1.5,161,12,83.468994,917.169473
259361,10,32,5.0,192,12,39.816313,976.996442


In [155]:
explore_top(df, 0.985, 'LeNet-5', 'synchronous')

Unnamed: 0,num_clients,batch_size,theta,total_rounds,epoch,total_communication_gb,common_channel_time_cost
907291,10,256,0.0,985,42,4.862433,65.718438
922279,5,256,0.0,1407,30,3.472814,66.091442
932263,5,64,0.0,2625,14,6.47913,67.239968
905774,5,128,0.0,2344,25,5.785555,76.72987
898282,10,128,0.0,1547,33,7.636735,81.187343
899045,15,128,0.0,1438,46,10.647987,103.861601
928274,10,64,0.0,2344,25,11.571109,106.326526
908829,15,256,0.0,1250,80,9.2559,108.081433
899805,20,128,0.0,1313,56,12.963196,120.75969
928680,15,64,0.0,1938,31,14.350347,126.177494


In [169]:
explore_top(df, 0.985, 'LeNet-5', 'naive')

Unnamed: 0,num_clients,batch_size,theta,total_rounds,total_fda_steps,epoch,total_communication_gb,common_channel_time_cost
703566,10,32,0.5,149,3188,17,0.735663,25.24986
705126,15,32,1.0,94,3375,27,0.696246,26.322122
720687,15,64,1.0,64,2375,38,0.474045,26.858447
715910,5,32,1.0,101,4125,11,0.249375,27.41359
738963,5,64,0.5,118,2625,14,0.291305,27.637887
703920,10,32,1.0,107,3938,21,0.528361,28.451392
704321,10,32,1.5,78,4125,22,0.38521,28.591106
705531,15,32,1.5,74,4000,32,0.548189,29.128387
707947,25,32,1.5,71,3600,48,0.876585,29.26389
723505,20,64,1.5,54,2625,56,0.53335,29.826995


In [171]:
mean_epoch_per_method(df, 0.98, 'LeNet-5', 'synchronous')

Mean epochs : 63.666666666666664
Std epochs : 45.50489023677557
Mean rounds : 1274.9375
Std rounds : 447.0183206417402
