In [1]:
from typing import Dict
import flwr as fl

from flower_helpers import (create_model, get_weights, test, 
                            load_data, load_stored_tff)
from config import  (NUM_ROUNDS, TRAIN_SIZE, VAL_PORTION, 
                    TEST_SIZE, BATCH_SIZE, LEARNING_RATE, 
                    EPOCHS, FRAC_FIT, FRAC_EVAL, MIN_FIT,
                    MIN_EVAL, MIN_AVAIL, FIT_CONFIG_FN,
                    NUM_CLIENTS, CLIENT_RESOURCES, NON_IID,
                    RAY_ARGS, NUM_CLASSES, TFF_DATA_DIR, 
                    MODEL_NAME, PRE_TRAINED, DOUBLE_TRAIN)
from client import FlowerClient

In [2]:
if NON_IID:
    # non-iid dataset from tff (train and test are already split)
    trainloaders, valloaders, testloader = load_stored_tff(TFF_DATA_DIR, 
                                                           BATCH_SIZE,
                                                           DOUBLE_TRAIN)
else:
    # iid dataset from huggingface
    trainloaders, valloaders, testloader = load_data(MODEL_NAME, TEST_SIZE, 
                                                    TRAIN_SIZE, VAL_PORTION, 
                                                    BATCH_SIZE, NUM_CLIENTS, 
                                                    NUM_CLASSES)

Found cached dataset cifar10 (C:/Users/Jean/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4)


  0%|          | 0/2 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\Jean\.cache\huggingface\datasets\cifar10\plain_text\1.0.0\447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4\cache-16b9e105e7ead8c5.arrow
Loading cached shuffled indices for dataset at C:\Users\Jean\.cache\huggingface\datasets\cifar10\plain_text\1.0.0\447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4\cache-6d58bc2a635b7b42.arrow
Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [3]:
net = create_model(MODEL_NAME, NUM_CLASSES, PRE_TRAINED)
init_weights = get_weights(net)
MODEL_CONFIG = net.config
# Convert the weights (np.ndarray) to parameters (bytes)
init_param = fl.common.ndarrays_to_parameters(init_weights)
# del the net as we don't need it anymore
del net

In [4]:
# server side evaluation function
def evaluate(server_round: int, params: fl.common.NDArrays,
             config: Dict[str, fl.common.Scalar]):
    data_size, metrics = test(MODEL_CONFIG, params, testloader)
    # changing the name of the metric to avoid confusion
    metrics['test_loss'] = metrics.pop('loss')
    metrics['test_accuracy'] = metrics.pop('accuracy')
    return metrics['test_loss'], metrics

def weighted_average_eval(metrics):
    weighted_train_loss = 0
    weighted_train_accuracy = 0
    for c in metrics: # c is a tuple (num_examples, metrics) for each client
        weighted_train_loss += c[0] * c[1]['val_loss']
        weighted_train_accuracy += c[0] * c[1]['val_accuracy']
    
    aggregated_metrics = {'val_loss': weighted_train_loss / sum([c[0] for c in metrics]),
            'val_accuracy': weighted_train_accuracy / sum([c[0] for c in metrics])}
    print('\t',aggregated_metrics)
    return aggregated_metrics

def weighted_average_fit(metrics):
    # print(metrics)
    weighted_train_loss = 0
    weighted_train_accuracy = 0
    for c in metrics: # c is a tuple (num_examples, metrics) for each client
        # metrics for each epoch is included, we only need the last one
        weighted_train_loss += c[0] * c[1]['train_loss']
        weighted_train_accuracy += c[0] * c[1]['train_accuracy']
    
    aggregated_metrics = {'train_loss': weighted_train_loss / sum([c[0] for c in metrics]),
            'train_accuracy': weighted_train_accuracy / sum([c[0] for c in metrics])}
    print('\t',aggregated_metrics)
    return aggregated_metrics

In [5]:
strategy = fl.server.strategy.FedAvg(
    fraction_fit=FRAC_FIT,
    fraction_evaluate=FRAC_EVAL,
    min_fit_clients=MIN_FIT,
    min_evaluate_clients=MIN_EVAL,
    min_available_clients=MIN_AVAIL,
    
    fit_metrics_aggregation_fn=weighted_average_fit,
    evaluate_metrics_aggregation_fn=weighted_average_eval,
    evaluate_fn=evaluate,
    on_fit_config_fn=FIT_CONFIG_FN,
    
    initial_parameters=init_param,
)

