In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import glob
import seaborn as sns

## Skip (unless new tests): Combine metrics

In [2]:
#df_old = pd.read_parquet("../../metrics/epoch_metrics/combined_epoch_metrics.parquet")

In [3]:
"""
csv_files = []
csv_files.extend(glob.glob('../../metrics/epoch_metrics/AdvancedCNN_complete/b32_t05/epoch_metrics/*.csv'))
csv_files.extend(glob.glob('../../metrics/epoch_metrics/AdvancedCNN_complete/b32_t1/epoch_metrics/*.csv'))
csv_files.extend(glob.glob('../../metrics/epoch_metrics/AdvancedCNN_complete/b32_t2/epoch_metrics/*.csv'))
csv_files.extend(glob.glob('../../metrics/epoch_metrics/AdvancedCNN_complete/b128_t05/epoch_metrics/*.csv'))
csv_files.extend(glob.glob('../../metrics/epoch_metrics/AdvancedCNN_complete/b128_t1/epoch_metrics/*.csv'))
csv_files.extend(glob.glob('../../metrics/epoch_metrics/AdvancedCNN_complete/b128_t2/epoch_metrics/*.csv'))
csv_files.extend(glob.glob('../../metrics/epoch_metrics/AdvancedCNN_complete/b256_t05/epoch_metrics/*.csv'))
csv_files.extend(glob.glob('../../metrics/epoch_metrics/AdvancedCNN_complete/b256_t1/epoch_metrics/*.csv'))
csv_files.extend(glob.glob('../../metrics/epoch_metrics/AdvancedCNN_complete/b256_t2/epoch_metrics/*.csv'))
csv_files.extend(glob.glob('../../metrics/epoch_metrics/AdvancedCNN_complete/sync_32/epoch_metrics/*.csv'))
csv_files.extend(glob.glob('../../metrics/epoch_metrics/AdvancedCNN_complete/sync_128/epoch_metrics/*.csv'))
csv_files.extend(glob.glob('../../metrics/epoch_metrics/AdvancedCNN_complete/sync_256/epoch_metrics/*.csv'))
"""

"\ncsv_files = []\ncsv_files.extend(glob.glob('../../metrics/epoch_metrics/AdvancedCNN_complete/b32_t05/epoch_metrics/*.csv'))\ncsv_files.extend(glob.glob('../../metrics/epoch_metrics/AdvancedCNN_complete/b32_t1/epoch_metrics/*.csv'))\ncsv_files.extend(glob.glob('../../metrics/epoch_metrics/AdvancedCNN_complete/b32_t2/epoch_metrics/*.csv'))\ncsv_files.extend(glob.glob('../../metrics/epoch_metrics/AdvancedCNN_complete/b128_t05/epoch_metrics/*.csv'))\ncsv_files.extend(glob.glob('../../metrics/epoch_metrics/AdvancedCNN_complete/b128_t1/epoch_metrics/*.csv'))\ncsv_files.extend(glob.glob('../../metrics/epoch_metrics/AdvancedCNN_complete/b128_t2/epoch_metrics/*.csv'))\ncsv_files.extend(glob.glob('../../metrics/epoch_metrics/AdvancedCNN_complete/b256_t05/epoch_metrics/*.csv'))\ncsv_files.extend(glob.glob('../../metrics/epoch_metrics/AdvancedCNN_complete/b256_t1/epoch_metrics/*.csv'))\ncsv_files.extend(glob.glob('../../metrics/epoch_metrics/AdvancedCNN_complete/b256_t2/epoch_metrics/*.csv'))\n

In [4]:
#df_new = pd.concat([pd.read_csv(file) for file in csv_files], ignore_index=True)

In [5]:
#df = pd.concat([df_old, df_new])

## Read from `~/metrics/epoch_metrics/` directory

In [65]:
df = pd.read_parquet("../../metrics/epoch_metrics/combined_epoch_metrics.parquet")

In [66]:
#df.to_parquet("../../metrics/epoch_metrics/combined_epoch_metrics.parquet")

# Hyper-Parameters

## All tests summary

Note that this is a summary. 

Notes:
1. `synchronous` method implies `theta` equal to `0.0` and the reverse.
2. The stopping criterion (epochs at the moment) is not the same for all combination, in fact, it changes. For `LeNet-5` currently all tests end at 50 epochs. For the `AdvancedCNN` we go up to 350 epochs in some tests and some others up until 250.
3. `Theta` greater than 2 are exploratory tests and will not be considered at all later on.

In [67]:
for col in ['dataset_name', 'fda_name', 'nn_num_weights', 'num_clients', 'batch_size', 'num_steps_until_rtc_check', 'theta']:
    print(f"{col}: {sorted(list(df[col].unique()))}")

dataset_name: ['EMNIST']
fda_name: ['linear', 'naive', 'sketch', 'synchronous']
nn_num_weights: [61706, 2592202]
num_clients: [5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60]
batch_size: [32, 128, 256]
num_steps_until_rtc_check: [1]
theta: [0.0, 0.5, 1.0, 2.0]


## Query All tests

In [68]:
test_combinations = df.groupby(['dataset_name', 'nn_num_weights', 'fda_name', 'num_steps_until_rtc_check', 'batch_size', 'theta', 'num_clients'])['epoch'].max().reset_index()

In [69]:
test_combinations[
    (test_combinations['nn_num_weights'] == 61706) &
    (test_combinations['fda_name'] == 'sketch') &
    (test_combinations['theta'] == 2) &
    (test_combinations['num_clients'] == 5)
]

