In [None]:
from ditto_server import DittoStrategy
from ditto_client import ditto_client_fn_generator
from femnist import download_femnist
from utils import CLIENT_MODEL_DIR

import shutil
import os
import gc

import flwr as fl
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import ray

In [None]:
# Download FEMNIST data (does nothing if already present)
download_femnist()

In [None]:
results = {}

for _lambda in np.linspace(0, 1, 5):
    print('lambda =', _lambda)
    # Reset client models
    if os.path.exists(CLIENT_MODEL_DIR):
        shutil.rmtree(CLIENT_MODEL_DIR)  # Delete client model directory
    os.makedirs(CLIENT_MODEL_DIR)  # Recreate client model directory
    
    num_clients = 16
    num_rounds = 25
    epochs_per_round = 10
    
    # https://flower.ai/docs/framework/tutorial-series-customize-the-client-pytorch.html
    results[_lambda] = fl.simulation.start_simulation(
        num_clients=num_clients,
        client_fn=ditto_client_fn_generator(_lambda=_lambda, epochs_per_round=epochs_per_round),
        config=fl.server.ServerConfig(num_rounds=num_rounds),
        strategy=DittoStrategy(log_accuracy=True),
        client_resources={
            'num_cpus': max(os.cpu_count()//num_clients, 1)
        }
    )
    
    # Clean up after simulation to prevent memory leakage
    ray.shutdown()
    gc.collect()

In [None]:
global_accuracies = {_lambda: [acc for _, acc in results[_lambda].metrics_distributed['avg_global_accuracy']] for _lambda in results.keys()}
local_accuracies = {_lambda: [acc for _, acc in results[_lambda].metrics_distributed['avg_local_accuracy']] for _lambda in results.keys()}

In [None]:
pd.DataFrame.from_dict(global_accuracies).plot()
plt.xlabel('Round')
plt.ylabel('Mean global accuracy')
plt.legend(title='$\lambda$')

plt.show()

pd.DataFrame.from_dict(local_accuracies).plot()
plt.xlabel('Round')
plt.ylabel('Mean local accuracy')
plt.legend(title='$\lambda$')

plt.show()