-
Notifications
You must be signed in to change notification settings - Fork 155
/
fedavg.py
executable file
·82 lines (66 loc) · 3.56 KB
/
fedavg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import numpy as np
from tqdm import trange, tqdm
import tensorflow as tf
from .fedbase import BaseFedarated
class Server(BaseFedarated):
def __init__(self, params, learner, dataset):
print('Using Federated Average to Train')
self.inner_opt = tf.train.GradientDescentOptimizer(params['learning_rate'])
super(Server, self).__init__(params, learner, dataset)
def train(self):
'''Train using Federated Averaging'''
print('Training with {} workers ---'.format(self.clients_per_round))
for i in trange(self.num_rounds, desc='Round: ', ncols=120):
# test model
if i % self.eval_every == 0:
stats = self.test()
stats_train = self.train_error_and_loss()
self.metrics.accuracies.append(stats)
self.metrics.train_accuracies.append(stats_train)
tqdm.write('At round {} accuracy: {}'.format(i, np.sum(stats[3])*1.0/np.sum(stats[2])))
tqdm.write('At round {} training accuracy: {}'.format(i, np.sum(stats_train[3])*1.0/np.sum(stats_train[2])))
tqdm.write('At round {} training loss: {}'.format(i, np.dot(stats_train[4], stats_train[2])*1.0/np.sum(stats_train[2])))
model_len = process_grad(self.latest_model).size
global_grads = np.zeros(model_len)
client_grads = np.zeros(model_len)
num_samples = []
local_grads = []
for c in self.clients:
num, client_grad = c.get_grads(model_len)
local_grads.append(client_grad)
num_samples.append(num)
global_grads = np.add(global_grads, client_grads * num)
global_grads = global_grads * 1.0 / np.sum(np.asarray(num_samples))
difference = 0
for idx in range(len(self.clients)):
difference += np.sum(np.square(global_grads - local_grads[idx]))
difference = difference * 1.0 / len(self.clients)
tqdm.write('gradient difference: {}'.format(difference))
# save server model
self.metrics.write()
self.save()
# choose K clients prop to data size
selected_clients = self.select_clients(i, num_clients=self.clients_per_round)
csolns = [] # buffer for receiving client solutions
for c in tqdm(selected_clients, desc='Client: ', leave=False, ncols=120):
# communicate the latest model
c.set_params(self.latest_model)
# solve minimization locally
soln, stats = c.solve_inner(num_epochs=self.num_epochs, batch_size=self.batch_size)
# gather solutions from client
csolns.append(soln)
# track communication cost
self.metrics.update(rnd=i, cid=c.id, stats=stats)
# update model
self.latest_model = self.aggregate(csolns)
# final test model
stats = self.test()
stats_train = self.train_error()
stats_loss = self.train_loss()
self.metrics.accuracies.append(stats)
self.metrics.train_accuracies.append(stats_train)
tqdm.write('At round {} accuracy: {}'.format(self.num_rounds, np.sum(stats[3])*1.0/np.sum(stats[2])))
tqdm.write('At round {} training accuracy: {}'.format(self.num_rounds, np.sum(stats_train[3])*1.0/np.sum(stats_train[2])))
# save server model
self.metrics.write()
self.save()