Unnamed: 0,dataset_name,nn_num_weights,fda_name,num_steps_until_rtc_check,batch_size,theta,num_clients,epoch
240,EMNIST,61706,sketch,1,32,2.0,5,50
276,EMNIST,61706,sketch,1,128,2.0,5,50
312,EMNIST,61706,sketch,1,256,2.0,5,50


# Helpful new Dataframe metrics

### Remove the exploratory tests (`Theta` above 2)

In [70]:
df = df[(df.theta <= 2)]

### Add NN name

In [71]:
def nn_name(row):
    if row['nn_num_weights'] == 61706:
        return 'LeNet-5'
    if row['nn_num_weights'] == 2592202:
        return 'AdvancedCNN'

In [72]:
df['nn_name'] = df.apply(nn_name, axis=1)

### Add Helpful Dataset Metrics 

In [73]:
def dataset_n_train(row):
    if row['dataset_name'] == "EMNIST":
        return 60_000
    else:
        return -1


def dataset_one_sample_bytes(row):
    if row['dataset_name'] == "EMNIST":
        # input image 784 tf.float32 pixels and a tf.int32 label
        return 4 * (784 + 1)
    else:
        return -1

In [74]:
df['n_train'] = df.apply(dataset_n_train, axis=1)
df['one_sample_bytes'] = df.apply(dataset_one_sample_bytes, axis=1)

### Add Helpful model metrics

In [75]:
df['model_bytes'] = df['nn_num_weights'] * 4

### Add Helpful FDA method metrics

In [76]:
def fda_local_state_bytes(row):
    if row['fda_name'] == "naive":
        return 4
    if row['fda_name'] == "linear":
        return 8
    if row['fda_name'] == "sketch":
        return row['sketch_width'] * row['sketch_depth'] * 4 + 4
    if row['fda_name'] == "synchronous":
        return 0

In [77]:
df['local_state_bytes'] = df.apply(fda_local_state_bytes, axis=1)

### Add Total Steps

total steps (a single fda step might have many normal SGD steps, batch steps)

In [78]:
df['total_steps'] = df['total_fda_steps'] * df['num_steps_until_rtc_check']

### Add communication metrics

The communication bytes exchanged for model synchronization. Remember that the Clients send their models to the Server and the Server sends the global model back. This happens at the end of every round.

In [79]:
df['model_bytes_exchanged'] = df['total_rounds'] * df['model_bytes'] * df['num_clients'] * 2

The communication bytes exchanged for monitoring the variance. This happens at the end of every FDA step which consists of `num_steps_until_rtc_check` number of steps. 

In [80]:
df['monitoring_bytes_exchanged'] = df['local_state_bytes'] * df['total_fda_steps'] * df['num_clients']

The total communication bytes exchanged in the whole Federated Learning lifecycle.

In [81]:
df['total_communication_bytes'] = df['model_bytes_exchanged'] + df['monitoring_bytes_exchanged']

In [82]:
df['total_communication_gb'] = df['total_communication_bytes'] / 10**9

Add rounds in one epoch.

In [83]:
df = df.sort_values(by=['dataset_name', 'fda_name', 'nn_num_weights', 'num_clients', 'batch_size', 'num_steps_until_rtc_check', 'theta', 'epoch'])

df['epoch_rounds'] = df.groupby(['dataset_name', 'fda_name', 'nn_num_weights', 'num_clients', 'batch_size', 'num_steps_until_rtc_check', 'theta'])['total_rounds'].diff()

# NaN first epoch
df['epoch_rounds'] = df['epoch_rounds'].fillna(df['total_rounds'])

df['epoch_rounds'] = df['epoch_rounds'].astype(int)

# HyperParameter ranking

### AdvancedCNN
On 8 CPUs, the step time:

1. *Batch Size* = 32 -> `307ms`
2. *Batch Size* = 64 -> `445ms`
3. *Batch Size* = 128 -> `815ms`
4. *Batch Size* = 256 -> `1401ms`

Best fit line:

step(ms) = 4.97092 * batch_size + 147.739

### LeNet-5
On 8 CPUs, the step time:

1. *Batch Size* = 32 -> `5.93ms`
2. *Batch Size* = 64 -> `9.16ms`
3. *Batch Size* = 128 -> `18.5ms`
4. *Batch Size* = 256 -> `30.6ms`

Best fit line:

step(ms) = 0.11124 * batch_size + 2.69913

In [84]:
def step_ms(batch_size, nn_name):
    if nn_name == 'AdvancedCNN': 
        return 4.97092 * batch_size + 147.739
    if nn_name == 'LeNet-5':
        return 0.11124 * batch_size + 2.69913

Time cost for training-reducing

In [85]:
import numpy as np

def cpu_time_cost(row):
    """ Total cpu time cost in (sec).
    A single `step` means each client performed a single `step` 
    """
    return row['total_steps'] * step_ms(row['batch_size'], row['nn_name']) / 1000

def communication_time_cost(num_clients, total_communication_bytes, comm_model):
    """ Assuming channel is 1Gbps """

    total_communication_gbit = total_communication_bytes * 8e-9

    if comm_model == 'common_channel':
        
        return ((num_clients - 1) / num_clients) * total_communication_gbit    # sec

    if comm_model == 'hypercube':

        return (np.ceil(np.log(num_clients)) / num_clients) * total_communication_gbit   # sec

In [86]:
df['cpu_time_cost'] = df.apply(cpu_time_cost, axis=1)

In [87]:
df['hypercube_communication_time_cost'] = communication_time_cost(df['num_clients'], df['total_communication_bytes'], 'hypercube')

In [88]:
df['common_channel_communication_time_cost'] = communication_time_cost(df['num_clients'], df['total_communication_bytes'], 'common_channel')

