In [1]:
from copy import deepcopy
import jax

from federated_library.distributions import convert_to_federated_data
from federated_library.dataset_loader import load_tf_dataset
from federated_library.train_fed_avg import fed_avg_gridsearch

In [None]:
print(jax.local_devices())

In [None]:
def run(params, ds, test_split, ds_info, display):
    clients_set = params['clients_set']
    skews_set = params['skews_set']
    
    params = deepcopy(params)
    del params['clients_set']
    del params['skews_set']
    
    for nr_clients in clients_set:
        ds_info['num_clients'] = nr_clients
        for skew in skews_set:
            params['skew'] = [skew]
            fedavg_hyperparams_grid_search_res = fed_avg_gridsearch(
                params, ds, test_split, ds_info, display=display
            )
            
            sorted_res = sorted(fedavg_hyperparams_grid_search_res, key=lambda e: e[0], reverse=True)
            
            file_name = f"results/{dataset_name}_{skew_type}_skew_{skew}_{nr_clients}clients.txt"
            print(file_name)
            textfile = open(file_name, "w")
            for line in sorted_res:
                textfile.write(str(line) + "\n")
            textfile.close()

In [None]:
# TOTAL CONFIGS: 9 * 4 * 75 = 675

DATASETS = ["mnist", "emnist", "cifar10", "svhn_cropped"]
SKEW_TYPES = ["label", "feature", "qty"]
SKEWS = {
    "feature": [0.02, 0.1],
    "label": [0.1, 1.0, 5.0],
    "qty": [0.1, 0.4, 1.0, 2.0]
}

TEST_PARAMS = dict(
    act_fn=[jax.nn.relu],
    client_lr=[0.03],
    server_lr=[0.03],
    client_momentum=[0.9],
    server_momentum=[0.9],
    batch_size=[16],
    epochs_per_round=[2],
    rounds=[3],
    runs=[1],
    clients_set=[3],
    skews_set=[0.1]
    
)

for skew_type in SKEW_TYPES:
    for dataset_name in DATASETS:
        print(f"DATASET: {dataset_name}, SKEW_TYPE: {skew_type}")
        ds, (x_test, y_test), ds_info = load_tf_dataset(
            dataset_name=dataset_name, skew_type=skew_type, decentralized=False,
            display=False
        )

        test_split = convert_to_federated_data(x_test, y_test, ds_info, is_train=False)
    
        hp_configs = dict(
            act_fn=[jax.nn.relu],
            client_lr=[0.01],
            server_lr=[0.001, 0.005, 0.01, 0.05],
            client_momentum=[0.5],
            server_momentum=[0.5, 0.7, 0.9],
            batch_size=[16, 32],
            epochs_per_round=[2],
            rounds=[30],
            runs=[1],
            clients_set=[10, 20],
            skews_set=SKEWS[skew_type]
        )

        run(hp_configs, ds, test_split, ds_info, display=False)

In [None]:
def run_test(params, ds, test_split, ds_info, display):
    for i in range(len(params['client_lr'])):
        ds_info['num_clients'] = params['clients_set'][i]

        skew = params['skews_set'][i]

        act = [params['act_fn'][0]]
        clr = [params['client_lr'][i]]
        slr = [params['server_lr'][0]]
        cmom = [params['client_momentum'][i]]
        smom = [params['server_momentum'][0]]
        bs = [params['batch_size'][i]]
        epr = [params['epochs_per_round'][0]]
        r = [params['rounds'][0]]
        runs = params['runs']

        test_params = dict(clients=ds_info['num_clients'], skew=skew, act_fn=act, client_lr=clr, server_lr=slr,
                           client_momentum=cmom, server_momentum=smom, batch_size=bs,
                           epochs_per_round=epr, rounds=r, runs=runs)

        fedavg_hyperparams_grid_search_res = fed_avg_gridsearch(test_params, ds, test_split, ds_info, display=display)

        return fedavg_hyperparams_grid_search_res

In [None]:
datasets = ["mnist", "emnist", "cifar10", "svhn_cropped"]
SKEWS = {
    "feature": [0.02, 0.1],
    "label": [0.1, 1.0, 5.0],
    "qty": [0.1, 0.4, 1.0, 2.0]
}

for dataset_name in datasets:
    print(dataset_name)
    experiment_name = f"{dataset_name}_non-iid"
    skew_type = "label"

    ds, (x_test, y_test), ds_info = load_tf_dataset(dataset_name=dataset_name, skew_type=skew_type, decentralized=False,
                                                    display=True)
    test_split = fedjax.create_tf_dataset_for_clients(to_ClientData([x_test], [y_test], ds_info, train=False),
                                                      ['0']).batch(50)

    params = dict(
        act_fn=[jax.nn.relu],
        client_lr=[0.01],
        server_lr=[0.01, 0.03, 0.05],
        client_momentum=[0.01],
        server_momentum=[0.0, 0.01, 0.3, 0.9],
        batch_size=[8, 16, 32],
        epochs_per_round=[2],
        rounds=[30],
        runs=1,
        clients_set=[10, 20],
        skews_set=SKEWS[skew_type]
    )

    run(params, ds, test_split, ds_info, False)