In [6]:
print('num clients:', NUM_CLIENTS)
print('num rounds:', NUM_ROUNDS)
print('--'*20)
print('client training set size:', [len(t.dataset) for t in trainloaders])
print('client validation set size:', [len(v.dataset) for v in valloaders])
print('test set size:', len(testloader.dataset))
print('--'*20)
print('model name:', MODEL_NAME)
print('num classes:', NUM_CLASSES)
print('pre-trained:', PRE_TRAINED)
print('learning rate:', LEARNING_RATE)
print('batch size:', BATCH_SIZE)
print('epochs:', EPOCHS)

num clients: 5
num rounds: 10
----------------------------------------
client training set size: [180, 180, 180, 180, 180]
client validation set size: [20, 20, 20, 20, 20]
test set size: 100
----------------------------------------
model name: facebook/deit-tiny-distilled-patch16-224
num classes: 10
pre-trained: True
learning rate: 0.0001
batch size: 32
epochs: 1


In [7]:
fl.simulation.start_simulation(
    client_fn=lambda cid: FlowerClient(MODEL_CONFIG, trainloaders[int(cid)], valloaders[int(cid)]),
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
    strategy=strategy,
    client_resources=CLIENT_RESOURCES,
    ray_init_args=RAY_ARGS,
)

INFO flwr 2023-04-16 16:54:25,993 | app.py:145 | Starting Flower simulation, config: ServerConfig(num_rounds=10, round_timeout=None)
2023-04-16 16:54:28,063	INFO worker.py:1529 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m
INFO flwr 2023-04-16 16:54:29,515 | app.py:179 | Flower VCE: Ray initialized with resources: {'memory': 11670284699.0, 'object_store_memory': 5835142348.0, 'CPU': 24.0, 'node:127.0.0.1': 1.0, 'GPU': 1.0}
INFO flwr 2023-04-16 16:54:29,516 | server.py:86 | Initializing global parameters
INFO flwr 2023-04-16 16:54:29,517 | server.py:266 | Using initial parameters provided by strategy
INFO flwr 2023-04-16 16:54:29,517 | server.py:88 | Evaluating initial parameters
INFO flwr 2023-04-16 16:54:32,248 | server.py:91 | initial parameters (loss, other metrics): 11.147294044494629, {'test_loss': 11.147294044494629, 'test_accuracy': 0.07}
INFO flwr 2023-04-16 16:54:32,248 | server.py:101 | FL starting
DEBUG flwr 2023-04-16 16:54:32,249

	 {'train_loss': 1.4104443788528442, 'train_accuracy': 0.28055555555555556}


INFO flwr 2023-04-16 16:55:12,896 | server.py:116 | fit progress: (1, 6.4984283447265625, {'test_loss': 6.4984283447265625, 'test_accuracy': 0.48}, 40.648197600006824)
DEBUG flwr 2023-04-16 16:55:12,897 | server.py:165 | evaluate_round 1: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-04-16 16:55:15,683 | server.py:179 | evaluate_round 1 received 2 results and 0 failures
DEBUG flwr 2023-04-16 16:55:15,684 | server.py:215 | fit_round 2: strategy sampled 2 clients (out of 5)


	 {'val_loss': 1.485456109046936, 'val_accuracy': 0.6}


DEBUG flwr 2023-04-16 16:55:49,653 | server.py:229 | fit_round 2 received 2 results and 0 failures


	 {'train_loss': 1.0235320329666138, 'train_accuracy': 0.5361111111111111}


INFO flwr 2023-04-16 16:55:52,429 | server.py:116 | fit progress: (2, 5.266861438751221, {'test_loss': 5.266861438751221, 'test_accuracy': 0.62}, 80.18139799998607)
DEBUG flwr 2023-04-16 16:55:52,429 | server.py:165 | evaluate_round 2: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-04-16 16:55:55,201 | server.py:179 | evaluate_round 2 received 2 results and 0 failures
DEBUG flwr 2023-04-16 16:55:55,202 | server.py:215 | fit_round 3: strategy sampled 2 clients (out of 5)


	 {'val_loss': 1.0066905617713928, 'val_accuracy': 0.675}


DEBUG flwr 2023-04-16 16:56:29,041 | server.py:229 | fit_round 3 received 2 results and 0 failures


	 {'train_loss': 0.8902626931667328, 'train_accuracy': 0.7916666666666666}