In [89]:
df['hypercube_time_cost'] = df['cpu_time_cost'] + df['hypercube_communication_time_cost']

In [90]:
df['common_channel_time_cost'] = df['cpu_time_cost'] + df['common_channel_communication_time_cost']

In [91]:
df['hypercube_comm_cpu_time_ratio'] = df['hypercube_communication_time_cost'] / df['cpu_time_cost']

In [92]:
df['common_channel_comm_cpu_time_ratio'] = df['common_channel_communication_time_cost'] / df['cpu_time_cost']

# Plots about cost

In [93]:
# Define styles for each fda_name
fda_styles = {
    'naive': 'o-r',
    'linear': 's-g',
    'sketch': '^-b',
    'synchronous': 'x-c'
}
fda_names = sorted(df['fda_name'].unique())

In [94]:
import matplotlib

num_clients_values = sorted(df['num_clients'].unique())
cmap = matplotlib.colormaps['tab20b']
colors_dict = {
    num_clients: color 
    for num_clients, color in zip(num_clients_values, cmap(np.linspace(0, 1, len(num_clients_values))))
}

## Total time cost with accuracy (scatter)

### Scatter

In [95]:
def scatter_time_cost(df, filename):
    pdf = PdfPages(filename)
    
    fig, axs = plt.subplots(1, 2, figsize=(20, 6))
    
    # Prepare lists to store the average information for each subplot
    avg_info_common_channel = []
    avg_info_hypercube = []

    # Plot the data points for each method (fda_name) for common_channel communication model
    for fda_name in fda_names:
        fda_filtered_data = df[(df['fda_name'] == fda_name)] 
        axs[0].scatter(fda_filtered_data['common_channel_time_cost'], fda_filtered_data['accuracy'], label=fda_name)
        
        # Calculate the average time cost for common_channel model and append to avg_info_common_channel
        avg_time_common_channel = fda_filtered_data['common_channel_time_cost'].mean()
        avg_info_common_channel.append(f'{fda_name}: {avg_time_common_channel:.2f} sec')
        
    
    text_common_channel = "Average Time Cost:\n" + '\n'.join(avg_info_common_channel)
    # Add the text annotation inside the first plot
    axs[0].text(0.62, 0.97, text_common_channel, transform=axs[0].transAxes, fontsize=9, verticalalignment='top',
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.4))

    axs[0].set_xlabel('Time Cost (sec)')
    axs[0].set_ylabel('Accuracy')
    axs[0].legend()
    axs[0].set_title("Common Channel Communication Model")

    # Plot the data points for each method (fda_name) for multi-bus model
    for fda_name in fda_names:
        fda_filtered_data = df[(df['fda_name'] == fda_name)] 
        axs[1].scatter(fda_filtered_data['hypercube_time_cost'], fda_filtered_data['accuracy'], label=fda_name)
        
        # Calculate the average time cost for hypercube model and append to avg_info_hypercube
        avg_time_hypercube = fda_filtered_data['hypercube_time_cost'].mean()
        avg_info_hypercube.append(f'{fda_name}: {avg_time_hypercube:.2f} sec')
        
    text_hypercube = "Average Time Cost:\n" + '\n'.join(avg_info_hypercube)
    # Add the text annotation inside the second plot
    axs[1].text(0.62, 0.97, text_hypercube, transform=axs[1].transAxes, fontsize=9, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.4))    

    axs[0].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
    axs[1].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
        
    axs[1].set_xlabel('Time Cost (sec)')
    axs[1].set_ylabel('Accuracy')
    axs[1].legend()
    axs[1].set_title("Hypercube Communication Model")
    
    axs[0].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
    axs[1].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)

    plt.tight_layout()
    
    pdf.savefig(fig)
        
    # Close the current figure to prevent it from being displayed in the notebook
    plt.close(fig)
    pdf.close()

### Histogram

In [96]:
def histo_time_cost(df, filename):
    pdf = PdfPages(filename)
    
    fig, axs = plt.subplots(1, 2, figsize=(20, 6))

    num_bins = 15
    
    hist_data_common_channel = []
    hist_data_hypercube = []

    # Prepare a list to store the average communication for each method
    avg_time_cost_common_channel_dict = {}
    avg_time_cost_hypercube_dict = {}
    
    # Prepare lists to store the average information (string) for each subplot
    avg_info_common_channel = []
    avg_info_hypercube = []

    for fda_name in fda_names:
        common_channel_df = df[df['fda_name'] == fda_name]['common_channel_time_cost']
        avg_time_cost_common_channel_dict[fda_name] = common_channel_df.mean()
        hist_data_common_channel.append(common_channel_df.tolist())
        avg_info_common_channel.append(f'{fda_name}: {avg_time_cost_common_channel_dict[fda_name]:.2f} sec')
        
        hypercube_df = df[df['fda_name'] == fda_name]['hypercube_time_cost']
        avg_time_cost_hypercube_dict[fda_name] = hypercube_df.mean()
        hist_data_hypercube.append(hypercube_df.tolist())
        avg_info_hypercube.append(f'{fda_name}: {avg_time_cost_hypercube_dict[fda_name]:.2f} sec')
        

    # 1. Common Channel - Plotting the histogram
    n, _, patches  = axs[0].hist(hist_data_common_channel, num_bins, histtype='bar', label=fda_names)
    
    # Plotting the average lines
    for i, fda_name in enumerate(fda_names):
        axs[0].axvline(avg_time_cost_common_channel_dict[fda_name], color=patches[i][0].get_facecolor(), 
                    linestyle='--', linewidth=0.6)
    
    text = "Average Time Cost:\n" + '\n'.join(avg_info_common_channel)
    # Add the text annotation inside the plot
    axs[0].text(0.62, 0.97, text, transform=axs[0].transAxes, fontsize=9, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.4))
    
    axs[0].set_xlabel('Time Cost (sec)')
    axs[0].set_ylabel('Count')
    axs[0].legend()
    axs[0].set_title("Common Channel Communication Model")
    axs[0].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
    axs[0].legend()
    
    
    # 2. Hypercube - Plotting the histogram
    n, _, patches  = axs[1].hist(hist_data_hypercube, num_bins, histtype='bar', label=fda_names)
    
    # Plotting the average lines
    for i, fda_name in enumerate(fda_names):
        axs[1].axvline(avg_time_cost_hypercube_dict[fda_name], color=patches[i][0].get_facecolor(), 
                    linestyle='--', linewidth=0.6)

    text = "Average Time Cost:\n" + '\n'.join(avg_info_hypercube)
    # Add the text annotation inside the plot
    axs[1].text(0.62, 0.97, text, transform=axs[1].transAxes, fontsize=9, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.4))
    
    axs[1].set_xlabel('Time Cost (sec)')
    axs[1].set_ylabel('Count')
    axs[1].legend()
    axs[1].set_title("Hypercube Communication Model")
    axs[1].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
    axs[1].legend()
    
    plt.tight_layout()
    
    pdf.savefig(fig)
        
    # Close the current figure to prevent it from being displayed in the notebook
    plt.close(fig)
    pdf.close()

