In [1]:
from typing import Dict
import flwr as fl
import torch
from flower_helpers import (create_model, get_weights, test, load_data)
from config import (NUM_ROUNDS, MODEL_NAME, NUM_CLASSES, 
                    PRE_TRAINED, TRAIN_SIZE, VAL_PORTION, 
                    TEST_SIZE, BATCH_SIZE, LEARNING_RATE, 
                    EPOCHS, FRAC_FIT, FRAC_EVAL, MIN_FIT,
                    MIN_EVAL, MIN_AVAIL, FIT_CONFIG_FN,
                    NUM_CLIENTS, CLIENT_RESOURCES)
from client import FlowerClient

In [2]:
data_path = lambda x: f'tff_dataloaders_5clients/{x}.pth'
trainloaders = torch.load(data_path('trainloaders'))
valloaders = torch.load(data_path('valloaders'))
testloader = torch.load(data_path('testloader'))

In [3]:
# trainloaders, valloaders, testloader = load_data()

In [4]:
net = create_model()
init_weights = get_weights(net)
MODEL_CONFIG = net.config
# Convert the weights (np.ndarray) to parameters (bytes)
init_param = fl.common.ndarrays_to_parameters(init_weights)
# del the net as we don't need it anymore
del net

In [5]:
# server side evaluation function
def evaluate(server_round: int, params: fl.common.NDArrays,
             config: Dict[str, fl.common.Scalar]):
    data_size, metrics = test(MODEL_CONFIG, params, testloader)
    # changing the name of the metric to avoid confusion
    metrics['test_loss'] = metrics.pop('loss')
    metrics['test_accuracy'] = metrics.pop('accuracy')
    return metrics['test_loss'], metrics

def weighted_average_eval(metrics):
    weighted_train_loss = 0
    weighted_train_accuracy = 0
    for c in metrics: # c is a tuple (num_examples, metrics) for each client
        weighted_train_loss += c[0] * c[1]['val_loss']
        weighted_train_accuracy += c[0] * c[1]['val_accuracy']
    
    aggregated_metrics = {'val_loss': weighted_train_loss / sum([c[0] for c in metrics]),
            'val_accuracy': weighted_train_accuracy / sum([c[0] for c in metrics])}
    print('\t',aggregated_metrics)
    return aggregated_metrics

def weighted_average_fit(metrics):
    # print(metrics)
    weighted_train_loss = 0
    weighted_train_accuracy = 0
    for c in metrics: # c is a tuple (num_examples, metrics) for each client
        # metrics for each epoch is included, we only need the last one
        weighted_train_loss += c[0] * c[1]['train_loss'][-1]
        weighted_train_accuracy += c[0] * c[1]['train_accuracy'][-1]
    
    aggregated_metrics = {'train_loss': weighted_train_loss / sum([c[0] for c in metrics]),
            'train_accuracy': weighted_train_accuracy / sum([c[0] for c in metrics])}
    print('\t',aggregated_metrics)
    return aggregated_metrics

In [6]:
strategy = fl.server.strategy.FedAvg(
    fraction_fit=FRAC_FIT,
    fraction_evaluate=FRAC_EVAL,
    min_fit_clients=MIN_FIT,
    min_evaluate_clients=MIN_EVAL,
    min_available_clients=MIN_AVAIL,
    
    fit_metrics_aggregation_fn=weighted_average_fit,
    evaluate_metrics_aggregation_fn=weighted_average_eval,
    evaluate_fn=evaluate,
    on_fit_config_fn=FIT_CONFIG_FN,
    
    initial_parameters=init_param,
)

In [7]:
fl.simulation.start_simulation(
    client_fn=lambda cid: FlowerClient(MODEL_CONFIG, trainloaders[int(cid)], valloaders[int(cid)]),
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
    strategy=strategy,
    client_resources=CLIENT_RESOURCES,
)