INFO flwr 2023-04-16 16:56:31,792 | server.py:116 | fit progress: (3, 4.384940147399902, {'test_loss': 4.384940147399902, 'test_accuracy': 0.67}, 119.54541980000795)
DEBUG flwr 2023-04-16 16:56:31,793 | server.py:165 | evaluate_round 3: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-04-16 16:56:34,559 | server.py:179 | evaluate_round 3 received 2 results and 0 failures
DEBUG flwr 2023-04-16 16:56:34,559 | server.py:215 | fit_round 4: strategy sampled 2 clients (out of 5)


	 {'val_loss': 0.9065665304660797, 'val_accuracy': 0.65}


DEBUG flwr 2023-04-16 16:57:08,486 | server.py:229 | fit_round 4 received 2 results and 0 failures


	 {'train_loss': 0.6526248753070831, 'train_accuracy': 0.8361111111111111}


INFO flwr 2023-04-16 16:57:11,226 | server.py:116 | fit progress: (4, 4.322912693023682, {'test_loss': 4.322912693023682, 'test_accuracy': 0.71}, 158.97975729999598)
DEBUG flwr 2023-04-16 16:57:11,226 | server.py:165 | evaluate_round 4: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-04-16 16:57:13,976 | server.py:179 | evaluate_round 4 received 2 results and 0 failures
DEBUG flwr 2023-04-16 16:57:13,977 | server.py:215 | fit_round 5: strategy sampled 2 clients (out of 5)


	 {'val_loss': 0.5926760137081146, 'val_accuracy': 0.8}


DEBUG flwr 2023-04-16 16:57:48,050 | server.py:229 | fit_round 5 received 2 results and 0 failures


	 {'train_loss': 0.6825152337551117, 'train_accuracy': 0.8527777777777777}


INFO flwr 2023-04-16 16:57:50,846 | server.py:116 | fit progress: (5, 4.245245456695557, {'test_loss': 4.245245456695557, 'test_accuracy': 0.77}, 198.6009255000099)
DEBUG flwr 2023-04-16 16:57:50,847 | server.py:165 | evaluate_round 5: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-04-16 16:57:53,642 | server.py:179 | evaluate_round 5 received 2 results and 0 failures
DEBUG flwr 2023-04-16 16:57:53,643 | server.py:215 | fit_round 6: strategy sampled 2 clients (out of 5)


	 {'val_loss': 0.5068043172359467, 'val_accuracy': 0.85}


DEBUG flwr 2023-04-16 16:58:27,652 | server.py:229 | fit_round 6 received 2 results and 0 failures


	 {'train_loss': 0.5156842023134232, 'train_accuracy': 0.8805555555555555}


INFO flwr 2023-04-16 16:58:30,313 | server.py:116 | fit progress: (6, 3.647756576538086, {'test_loss': 3.647756576538086, 'test_accuracy': 0.79}, 238.0682281999907)
DEBUG flwr 2023-04-16 16:58:30,314 | server.py:165 | evaluate_round 6: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-04-16 16:58:33,052 | server.py:179 | evaluate_round 6 received 2 results and 0 failures
DEBUG flwr 2023-04-16 16:58:33,053 | server.py:215 | fit_round 7: strategy sampled 2 clients (out of 5)


	 {'val_loss': 0.6340493559837341, 'val_accuracy': 0.825}


DEBUG flwr 2023-04-16 16:59:06,999 | server.py:229 | fit_round 7 received 2 results and 0 failures


	 {'train_loss': 0.3088770657777786, 'train_accuracy': 0.9555555555555556}


INFO flwr 2023-04-16 16:59:09,668 | server.py:116 | fit progress: (7, 3.7090768814086914, {'test_loss': 3.7090768814086914, 'test_accuracy': 0.79}, 277.42403809999814)
DEBUG flwr 2023-04-16 16:59:09,669 | server.py:165 | evaluate_round 7: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-04-16 16:59:12,418 | server.py:179 | evaluate_round 7 received 2 results and 0 failures
DEBUG flwr 2023-04-16 16:59:12,419 | server.py:215 | fit_round 8: strategy sampled 2 clients (out of 5)


	 {'val_loss': 0.5403808653354645, 'val_accuracy': 0.875}


DEBUG flwr 2023-04-16 16:59:46,123 | server.py:229 | fit_round 8 received 2 results and 0 failures


	 {'train_loss': 0.37962769716978073, 'train_accuracy': 0.9305555555555556}