### KDE

In [97]:
def kde_time_cost(df, filename):
    pdf = PdfPages(filename)
    
    fig, axs = plt.subplots(1, 2, figsize=(20, 6))
    
    hist_data_common_channel = []
    hist_data_hypercube = []

    # Prepare a list to store the average communication for each method
    avg_time_cost_common_channel_dict = {}
    avg_time_cost_hypercube_dict = {}
    
    # Prepare lists to store the average information (string) for each subplot
    avg_info_common_channel = []
    avg_info_hypercube = []

    for fda_name in fda_names:
        common_channel_df = df[df['fda_name'] == fda_name]['common_channel_time_cost']
        avg_time_cost_common_channel_dict[fda_name] = common_channel_df.mean()
        avg_info_common_channel.append(f'{fda_name}: {avg_time_cost_common_channel_dict[fda_name]:.2f} sec')
        sns.kdeplot(common_channel_df, label=fda_name, fill=True, alpha=0.4, ax=axs[0])
        
        hypercube_df = df[df['fda_name'] == fda_name]['hypercube_time_cost']
        avg_time_cost_hypercube_dict[fda_name] = hypercube_df.mean()
        avg_info_hypercube.append(f'{fda_name}: {avg_time_cost_hypercube_dict[fda_name]:.2f} sec')
        sns.kdeplot(hypercube_df, label=fda_name, fill=True, alpha=0.4, ax=axs[1])
        
    
    text = "Average Time Cost:\n" + '\n'.join(avg_info_common_channel)
    # Add the text annotation inside the plot
    axs[0].text(0.62, 0.97, text, transform=axs[0].transAxes, fontsize=9, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.4))
    
    axs[0].set_xlim(left=0)
    axs[0].set_xlabel('Time Cost (sec)')
    axs[0].set_ylabel('Density')
    axs[0].legend()
    axs[0].set_title("Common Channel Communication Model")
    axs[0].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
    axs[0].legend()

    text = "Average Time Cost:\n" + '\n'.join(avg_info_hypercube)
    # Add the text annotation inside the plot
    axs[1].text(0.62, 0.97, text, transform=axs[1].transAxes, fontsize=9, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.4))
    
    axs[1].set_xlim(left=0)
    axs[1].set_xlabel('Time Cost (sec)')
    axs[1].set_ylabel('Density')
    axs[1].legend()
    axs[1].set_title("Hypercube Communication Model")
    axs[1].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
    axs[1].legend()
    
    plt.tight_layout()
    
    pdf.savefig(fig)
        
    # Close the current figure to prevent it from being displayed in the notebook
    plt.close(fig)
    pdf.close()

## Total Communication cost (in gb) with accuracy (scatter)

### Scatter

In [98]:
def scatter_communication_cost(df, filename):
    pdf = PdfPages(filename)
    
    plt.figure(figsize=(10, 6))
    
    # Prepare a list to store the average information for the title
    avg_info = []

    # Plot the data points for each method (fda_name) for common_channel communication model
    for fda_name in fda_names:
        fda_filtered_data = df[(df['fda_name'] == fda_name)] 
        plt.scatter(fda_filtered_data['total_communication_gb'], fda_filtered_data['accuracy'], label=fda_name)

        # Calculate the average communication for each method and append to avg_info
        avg_communication = fda_filtered_data['total_communication_gb'].mean()
        avg_info.append(f'{fda_name}: {avg_communication:.2f} GB')

    text = "Average Communication:\n" + '\n'.join(avg_info)
    # Add the text annotation inside the plot
    plt.text(0.62, 0.97, text, transform=plt.gca().transAxes, fontsize=9, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.4))

    plt.xlabel('Communication (GB)')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.5)

    plt.tight_layout()
    
    pdf.savefig(plt.gcf()) # Save the current figure
        
    # Close the current figure to prevent it from being displayed in the notebook
    plt.close()
    pdf.close()

### Histogram

