In [None]:
import shutil
import os
import gc
import time

os.chdir('../../')

from ditto.base.ditto_server import DittoStrategy
from ditto.base.ditto_client import ditto_client_fn_generator
from femnist import download_femnist
from utils import CLIENT_MODEL_DIR

import flwr as fl
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import ray

In [None]:
# Download FEMNIST data (does nothing if already present)
download_femnist()

In [None]:
start_time = time.time()

results = {}

num_clients = 8
num_rounds = 50
epochs_per_round = 10

clients_ids = list(map(str, np.random.randint(0, len(next(os.walk('./femnist/client_data_mappings/fed_natural'))[1]), size=8))) 

for _lambda in np.linspace(0, 100, 6):
    print('lambda =', _lambda)
    # Reset client models
    if os.path.exists(CLIENT_MODEL_DIR):
        shutil.rmtree(CLIENT_MODEL_DIR)  # Delete client model directory
    os.makedirs(CLIENT_MODEL_DIR)  # Recreate client model directory
    
    # https://flower.ai/docs/framework/tutorial-series-customize-the-client-pytorch.html
    results[_lambda] = fl.simulation.start_simulation(
        # num_clients=num_clients,
        clients_ids=clients_ids,
        client_fn=ditto_client_fn_generator(_lambda=_lambda, epochs_per_round=epochs_per_round),
        config=fl.server.ServerConfig(num_rounds=num_rounds),
        strategy=DittoStrategy(log_accuracy=True),
        client_resources={
            'num_cpus': max(os.cpu_count()//num_clients, 1)
        }
    )
    
    # Clean up after simulation to prevent memory leakage
    ray.shutdown()
    gc.collect()
    
elapsed_time = time.time() - start_time
print('Elapsed time:', time.strftime("%H:%M:%S", time.gmtime(elapsed_time)))

In [None]:
global_accuracies = pd.DataFrame.from_dict({_lambda: [acc for _, acc in results[_lambda].metrics_distributed['avg_global_accuracy']] for _lambda in results.keys()})
local_accuracies = pd.DataFrame.from_dict({_lambda: [acc for _, acc in results[_lambda].metrics_distributed['avg_local_accuracy']] for _lambda in results.keys()})
accuracy_differential = global_accuracies - local_accuracies

In [None]:
_, axs = plt.subplots(1,3, figsize=(15,4))

global_accuracies.plot(ax=axs[0])
axs[0].set_xlabel('Round')
axs[0].set_ylabel('Mean global accuracy')
axs[0].legend(title='$\lambda$')
axs[0].grid()

local_accuracies.plot(ax=axs[1])
axs[1].set_xlabel('Round')
axs[1].set_ylabel('Mean local accuracy')
axs[1].legend(title='$\lambda$')
axs[1].grid()

accuracy_differential.plot(ax=axs[2])
axs[2].set_xlabel('Round')
axs[2].set_ylabel('Mean global - local accuracies')
axs[2].legend(title='$\lambda$')
axs[2].axhline(y=0, color='black', linestyle='--', linewidth=2.5, alpha=0.5)
axs[2].grid()

plt.tight_layout()
plt.show()