INFO flwr 2023-04-16 16:59:48,748 | server.py:116 | fit progress: (8, 4.619601249694824, {'test_loss': 4.619601249694824, 'test_accuracy': 0.76}, 316.50499099999433)
DEBUG flwr 2023-04-16 16:59:48,749 | server.py:165 | evaluate_round 8: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-04-16 16:59:51,501 | server.py:179 | evaluate_round 8 received 2 results and 0 failures
DEBUG flwr 2023-04-16 16:59:51,501 | server.py:215 | fit_round 9: strategy sampled 2 clients (out of 5)


	 {'val_loss': 0.4340755045413971, 'val_accuracy': 0.825}


DEBUG flwr 2023-04-16 17:00:25,263 | server.py:229 | fit_round 9 received 2 results and 0 failures


	 {'train_loss': 0.19609037786722183, 'train_accuracy': 0.975}


INFO flwr 2023-04-16 17:00:27,905 | server.py:116 | fit progress: (9, 3.78309965133667, {'test_loss': 3.78309965133667, 'test_accuracy': 0.83}, 355.6617385999998)
DEBUG flwr 2023-04-16 17:00:27,905 | server.py:165 | evaluate_round 9: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-04-16 17:00:30,651 | server.py:179 | evaluate_round 9 received 2 results and 0 failures
DEBUG flwr 2023-04-16 17:00:30,652 | server.py:215 | fit_round 10: strategy sampled 2 clients (out of 5)


	 {'val_loss': 0.41833867132663727, 'val_accuracy': 0.825}


DEBUG flwr 2023-04-16 17:01:04,339 | server.py:229 | fit_round 10 received 2 results and 0 failures


	 {'train_loss': 0.2311234325170517, 'train_accuracy': 0.9444444444444444}


INFO flwr 2023-04-16 17:01:06,970 | server.py:116 | fit progress: (10, 4.116977691650391, {'test_loss': 4.116977691650391, 'test_accuracy': 0.79}, 394.72768010001164)
DEBUG flwr 2023-04-16 17:01:06,970 | server.py:165 | evaluate_round 10: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-04-16 17:01:09,696 | server.py:179 | evaluate_round 10 received 2 results and 0 failures
INFO flwr 2023-04-16 17:01:09,697 | server.py:144 | FL finished in 397.4542209999927
INFO flwr 2023-04-16 17:01:09,708 | app.py:202 | app_fit: losses_distributed [(1, 1.485456109046936), (2, 1.0066905617713928), (3, 0.9065665304660797), (4, 0.5926760137081146), (5, 0.5068043172359467), (6, 0.6340493559837341), (7, 0.5403808653354645), (8, 0.4340755045413971), (9, 0.41833867132663727), (10, 0.3967791944742203)]
INFO flwr 2023-04-16 17:01:09,709 | app.py:203 | app_fit: metrics_distributed {'val_loss': [(1, 1.485456109046936), (2, 1.0066905617713928), (3, 0.9065665304660797), (4, 0.5926760137081146), (5, 0.5068043

	 {'val_loss': 0.3967791944742203, 'val_accuracy': 0.85}


History (loss, distributed):
	round 1: 1.485456109046936
	round 2: 1.0066905617713928
	round 3: 0.9065665304660797
	round 4: 0.5926760137081146
	round 5: 0.5068043172359467
	round 6: 0.6340493559837341
	round 7: 0.5403808653354645
	round 8: 0.4340755045413971
	round 9: 0.41833867132663727
	round 10: 0.3967791944742203
History (loss, centralized):
	round 0: 11.147294044494629
	round 1: 6.4984283447265625
	round 2: 5.266861438751221
	round 3: 4.384940147399902
	round 4: 4.322912693023682
	round 5: 4.245245456695557
	round 6: 3.647756576538086
	round 7: 3.7090768814086914
	round 8: 4.619601249694824
	round 9: 3.78309965133667
	round 10: 4.116977691650391
History (metrics, distributed):
{'val_loss': [(1, 1.485456109046936), (2, 1.0066905617713928), (3, 0.9065665304660797), (4, 0.5926760137081146), (5, 0.5068043172359467), (6, 0.6340493559837341), (7, 0.5403808653354645), (8, 0.4340755045413971), (9, 0.41833867132663727), (10, 0.3967791944742203)], 'val_accuracy': [(1, 0.6), (2, 0.675), (3,