In [99]:
def histo_communication_cost(df, filename):
    pdf = PdfPages(filename)
    
    plt.figure(figsize=(10, 6))

    num_bins = 15
    hist_data = []

    # Prepare a list to store the average communication for each method
    avg_communications_dict = {}
    
    avg_info = []

    for fda_name in fda_names:
        fda_data_df = df[df['fda_name'] == fda_name]['total_communication_gb']
        avg_communications_dict[fda_name] = fda_data_df.mean()
        hist_data.append(fda_data_df.tolist())
        avg_info.append(f'{fda_name}: {avg_communications_dict[fda_name]:.2f} GB')

    # Plotting the histogram
    n, _, patches = plt.hist(hist_data, num_bins, histtype='bar', label=fda_names)

    # Plotting the average lines
    for i, fda_name in enumerate(fda_names):
        plt.axvline(avg_communications_dict[fda_name], color=patches[i][0].get_facecolor(), 
                    linestyle='--', linewidth=0.6)
    
    text = "Average Communication:\n" + '\n'.join(avg_info)
    # Add the text annotation inside the plot
    plt.text(0.62, 0.97, text, transform=plt.gca().transAxes, fontsize=9, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.4))

    plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
    plt.xlabel('Communication (GB)')
    plt.ylabel('Count')
    plt.legend()
    plt.tight_layout()
    
    pdf.savefig(plt.gcf()) # Save the current figure
        
    # Close the current figure to prevent it from being displayed in the notebook
    plt.close()
    pdf.close()

### KDE

In [100]:
def kde_communication_cost(df, filename):
    pdf = PdfPages(filename)
    
    plt.figure(figsize=(10, 6))

    avg_communications_dict = {}
    avg_info = []

    for fda_name in fda_names:
        fda_data_df = df[df['fda_name'] == fda_name]['total_communication_gb']
        avg_communications_dict[fda_name] = fda_data_df.mean()
        avg_info.append(f'{fda_name}: {avg_communications_dict[fda_name]:.2f} GB')
        
        # Plotting only the KDE using kdeplot
        sns.kdeplot(fda_data_df, label=fda_name, fill=True, alpha=0.4)

    text = "Average Communication:\n" + '\n'.join(avg_info)
    # Add the text annotation inside the plot
    plt.text(0.62, 0.97, text, transform=plt.gca().transAxes, fontsize=9, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.4))

    plt.xlim(left=0)
    plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
    plt.xlabel('Communication (GB)')
    plt.ylabel('Density')
    plt.legend()
    plt.tight_layout()
    
    pdf.savefig(plt.gcf()) # Save the current figure
        
    # Close the current figure to prevent it from being displayed in the notebook
    plt.close()
    pdf.close()

## Total CPU time (in Seconds) with accuracy

### Scatter

In [101]:
def scatter_cpu_time_cost(df, filename):
    pdf = PdfPages(filename)
    
    plt.figure(figsize=(10, 6))
    
    # Prepare a list to store the average information for the title
    avg_info = []

    # Plot the data points for each method (fda_name) for common_channel communication model
    for fda_name in fda_names:
        fda_filtered_data = df[(df['fda_name'] == fda_name)] 
        plt.scatter(fda_filtered_data['cpu_time_cost'], fda_filtered_data['accuracy'], label=fda_name)
        
        # Calculate the average communication for each method and append to avg_info
        avg_cpu_time = fda_filtered_data['cpu_time_cost'].mean()
        avg_info.append(f'{fda_name}: {avg_cpu_time:.2f} sec')
    
    text = "Average CPU time:\n" + '\n'.join(avg_info)
    # Add the text annotation inside the plot
    plt.text(0.62, 0.97, text, transform=plt.gca().transAxes, fontsize=9, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.4))

    plt.xlabel('CPU time cost (sec)')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.5)

    plt.tight_layout()
    
    pdf.savefig(plt.gcf()) # Save the current figure
        
    # Close the current figure to prevent it from being displayed in the notebook
    plt.close()
    pdf.close()

### Histogram

In [102]:
def histo_cpu_time_cost(df, filename):
    pdf = PdfPages(filename)
    
    plt.figure(figsize=(10, 6))

    num_bins = 15
    hist_data = []

    # Prepare a list to store the average communication for each method
    avg_cpu_dict = {}
    
    avg_info = []

    for fda_name in fda_names:
        fda_data_df = df[df['fda_name'] == fda_name]['cpu_time_cost']
        avg_cpu_dict[fda_name] = fda_data_df.mean()
        hist_data.append(fda_data_df.tolist())
        avg_info.append(f'{fda_name}: {avg_cpu_dict[fda_name]:.2f} sec')

    # Plotting the histogram
    n, _, patches = plt.hist(hist_data, num_bins, histtype='bar', label=fda_names)

    # Plotting the average lines
    for i, fda_name in enumerate(fda_names):
        plt.axvline(avg_cpu_dict[fda_name], color=patches[i][0].get_facecolor(), 
                    linestyle='--', linewidth=0.6)
    
    text = "Average CPU time:\n" + '\n'.join(avg_info)
    # Add the text annotation inside the plot
    plt.text(0.62, 0.97, text, transform=plt.gca().transAxes, fontsize=9, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.4))

    plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
    plt.xlabel('CPU time cost (sec)')
    plt.ylabel('Count')
    plt.legend()
    plt.tight_layout()
    
    pdf.savefig(plt.gcf()) # Save the current figure
        
    # Close the current figure to prevent it from being displayed in the notebook
    plt.close()
    pdf.close()

### KDE