INFO flwr 2023-04-10 17:38:33,480 | app.py:145 | Starting Flower simulation, config: ServerConfig(num_rounds=10, round_timeout=None)
2023-04-10 17:38:35,650	INFO worker.py:1529 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m
INFO flwr 2023-04-10 17:38:36,976 | app.py:179 | Flower VCE: Ray initialized with resources: {'memory': 14319938766.0, 'CPU': 24.0, 'object_store_memory': 7159969382.0, 'node:127.0.0.1': 1.0, 'GPU': 1.0}
INFO flwr 2023-04-10 17:38:36,977 | server.py:86 | Initializing global parameters
INFO flwr 2023-04-10 17:38:36,978 | server.py:266 | Using initial parameters provided by strategy
INFO flwr 2023-04-10 17:38:36,979 | server.py:88 | Evaluating initial parameters
INFO flwr 2023-04-10 17:38:43,020 | server.py:91 | initial parameters (loss, other metrics): 23.11473846435547, {'test_loss': 23.11473846435547, 'test_accuracy': 0.0}
INFO flwr 2023-04-10 17:38:43,021 | server.py:101 | FL starting
DEBUG flwr 2023-04-10 17:38:43,021 | 

	 {'train_loss': 3.146144151687622, 'train_accuracy': 0.14444444444444443}


INFO flwr 2023-04-10 17:39:29,777 | server.py:116 | fit progress: (1, 20.31015968322754, {'test_loss': 20.31015968322754, 'test_accuracy': 0.02}, 46.75605870000436)
DEBUG flwr 2023-04-10 17:39:29,778 | server.py:165 | evaluate_round 1: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-04-10 17:39:31,829 | server.py:179 | evaluate_round 1 received 2 results and 0 failures
DEBUG flwr 2023-04-10 17:39:31,830 | server.py:215 | fit_round 2: strategy sampled 2 clients (out of 5)


	 {'val_loss': 5.143322706222534, 'val_accuracy': 0.1}


DEBUG flwr 2023-04-10 17:40:08,017 | server.py:229 | fit_round 2 received 2 results and 0 failures


	 {'train_loss': 2.064218282699585, 'train_accuracy': 0.3888888888888889}


INFO flwr 2023-04-10 17:40:14,017 | server.py:116 | fit progress: (2, 16.250322341918945, {'test_loss': 16.250322341918945, 'test_accuracy': 0.21}, 90.99567160000151)
DEBUG flwr 2023-04-10 17:40:14,017 | server.py:165 | evaluate_round 2: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-04-10 17:40:16,082 | server.py:179 | evaluate_round 2 received 2 results and 0 failures
DEBUG flwr 2023-04-10 17:40:16,083 | server.py:215 | fit_round 3: strategy sampled 2 clients (out of 5)


	 {'val_loss': 3.239913582801819, 'val_accuracy': 0.25}


DEBUG flwr 2023-04-10 17:40:52,154 | server.py:229 | fit_round 3 received 2 results and 0 failures


	 {'train_loss': 2.921006202697754, 'train_accuracy': 0.3}


INFO flwr 2023-04-10 17:40:58,079 | server.py:116 | fit progress: (3, 16.552156448364258, {'test_loss': 16.552156448364258, 'test_accuracy': 0.18}, 135.05719320000208)
DEBUG flwr 2023-04-10 17:40:58,079 | server.py:165 | evaluate_round 3: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-04-10 17:41:00,290 | server.py:179 | evaluate_round 3 received 2 results and 0 failures
DEBUG flwr 2023-04-10 17:41:00,290 | server.py:215 | fit_round 4: strategy sampled 2 clients (out of 5)


	 {'val_loss': 3.2266114950180054, 'val_accuracy': 0.35}


DEBUG flwr 2023-04-10 17:41:36,206 | server.py:229 | fit_round 4 received 2 results and 0 failures


	 {'train_loss': 2.375395655632019, 'train_accuracy': 0.4833333333333334}


