-
Notifications
You must be signed in to change notification settings - Fork 155
/
fedbase.py
executable file
·126 lines (97 loc) · 4.21 KB
/
fedbase.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import numpy as np
import tensorflow as tf
from tqdm import tqdm
from flearn.models.client import Client
from flearn.utils.model_utils import Metrics
from flearn.utils.tf_utils import process_grad
class BaseFedarated(object):
def __init__(self, params, learner, dataset):
# transfer parameters to self
for key, val in params.items(): setattr(self, key, val);
# create worker nodes
tf.reset_default_graph()
self.client_model = learner(*params['model_params'], self.inner_opt, self.seed)
self.clients = self.setup_clients(dataset, self.client_model)
print('{} Clients in Total'.format(len(self.clients)))
self.latest_model = self.client_model.get_params()
# initialize system metrics
self.metrics = Metrics(self.clients, params)
def __del__(self):
self.client_model.close()
def setup_clients(self, dataset, model=None):
'''instantiates clients based on given train and test data directories
Return:
list of Clients
'''
users, groups, train_data, test_data = dataset
if len(groups) == 0:
groups = [None for _ in users]
all_clients = [Client(u, g, train_data[u], test_data[u], model) for u, g in zip(users, groups)]
return all_clients
def train_error_and_loss(self):
num_samples = []
tot_correct = []
losses = []
for c in self.clients:
ct, cl, ns = c.train_error_and_loss()
tot_correct.append(ct*1.0)
num_samples.append(ns)
losses.append(cl*1.0)
ids = [c.id for c in self.clients]
groups = [c.group for c in self.clients]
return ids, groups, num_samples, tot_correct, losses
def show_grads(self):
'''
Return:
gradients on all workers and the global gradient
'''
model_len = process_grad(self.latest_model).size
global_grads = np.zeros(model_len)
intermediate_grads = []
samples=[]
self.client_model.set_params(self.latest_model)
for c in self.clients:
num_samples, client_grads = c.get_grads(self.latest_model)
samples.append(num_samples)
global_grads = np.add(global_grads, client_grads * num_samples)
intermediate_grads.append(client_grads)
global_grads = global_grads * 1.0 / np.sum(np.asarray(samples))
intermediate_grads.append(global_grads)
return intermediate_grads
def test(self):
'''tests self.latest_model on given clients
'''
num_samples = []
tot_correct = []
self.client_model.set_params(self.latest_model)
for c in self.clients:
ct, ns = c.test()
tot_correct.append(ct*1.0)
num_samples.append(ns)
ids = [c.id for c in self.clients]
groups = [c.group for c in self.clients]
return ids, groups, num_samples, tot_correct
def save(self):
pass
def select_clients(self, round, num_clients=20):
'''selects num_clients clients weighted by number of samples from possible_clients
Args:
num_clients: number of clients to select; default 20
note that within function, num_clients is set to
min(num_clients, len(possible_clients))
Return:
list of selected clients objects
'''
num_clients = min(num_clients, len(self.clients))
np.random.seed(round) # make sure for each comparison, we are selecting the same clients each round
indices = np.random.choice(range(len(self.clients)), num_clients, replace=False)
return indices, np.asarray(self.clients)[indices]
def aggregate(self, wsolns):
total_weight = 0.0
base = [0]*len(wsolns[0][1])
for (w, soln) in wsolns: # w is the number of local samples
total_weight += w
for i, v in enumerate(soln):
base[i] += w*v.astype(np.float64)
averaged_soln = [v / total_weight for v in base]
return averaged_soln