In [103]:
def kde_cpu_time_cost(df, filename):
    pdf = PdfPages(filename)
    
    plt.figure(figsize=(10, 6))

    avg_cpu_dict = {}
    avg_info = []

    for fda_name in fda_names:
        fda_data_df = df[df['fda_name'] == fda_name]['cpu_time_cost']
        avg_cpu_dict[fda_name] = fda_data_df.mean()
        avg_info.append(f'{fda_name}: {avg_cpu_dict[fda_name]:.2f} sec')
        
        # Plotting only the KDE using kdeplot
        sns.kdeplot(fda_data_df, label=fda_name, fill=True, alpha=0.4)

    text = "Average Communication:\n" + '\n'.join(avg_info)
    # Add the text annotation inside the plot
    plt.text(0.62, 0.97, text, transform=plt.gca().transAxes, fontsize=9, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.4))

    plt.xlim(left=0)
    plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
    plt.xlabel('CPU time cost (sec)')
    plt.ylabel('Count')
    plt.legend()
    plt.tight_layout()
    
    pdf.savefig(plt.gcf()) # Save the current figure
        
    # Close the current figure to prevent it from being displayed in the notebook
    plt.close()
    pdf.close()

## Communication/CPU time - Accuracy

### Scatter

In [104]:
def scatter_time_cost_cpu_ratio(df, filename):
    
    pdf = PdfPages(filename)

    fig, axs = plt.subplots(1, 2, figsize=(20, 6))

    # Plot the data points for each method (fda_name) for single-bus model
    for fda_name in fda_names:
        fda_filtered_data = df[(df['fda_name'] == fda_name)] 
        axs[0].scatter(fda_filtered_data['common_channel_comm_cpu_time_ratio'], fda_filtered_data['accuracy'], label=fda_name)

    axs[0].set_xlabel('(Communication time) / (CPU time)')
    axs[0].set_ylabel('Accuracy')
    axs[0].legend()
    axs[0].set_title("Common Channel Communication Model")

    # Plot the data points for each method (fda_name) for multi-bus model
    for fda_name in fda_names:
        fda_filtered_data = df[(df['fda_name'] == fda_name)] 
        axs[1].scatter(fda_filtered_data['hypercube_comm_cpu_time_ratio'], fda_filtered_data['accuracy'], label=fda_name)

    axs[0].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
    axs[1].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)   
        
    axs[1].set_xlabel('(Communication time) / (CPU time)')
    axs[1].set_ylabel('Accuracy')
    axs[1].legend()
    axs[1].set_title("Hypercube Communication Model")
    
    axs[0].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
    axs[1].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)

    plt.tight_layout()
    pdf.savefig(fig)
        
    # Close the current figure to prevent it from being displayed in the notebook
    plt.close(fig)
    pdf.close()

## Different FDA-method runs visualization with lines (per clients)

In [105]:
import matplotlib
from matplotlib import cm

def fda_method_run_line_plot(df, fda_name, filename, limit_x_axis=False):
    
    hybrid_df = df[df.fda_name == fda_name]
    
    batch_size_values = sorted(hybrid_df['batch_size'].unique())
    theta_values = sorted(hybrid_df['theta'].unique())
    
    pdf = PdfPages(filename)

    for batch_size in batch_size_values:
        for theta in theta_values:
            filtered_data = hybrid_df[(hybrid_df['theta'] == theta) &
                                       (hybrid_df['batch_size'] == batch_size)] 

            if filtered_data.empty:
                continue
            
            num_clients_values = sorted(filtered_data['num_clients'].unique())
            
            #print(f"{batch_size} {theta} {filtered_data[['fda_name','accuracy']]}")
                
            # Because num_clients = 2 reaches very good accuracy very early we need to limit the x-axis so we can 
            # visualize the rest of the data more easily. We get the next max time cost and limit x
            # Get the maximum 'common_channel_time_cost' and 'hypercube_time_cost' when num_clients is not 2
            if limit_x_axis:
                # The num clients = 2 cause problems. We put limit 10 after the next in line
                max_common_channel_time_cost = filtered_data[filtered_data['num_clients'] != 2]['common_channel_time_cost'].max()
                max_hypercube_time_cost = filtered_data[filtered_data['num_clients'] != 2]['hypercube_time_cost'].max()


            fig, axs = plt.subplots(1, 2, figsize=(20, 6))

            # Plot each group with a unique color based on num_clients
            for num_clients in num_clients_values:
                data = filtered_data[filtered_data['num_clients'] == num_clients]

                axs[0].plot(data['common_channel_time_cost'], data['accuracy'], color=colors_dict[num_clients], label=num_clients, marker='o', markersize=3)
                axs[1].plot(data['hypercube_time_cost'], data['accuracy'], color=colors_dict[num_clients], label=num_clients, marker='o', markersize=3)
                
            
            axs[0].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
            axs[1].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
            
            #axs[0].set_ylim(top=1)
            #axs[1].set_ylim(top=1)

            axs[0].set_xlabel('Time Cost')
            axs[0].set_ylabel('Accuracy')
            axs[0].set_title("Common Channel Communication Model")
            axs[0].legend(title='Num Clients')
            if limit_x_axis:
                axs[0].set_xlim(0, max_common_channel_time_cost+700)  # set x-axis limit

            axs[1].set_xlabel('Time Cost')
            axs[1].set_ylabel('Accuracy')
            axs[1].set_title("Hypercube Communication Model")
            axs[1].legend(title='Num Clients')
            if limit_x_axis:
                axs[1].set_xlim(0, max_hypercube_time_cost+500)  # set x-axis limit

            title = f'Batch Size : {batch_size} , $\Theta$ : {theta}'

            fig.suptitle(title)

            plt.tight_layout()

            pdf.savefig(fig)

            plt.close(fig)

    pdf.close()

## Keep Hyper-parameters fixed and plot for filtered