INFO flwr 2023-04-10 17:41:42,235 | server.py:116 | fit progress: (4, 17.11565399169922, {'test_loss': 17.11565399169922, 'test_accuracy': 0.21}, 179.2132156000007)
DEBUG flwr 2023-04-10 17:41:42,235 | server.py:165 | evaluate_round 4: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-04-10 17:41:44,238 | server.py:179 | evaluate_round 4 received 2 results and 0 failures
DEBUG flwr 2023-04-10 17:41:44,239 | server.py:215 | fit_round 5: strategy sampled 2 clients (out of 5)


	 {'val_loss': 3.149528384208679, 'val_accuracy': 0.30000000000000004}


DEBUG flwr 2023-04-10 17:42:20,358 | server.py:229 | fit_round 5 received 2 results and 0 failures


	 {'train_loss': 2.2948511242866516, 'train_accuracy': 0.6722222222222222}


INFO flwr 2023-04-10 17:42:26,378 | server.py:116 | fit progress: (5, 16.31963539123535, {'test_loss': 16.31963539123535, 'test_accuracy': 0.22}, 223.35733769999933)
DEBUG flwr 2023-04-10 17:42:26,379 | server.py:165 | evaluate_round 5: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-04-10 17:42:28,415 | server.py:179 | evaluate_round 5 received 2 results and 0 failures
DEBUG flwr 2023-04-10 17:42:28,416 | server.py:215 | fit_round 6: strategy sampled 2 clients (out of 5)


	 {'val_loss': 2.5191184282302856, 'val_accuracy': 0.35}


DEBUG flwr 2023-04-10 17:43:03,899 | server.py:229 | fit_round 6 received 2 results and 0 failures


	 {'train_loss': 1.5101075768470764, 'train_accuracy': 0.7555555555555555}


INFO flwr 2023-04-10 17:43:10,045 | server.py:116 | fit progress: (6, 15.467330932617188, {'test_loss': 15.467330932617188, 'test_accuracy': 0.23}, 267.02328759999364)
DEBUG flwr 2023-04-10 17:43:10,045 | server.py:165 | evaluate_round 6: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-04-10 17:43:12,069 | server.py:179 | evaluate_round 6 received 2 results and 0 failures
DEBUG flwr 2023-04-10 17:43:12,069 | server.py:215 | fit_round 7: strategy sampled 2 clients (out of 5)


	 {'val_loss': 3.7281811237335205, 'val_accuracy': 0.2}


DEBUG flwr 2023-04-10 17:43:47,626 | server.py:229 | fit_round 7 received 2 results and 0 failures


	 {'train_loss': 0.7402545511722565, 'train_accuracy': 0.9111111111111111}


INFO flwr 2023-04-10 17:43:53,583 | server.py:116 | fit progress: (7, 16.898813247680664, {'test_loss': 16.898813247680664, 'test_accuracy': 0.2}, 310.5616507999948)
DEBUG flwr 2023-04-10 17:43:53,584 | server.py:165 | evaluate_round 7: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-04-10 17:43:55,776 | server.py:179 | evaluate_round 7 received 2 results and 0 failures
DEBUG flwr 2023-04-10 17:43:55,777 | server.py:215 | fit_round 8: strategy sampled 2 clients (out of 5)


	 {'val_loss': 2.5233311653137207, 'val_accuracy': 0.4}


DEBUG flwr 2023-04-10 17:44:31,180 | server.py:229 | fit_round 8 received 2 results and 0 failures


	 {'train_loss': 0.7859745770692825, 'train_accuracy': 0.9111111111111111}


INFO flwr 2023-04-10 17:44:37,099 | server.py:116 | fit progress: (8, 15.24635124206543, {'test_loss': 15.24635124206543, 'test_accuracy': 0.29}, 354.0773792999971)
DEBUG flwr 2023-04-10 17:44:37,100 | server.py:165 | evaluate_round 8: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-04-10 17:44:39,319 | server.py:179 | evaluate_round 8 received 2 results and 0 failures
DEBUG flwr 2023-04-10 17:44:39,320 | server.py:215 | fit_round 9: strategy sampled 2 clients (out of 5)


	 {'val_loss': 2.41679847240448, 'val_accuracy': 0.5}


