In [1]:
import os
from copy import deepcopy
import jax

from federated_library.distributions import convert_to_federated_data
from federated_library.dataset_loader import load_tf_dataset
from federated_library.train_fed_avg import fed_avg_gridsearch

In [None]:
print(jax.local_devices())

In [None]:
def run(params, ds, test_split, ds_info, display):
    clients_set = params['clients_set']
    skews_set = params['skews_set']

    params = deepcopy(params)
    del params['clients_set']
    del params['skews_set']

    for nr_clients in clients_set:
        ds_info['num_clients'] = nr_clients
        for skew in skews_set:
            print(f"{nr_clients} clients, skew={skew}")
            params['skew'] = [skew]
            fedavg_hyperparams_grid_search_res = fed_avg_gridsearch(
                params, ds, test_split, ds_info, display=display
            )

            sorted_res = sorted(fedavg_hyperparams_grid_search_res,
                                key=lambda e: e[0], reverse=True)

            if not os.path.exists("results/"):
                os.makedirs("results/")

            file_name = f"results/{dataset_name}_{skew_type}_skew_{skew}_{nr_clients}clients.txt"
            print(file_name)
            textfile = open(file_name, "w")
            for line in sorted_res:
                textfile.write(str(line) + "\n")
            textfile.close()

In [None]:
DATASETS = ["mnist", "emnist", "cifar10", "svhn_cropped"]
SKEW_TYPES = ["label", "feature", "qty"]
SKEWS = {
    "feature": [0.02, 0.1],
    "label": [0.1, 1.0, 5.0],
    "qty": [0.1, 0.4, 1.0, 2.0]
}

TEST_PARAMS = dict(
    act_fn=[jax.nn.relu],
    client_lr=[0.01],
    server_lr=[0.03],
    client_momentum=[0.9],
    server_momentum=[0.9],
    batch_size=[16],
    epochs_per_round=[2],
    rounds=[30],
    runs=[1],
    clients_set=[10, 20],
    skews_set=[0.1]
)

for dataset_name in DATASETS:
    for skew_type in SKEW_TYPES:
        print(f"DATASET: {dataset_name}, SKEW_TYPE: {skew_type}")
        ds, (x_test, y_test), ds_info = load_tf_dataset(
            dataset_name=dataset_name, skew_type=skew_type,
            decentralized=False, display=False
        )

        test_split = convert_to_federated_data(x_test, y_test, ds_info,
                                               is_train=False)

        hp_configs = dict(
            act_fn=[jax.nn.relu],
            client_lr=[0.01],
            server_lr=[0.01, 0.03, 0.05, 0.1, 0.3, 0.5],
            client_momentum=[0.01],
            server_momentum=[0.0, 0.3, 0.6, 0.9],
            batch_size=[8, 16, 32],
            epochs_per_round=[2],
            rounds=[30],
            runs=[1],
            clients_set=[10, 20],
            skews_set=SKEWS[skew_type]
        )

        run(hp_configs, ds, test_split, ds_info, display=False)