In [106]:
def fda_methods_batch_size(df, filename):
    pdf = PdfPages(filename)
    
    num_clients_values = sorted(df['num_clients'].unique())
    theta_values = sorted(df['theta'].unique())[1:]

    for num_clients in num_clients_values:
        for theta in theta_values:
            filtered_df = df[(df['theta'] == theta) & (df['num_clients'] == num_clients)] 
            synchronous_data = df[(df['num_clients'] == num_clients) & (df['fda_name'] == 'synchronous')] 
            
            fig, axs = plt.subplots(1, 2, figsize=(20, 6))
            
            for fda_name in fda_names:
                
                fda_data = filtered_df[filtered_df['fda_name'] == fda_name]
                
                if fda_data.empty:
                    continue

                axs[0].plot(fda_data['batch_size'], fda_data['common_channel_time_cost'], marker='o', label=fda_name, markersize=3)
                axs[1].plot(fda_data['batch_size'], fda_data['hypercube_time_cost'], marker='o', label=fda_name, markersize=3)
            
            if not synchronous_data.empty:
                axs[0].plot(synchronous_data['batch_size'], synchronous_data['common_channel_time_cost'], marker='o', label='synchronous', markersize=3)
                axs[1].plot(synchronous_data['batch_size'], synchronous_data['hypercube_time_cost'], marker='o', label='synchronous', markersize=3)

            
            if not axs[0].has_data():
                plt.close(fig)
                continue
            
            x_ticks = [32, 128, 256]
            
            axs[0].set_xticks(x_ticks)
            axs[0].set_xlabel('Batch size')
            axs[0].legend()
            axs[0].set_title("Common Channel Communication Model")
            axs[0].set_ylabel('Time Cost')
            axs[0].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
            
            axs[1].set_xticks(x_ticks)
            axs[1].set_xlabel('Batch size')
            axs[1].legend()
            axs[1].set_title("Hypercube Communication Model")
            axs[1].set_ylabel('Time Cost')
            axs[1].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
            
            title = f'Num Clients : {num_clients} , $\Theta$ : {theta}'
            fig.suptitle(title)

            plt.tight_layout()
            pdf.savefig(fig)

            plt.close(fig)
    pdf.close()

In [107]:
def fda_methods_clients(df, filename):
    pdf = PdfPages(filename)
    
    batch_size_values = sorted(df['batch_size'].unique())
    theta_values = sorted(df['theta'].unique())[1:]

    for batch_size in batch_size_values:
        for theta in theta_values:
            
            filtered_df = df[(df['theta'] == theta) & (df['batch_size'] == batch_size)] 
            
            if filtered_df.empty:
                continue
                
            synchronous_data = df[(df['batch_size'] == batch_size) & (df['fda_name'] == 'synchronous')] 
            
            fig, axs = plt.subplots(1, 2, figsize=(20, 6))
            
            for fda_name in fda_names:
                
                fda_data = filtered_df[filtered_df['fda_name'] == fda_name]
                
                if fda_data.empty:
                    continue

                empty = False
                axs[0].plot(fda_data['num_clients'], fda_data['common_channel_time_cost'], marker='o', label=fda_name, markersize=3)
                axs[1].plot(fda_data['num_clients'], fda_data['hypercube_time_cost'], marker='o', label=fda_name, markersize=3)
            
            if not synchronous_data.empty:
                axs[0].plot(synchronous_data['num_clients'], synchronous_data['common_channel_time_cost'], marker='o', label='synchronous', markersize=3)
                axs[1].plot(synchronous_data['num_clients'], synchronous_data['hypercube_time_cost'], marker='o', label='synchronous', markersize=3)

            if not axs[0].has_data():
                plt.close(fig)
                continue
                
            # Set xticks every 5 units based on available values
            min_clients = 5
            max_clients = 60
            
            x_ticks = list(range(min_clients, max_clients + 1, 5))
            
            axs[0].set_xticks(x_ticks)
            axs[0].set_xlabel('Number of clients')
            axs[0].legend()
            axs[0].set_title("Common Channel Communication Model")
            axs[0].set_ylabel('Time Cost')
            axs[0].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
            
            axs[1].set_xticks(x_ticks)
            axs[1].set_xlabel('Number of clients')
            axs[1].legend()
            axs[1].set_title("Hypercube Communication Model")
            axs[1].set_ylabel('Time Cost')
            axs[1].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
            
            title = f'Batch Size : {batch_size} , $\Theta$ : {theta}'
            fig.suptitle(title)

            plt.tight_layout()

            pdf.savefig(fig)

            plt.close(fig)
    pdf.close()

In [108]:
def fda_methods_theta(df, filename):
    pdf = PdfPages(filename)
    
    num_clients_values = sorted(df['num_clients'].unique())
    batch_sizes = sorted(df['batch_size'].unique())

    for num_clients in num_clients_values:
        for batch_size in batch_sizes:
            filtered_df = df[(df['batch_size'] == batch_size) & (df['num_clients'] == num_clients)] 
            synchronous_data = df[(df['num_clients'] == num_clients) & (df['fda_name'] == 'synchronous')] 
            
            fig, axs = plt.subplots(1, 2, figsize=(20, 6))
            
            for fda_name in fda_names+['synchronous']:
                
                fda_data = filtered_df[filtered_df['fda_name'] == fda_name]
                
                if fda_data.empty:
                    continue

                axs[0].plot(fda_data['theta'], fda_data['common_channel_time_cost'], marker='o', label=fda_name, markersize=3)
                axs[1].plot(fda_data['theta'], fda_data['hypercube_time_cost'], marker='o', label=fda_name, markersize=3)

            
            if not axs[0].has_data():
                plt.close(fig)
                continue
                
            x_ticks = [0., 0.5, 1., 2.]
            
            axs[0].set_xticks(x_ticks)
            axs[0].set_xlabel('Theta')
            axs[0].legend()
            axs[0].set_title("Common Channel Communication Model")
            axs[0].set_ylabel('Time Cost')
            axs[0].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
            
            axs[1].set_xticks(x_ticks)
            axs[1].set_xlabel('Theta')
            axs[1].legend()
            axs[1].set_title("Hypercube Communication Model")
            axs[1].set_ylabel('Time Cost')
            axs[1].grid(True, linestyle='--', linewidth=0.5, alpha=0.5)
            
            title = f'Num Clients : {num_clients} , Batch Size : {batch_size}'
            fig.suptitle(title)

            plt.tight_layout()
            pdf.savefig(fig)

            plt.close(fig)
    pdf.close()