DEBUG flwr 2023-04-10 17:45:14,957 | server.py:229 | fit_round 9 received 2 results and 0 failures


	 {'train_loss': 0.5332944989204407, 'train_accuracy': 0.9666666666666667}


INFO flwr 2023-04-10 17:45:20,878 | server.py:116 | fit progress: (9, 16.17179298400879, {'test_loss': 16.17179298400879, 'test_accuracy': 0.25}, 397.85720100000617)
DEBUG flwr 2023-04-10 17:45:20,879 | server.py:165 | evaluate_round 9: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-04-10 17:45:23,039 | server.py:179 | evaluate_round 9 received 2 results and 0 failures
DEBUG flwr 2023-04-10 17:45:23,039 | server.py:215 | fit_round 10: strategy sampled 2 clients (out of 5)


	 {'val_loss': 3.0911901593208313, 'val_accuracy': 0.3}


DEBUG flwr 2023-04-10 17:45:58,711 | server.py:229 | fit_round 10 received 2 results and 0 failures


	 {'train_loss': 0.2511480823159218, 'train_accuracy': 0.9944444444444445}


INFO flwr 2023-04-10 17:46:04,665 | server.py:116 | fit progress: (10, 14.921417236328125, {'test_loss': 14.921417236328125, 'test_accuracy': 0.27}, 441.64400329999626)
DEBUG flwr 2023-04-10 17:46:04,666 | server.py:165 | evaluate_round 10: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-04-10 17:46:06,815 | server.py:179 | evaluate_round 10 received 2 results and 0 failures
INFO flwr 2023-04-10 17:46:06,815 | server.py:144 | FL finished in 443.7938545999932
INFO flwr 2023-04-10 17:46:06,828 | app.py:202 | app_fit: losses_distributed [(1, 5.143322706222534), (2, 3.239913582801819), (3, 3.2266114950180054), (4, 3.149528384208679), (5, 2.5191184282302856), (6, 3.7281811237335205), (7, 2.5233311653137207), (8, 2.41679847240448), (9, 3.0911901593208313), (10, 1.7978174090385437)]
INFO flwr 2023-04-10 17:46:06,829 | app.py:203 | app_fit: metrics_distributed {'val_loss': [(1, 5.143322706222534), (2, 3.239913582801819), (3, 3.2266114950180054), (4, 3.149528384208679), (5, 2.519118428230

	 {'val_loss': 1.7978174090385437, 'val_accuracy': 0.55}


History (loss, distributed):
	round 1: 5.143322706222534
	round 2: 3.239913582801819
	round 3: 3.2266114950180054
	round 4: 3.149528384208679
	round 5: 2.5191184282302856
	round 6: 3.7281811237335205
	round 7: 2.5233311653137207
	round 8: 2.41679847240448
	round 9: 3.0911901593208313
	round 10: 1.7978174090385437
History (loss, centralized):
	round 0: 23.11473846435547
	round 1: 20.31015968322754
	round 2: 16.250322341918945
	round 3: 16.552156448364258
	round 4: 17.11565399169922
	round 5: 16.31963539123535
	round 6: 15.467330932617188
	round 7: 16.898813247680664
	round 8: 15.24635124206543
	round 9: 16.17179298400879
	round 10: 14.921417236328125
History (metrics, distributed):
{'val_loss': [(1, 5.143322706222534), (2, 3.239913582801819), (3, 3.2266114950180054), (4, 3.149528384208679), (5, 2.5191184282302856), (6, 3.7281811237335205), (7, 2.5233311653137207), (8, 2.41679847240448), (9, 3.0911901593208313), (10, 1.7978174090385437)], 'val_accuracy': [(1, 0.1), (2, 0.25), (3, 0.35), 