In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import jax

from federated_library.distributions import convert_to_federated_data
from federated_library.dataset_loader import load_tf_dataset
from federated_library.train_fed_avg import fed_avg_gridsearch
from read_data import get_client_res, get_fedavg_acc
from heuristic_funcs import aggregate_results
from constants import SKEW_TYPES, SKEWS

DATASETS = ["mnist", "svhn_cropped", "cifar10"]
HEUR_VERSIONS = [1, 2, 3, 4]

In [None]:
# Evaluate heuristic performance

# NR_EVALUATIONS = DATASETS x SKEW_TYPES x NR_CLIENTS x SKEW

for dataset_name in DATASETS:
    for skew_type in SKEW_TYPES:
        print(f"DATASET: {dataset_name}, SKEW_TYPE: {skew_type}")
        ds, (x_test, y_test), ds_info = load_tf_dataset(
            dataset_name=dataset_name,
            skew_type=skew_type,
            decentralized=False,
            display=False
        )

        test_split = convert_to_federated_data(
            x_test, y_test, ds_info, is_train=False)

        hp_configs = dict(
            act_fn=[jax.nn.relu],
            client_lr=[0.01],
            client_momentum=[0.01],
            epochs_per_round=[2],
            rounds=[30],
            runs=[1]
        )

        for nr_clients in (10, 20):
            ds_info['num_clients'] = nr_clients
            for skew in SKEWS[skew_type]:
                print(f"{nr_clients} clients, skew={skew}")
                hp_configs['skew'] = [skew]

                hps, accs, best_acc, ratios = get_client_res(
                    dataset_name, skew, nr_clients, skew_type
                )

                agg_hp_configs = {
                    v: aggregate_results(
                    hps, accs, best_acc, ratios,
                    type_of_skew=skew_type, v=v) for v in HEUR_VERSIONS
                }

                print(agg_hp_configs)

                for v in HEUR_VERSIONS:
                    for hp, value in agg_hp_configs[v].items():
                        hp_configs[hp] = [value]

                    print(hp_configs)

                    fedavg_hyperparams_grid_search_res = fed_avg_gridsearch(
                        hp_configs, ds, test_split, ds_info, display=False
                    )

                    sorted_heur_res = sorted(fedavg_hyperparams_grid_search_res,
                                        key=lambda e: e[0], reverse=True)

                    if not os.path.exists(f"heur_v{v}_results/"):
                        os.makedirs(f"heur_v{v}_results/")

                    file_name = (f"heur_v{v}_results/{dataset_name}_{skew_type}_skew_"
                                f"{skew}_{nr_clients}clients.txt")

                    textfile = open(file_name, "w")
                    for line in sorted_heur_res:
                        textfile.write(str(line) + "\n")
                    textfile.close()

                    fedavg_acc = get_fedavg_acc(
                        dataset_name, skew, nr_clients, skew_type)

                    print(dataset_name, skew, nr_clients, skew_type)
                    print("Accuracy", fedavg_acc, sorted_heur_res[0][0],
                        float(fedavg_acc) - sorted_heur_res[0][0])