## Save all those time-cost plots

In [109]:
import os

def time_cost_plots(df, acc_threshold, nn_name, limit_x_axis=False, show_runs=False, params=False):
    # Filter out based on `acc_threshold`
    acceptable_acc_df = df[(df.accuracy > acc_threshold) & (df.nn_name == nn_name)]
    
    str_thresh = str(acc_threshold).replace('.', '_')  # replace '.'
    
    if not os.path.exists(f"../../metrics/plots/{nn_name}/{str_thresh}"):
        os.makedirs(f"../../metrics/plots/{nn_name}/{str_thresh}")
    
    # 1. Same runs are included
    #scatter_time_cost(acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/nonfiltered_scatter_time_cost.pdf")
    #scatter_time_cost_cpu_ratio(acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/nonfiltered_scatter_time_cost_cpu_ratio.pdf")
    
    if show_runs:
        # Plot the runs of each method. x-axis : time cost, y-axis : accuracy, PER number of clients (lines)
        fda_method_run_line_plot(acceptable_acc_df, 'sketch', f"../../metrics/plots/{nn_name}/{str_thresh}/sketch_run.pdf", limit_x_axis=limit_x_axis)
        fda_method_run_line_plot(acceptable_acc_df, 'naive', f"../../metrics/plots/{nn_name}/{str_thresh}/naive_run.pdf", limit_x_axis=limit_x_axis)
        fda_method_run_line_plot(acceptable_acc_df, 'linear', f"../../metrics/plots/{nn_name}/{str_thresh}/linear_run.pdf", limit_x_axis=limit_x_axis)
        fda_method_run_line_plot(acceptable_acc_df, 'synchronous', f"../../metrics/plots/{nn_name}/{str_thresh}/synchronous_run.pdf", limit_x_axis=limit_x_axis)
    
    # 2. Filter out same runs. We choose the instance which first hits the `acc_threshold`
    idx = acceptable_acc_df.groupby(['fda_name', 'num_clients', 'batch_size', 'theta'])['epoch'].idxmin()
    filtered_acceptable_acc_df = acceptable_acc_df.loc[idx]
    
    # 2. Same runs are NOT included
    #scatter_time_cost(filtered_acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/filtered_scatter_time_cost.pdf")
    #histo_time_cost(filtered_acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/filtered_histo_time_cost.pdf")
    kde_time_cost(filtered_acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/filtered_kde_time_cost.pdf")
    
    #scatter_time_cost_cpu_ratio(filtered_acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/filtered_scatter_time_cost_cpu_ratio.pdf")
    
    #scatter_communication_cost(filtered_acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/filtered_scatter_communication_cost.pdf")
    #histo_communication_cost(filtered_acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/filtered_histo_communication_cost.pdf")
    kde_communication_cost(filtered_acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/filtered_kde_communication_cost.pdf")
    
    #scatter_cpu_time_cost(filtered_acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/filtered_scatter_cpu_time_cost.pdf")
    #histo_cpu_time_cost(filtered_acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/filtered_histo_cpu_time_cost.pdf")
    kde_cpu_time_cost(filtered_acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/filtered_kde_cpu_time_cost.pdf")
    
    # Parameters fixed, and plot
    if params:
        fda_methods_batch_size(filtered_acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/filtered_fda_methods_batch_size.pdf")
        fda_methods_clients(filtered_acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/filtered_fda_methods_clients.pdf")
        fda_methods_theta(filtered_acceptable_acc_df, f"../../metrics/plots/{nn_name}/{str_thresh}/filtered_fda_methods_theta.pdf")

In [110]:
time_cost_plots(df, 0.95, 'LeNet-5', show_runs=True, params=True)

In [111]:
time_cost_plots(df, 0.955, 'LeNet-5', params=True)

In [112]:
time_cost_plots(df, 0.96, 'LeNet-5', params=True)

In [113]:
time_cost_plots(df, 0.965, 'LeNet-5', params=True)

In [114]:
time_cost_plots(df, 0.97, 'LeNet-5', params=True)

In [115]:
time_cost_plots(df, 0.975, 'LeNet-5', params=True)

In [116]:
time_cost_plots(df, 0.95, 'AdvancedCNN', params=True)

In [117]:
time_cost_plots(df, 0.96, 'AdvancedCNN', params=True)

In [118]:
time_cost_plots(df, 0.97, 'AdvancedCNN', params=True)

In [119]:
time_cost_plots(df, 0.98, 'AdvancedCNN', params=True)

In [120]:
time_cost_plots(df, 0.985, 'AdvancedCNN', show_runs=True, params=True)

In [121]:
time_cost_plots(df, 0.988, 'AdvancedCNN', params=True)

In [122]:
time_cost_plots(df, 0.99, 'AdvancedCNN', params=True)