In [1]:
from copy import deepcopy
import jax
import fedjax
from distributions import *
from dataset_loader import *
from fedAvg_training import train_fedAvg

In [None]:
def fed_avg_gridsearch(params, ds, test_split, ds_info, display):
    act_fn = params['act_fn']
    client_lr = params['client_lr']
    server_lr = params['server_lr']
    client_momentum = params['client_momentum']
    server_momentum = params['server_momentum']
    batch_size = params['batch_size']
    epochs_per_round = params['epochs_per_round']
    rounds = params['rounds']
    runs = params['runs']

    total = len(client_lr) * len(server_lr) * len(client_momentum) * len(server_momentum) * len(batch_size) * len(
        epochs_per_round) * len(rounds)

    print(f'Gridsearch on {total} values, total with folds :{total * runs}')

    count = 0

    res = []
    for act in act_fn:
        for clr in client_lr:
            for slr in server_lr:
                for cmom in client_momentum:
                    for smom in server_momentum:
                        for bs in batch_size:
                            for epr in epochs_per_round:
                                for r in rounds:
                                    params = dict(act_fn=act, client_lr=clr, server_lr=slr,
                                                  client_momentum=cmom, server_momentum=smom, batch_size=bs,
                                                  epochs_per_round=epr, rounds=r, skew=params['skew'])
                                    print(f'Training with params : {params}')

                                    for r in range(runs):
                                        run_res = train_fedAvg(params, ds, test_split, ds_info, display=display)[
                                            'accuracy']
                                        print(count, run_res)

                                        res.append((run_res, deepcopy(params)))
                                        count += 1

    return res

In [None]:
def run(params, ds, test_split, ds_info, display):
    for c in params['clients_set']:

        ds_info['num_clients'] = c

        for skew in params['skews_set']:

            params['skew'] = skew

            fedavg_hyperparams_grid_search_res = fed_avg_gridsearch(deepcopy(params), ds, test_split, ds_info,
                                                                    display=display)
            sorted_res = fedavg_hyperparams_grid_search_res
            sorted_res.sort(key=lambda e: e[0], reverse=True)

            textfile = open(f"{dataset_name}_{skew_type}_skew_{skew}_{c}clients.txt", "w")
            for line in sorted_res:
                textfile.write(str(line) + "\n")
            textfile.close()

In [None]:
def run_test(params, ds, test_split, ds_info, display):
    for i in range(len(params['client_lr'])):
        ds_info['num_clients'] = params['clients_set'][i]

        skew = params['skews_set'][i]

        act = [params['act_fn'][0]]
        clr = [params['client_lr'][i]]
        slr = [params['server_lr'][0]]
        cmom = [params['client_momentum'][i]]
        smom = [params['server_momentum'][0]]
        bs = [params['batch_size'][i]]
        epr = [params['epochs_per_round'][0]]
        r = [params['rounds'][0]]
        runs = params['runs']

        test_params = dict(clients=ds_info['num_clients'], skew=skew, act_fn=act, client_lr=clr, server_lr=slr,
                           client_momentum=cmom, server_momentum=smom, batch_size=bs,
                           epochs_per_round=epr, rounds=r, runs=runs)

        fedavg_hyperparams_grid_search_res = fed_avg_gridsearch(test_params, ds, test_split, ds_info, display=display)

        return fedavg_hyperparams_grid_search_res

In [None]:
datasets = ["mnist", "emnist", "cifar10", "svhn_cropped"]
SKEWS = {
    "feature": [0.02, 0.1],
    "label": [0.1, 1.0, 5.0],
    "qty": [0.1, 0.4, 1.0, 2.0]
}

for dataset_name in datasets:
    print(dataset_name)
    experiment_name = f"{dataset_name}_non-iid"
    skew_type = "label"

    ds, (x_test, y_test), ds_info = load_tf_dataset(dataset_name=dataset_name, skew_type=skew_type, decentralized=False,
                                                    display=True)
    test_split = fedjax.create_tf_dataset_for_clients(to_ClientData([x_test], [y_test], ds_info, train=False),
                                                      ['0']).batch(50)

    params = dict(
        act_fn=[jax.nn.relu],
        client_lr=[0.01],
        server_lr=[0.01, 0.03, 0.05],
        client_momentum=[0.01],
        server_momentum=[0.0, 0.01, 0.3, 0.9],
        batch_size=[8, 16, 32],
        epochs_per_round=[2],
        rounds=[30],
        runs=1,
        clients_set=[10, 20],
        skews_set=SKEWS[skew_type]
    )

    run(params, ds, test_split, ds_info, False)