-
Notifications
You must be signed in to change notification settings - Fork 580
/
computation_utils.py
313 lines (254 loc) · 11.3 KB
/
computation_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
# Lint as: python3
# Copyright 2019, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines utility functions for constructing TFF computations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import attr
import six
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.core import api as tff
def update_state(state, **kwargs):
"""Returns a new `state` with new values for fields in `kwargs`.
Args:
state: the structure with named fields to update.
**kwargs: the list of key-value pairs of fields to update in `state`.
Raises:
KeyError: if kwargs contains a field that is not in state.
TypeError: if state is not a structure with named fields.
"""
# TODO(b/129569441): Support AnonymousTuple as well.
if not (py_typecheck.is_named_tuple(state) or py_typecheck.is_attrs(state) or
isinstance(state, collections.Mapping)):
raise TypeError('state must be a structure with named fields (e.g. '
'dict, attrs class, collections.namedtuple), '
'but found {}'.format(type(state)))
if py_typecheck.is_named_tuple(state):
d = state._asdict()
elif py_typecheck.is_attrs(state):
d = attr.asdict(state, dict_factory=collections.OrderedDict)
else:
for key in six.iterkeys(kwargs):
if key not in state:
raise KeyError(
'state does not contain a field named "{!s}"'.format(key))
d = state
d.update(kwargs)
if isinstance(state, collections.Mapping):
return d
return type(state)(**d)
class StatefulFn(object):
"""A base class for stateful functions."""
def __init__(self, initialize_fn, next_fn):
"""Creates the StatefulFn.
Args:
initialize_fn: A no-arg function that returns a Python container which can
be converted to a `tff.Value`, placed on the `tff.SERVER`, and passed as
the first argument of `__call__`. This may be called in vanilla
TensorFlow code, typically wrapped as a `tff.tf_computation`, as part of
the initialization of a larger state object.
next_fn: A function matching the signature of `__call__`, see below.
"""
py_typecheck.check_callable(initialize_fn)
py_typecheck.check_callable(next_fn)
self._initialize_fn = initialize_fn
self._next_fn = next_fn
def initialize(self):
"""Returns the initial state."""
return self._initialize_fn()
def __call__(self, state, *args, **kwargs):
"""Performs the stateful function call.
Args:
state: A `tff.Value` placed on the `tff.SERVER`.
*args: Arguments to the function.
**kwargs: Arguments to the function.
Returns:
A tuple of `tff.Value`s (state@SERVER, ...) where
* state: The updated state, to be passed to the next invocation
of call.
* ...: The result of the aggregation.
"""
return self._next_fn(tff.to_value(state), *args, **kwargs)
class StatefulAggregateFn(StatefulFn):
"""A simple container for a stateful aggregation function.
A typical (though trivial) example would be:
```
stateless_federated_mean = tff.utils.StatefulAggregateFn(
initialize_fn=lambda: (), # The state is an empty tuple.
next_fn=lambda state, value, weight=None: (
state, tff.federated_mean(value, weight=weight)))
```
"""
def __call__(self, state, value, weight=None):
"""Performs an aggregate of `value@CLIENTS`, producing `value@SERVER`.
The aggregation is optionally parameterized by `weight@CLIENTS`.
This is a function intended to (only) be invoked in the context
of a `tff.federated_computation`. It should be compatible with the
TFF type signature.
```
(state@SERVER, value@CLIENTS, weight@CLIENTS) ->
(state@SERVER, aggregate@SERVER).
```
Args:
state: A `tff.Value` placed on the `tff.SERVER`.
value: A `tff.Value` to be aggregated, placed on the `tff.CLIENTS`.
weight: An optional `tff.Value` for weighting `value`s, placed on the
`tff.CLIENTS`.
Returns:
A tuple of `tff.Value`s `(state@SERVER, aggregate@SERVER)`, where
* `state`: The updated state.
* `aggregate`: The result of the aggregation of `value` weighted by
`weight`.
"""
py_typecheck.check_type(state, tff.Value)
py_typecheck.check_type(state.type_signature, tff.FederatedType)
if state.type_signature.placement is not tff.SERVER:
raise TypeError('`state` argument must be a tff.Value placed at SERVER. '
'Got: {!s}'.format(state.type_signature))
py_typecheck.check_type(value, tff.Value)
py_typecheck.check_type(value.type_signature, tff.FederatedType)
if value.type_signature.placement is not tff.CLIENTS:
raise TypeError('`value` argument must be a tff.Value placed at CLIENTS. '
'Got: {!s}'.format(value.type_signature))
if weight is not None:
py_typecheck.check_type(weight, tff.Value)
py_typecheck.check_type(weight.type_signature, tff.FederatedType)
py_typecheck.check_type(weight.type_signature, tff.FederatedType)
if weight.type_signature.placement is not tff.CLIENTS:
raise TypeError('If not None, `weight` argument must be a tff.Value '
'placed at CLIENTS. Got: {!s}'.format(
weight.type_signature))
return self._next_fn(state, value, weight)
class StatefulBroadcastFn(StatefulFn):
"""A simple container for a stateful broadcast function.
A typical (though trivial) example would be:
```
stateless_federated_broadcast = tff.utils.StatefulBroadcastFn(
initialize_fn=lambda: (),
next_fn=lambda state, value: (
state, tff.federated_broadcast(value)))
```
"""
def __call__(self, state, value):
"""Performs a broadcast of `value@SERVER`, producing `value@CLIENTS`.
This is a function intended to (only) be invoked in the context
of a `tff.federated_computation`. It shold be compatible with the
TFF type signature
`(state@SERVER, value@SERVER) -> (state@SERVER, value@CLIENTS)`.
Args:
state: A `tff.Value` placed on the `tff.SERVER`.
value: A `tff.Value` placed on the `tff.SERVER`, to be broadcast to the
`tff.CLIENTS`.
Returns:
A tuple of `tff.Value`s `(state@SERVER, value@CLIENTS)` where
* `state`: The updated state.
* `value`: The input `value` now placed (communicated) to the
`tff.CLIENTS`.
"""
py_typecheck.check_type(state, tff.Value)
py_typecheck.check_type(state.type_signature, tff.FederatedType)
if state.type_signature.placement is not tff.SERVER:
raise TypeError('`state` argument must be a tff.Value placed at SERVER. '
'Got: {!s}'.format(state.type_signature))
py_typecheck.check_type(value, tff.Value)
py_typecheck.check_type(value.type_signature, tff.FederatedType)
if value.type_signature.placement is not tff.SERVER:
raise TypeError('`value` argument must be a tff.Value placed at CLIENTS. '
'Got: {!s}'.format(value.type_signature))
return self._next_fn(state, value)
class IterativeProcess(object):
"""A process that includes an initialization and iterated computation.
An iterated process will usually be driven by a control loop like:
```python
def initialize():
...
def next(state):
...
iterative_process = IterativeProcess(initialize, next)
state = iterative_process.initialize()
for round in range(num_rounds):
state = iterative_process.next(state)
```
The iteration step can accept arguments in addition to `state` (which must be
the first argument), and return additional arguments:
```python
def next(state, item):
...
iterative_process = ...
state = iterative_process.initialize()
for round in range(num_rounds):
state, output = iterative_process.next(state, round)
```
"""
def __init__(self, initialize_fn, next_fn):
"""Creates a `tff.IterativeProcess`.
Args:
initialize_fn: a no-arg `tff.Computation` that creates the initial state
of the chained computation.
next_fn: a `tff.Computation` that defines an iterated function. If
`initialize_fn` returns a type _T_, then `next_fn` must also return type
_T_ or multiple values where the first is of type _T_, and accept
either a single argument of type _T_ or multiple arguments where the
first argument must be of type _T_.
Raises:
TypeError: `initialize_fn` and `next_fn` are not compatible function
types.
"""
py_typecheck.check_type(initialize_fn, tff.Computation)
if initialize_fn.type_signature.parameter is not None:
raise TypeError(
'initialize_fn must be a no-arg tff.Computation, but found parameter '
'{}'.format(initialize_fn.type_signature))
initialize_result_type = initialize_fn.type_signature.result
py_typecheck.check_type(next_fn, tff.Computation)
if isinstance(next_fn.type_signature.parameter, tff.NamedTupleType):
next_first_param_type = next_fn.type_signature.parameter[0]
else:
next_first_param_type = next_fn.type_signature.parameter
if initialize_result_type != next_first_param_type:
raise TypeError('The return type of initialize_fn should match the '
'first parameter of next_fn, but found\n'
'initialize_fn.type_signature.result=\n{}\n'
'next_fn.type_signature.parameter[0]=\n{}'.format(
initialize_result_type, next_first_param_type))
next_result_type = next_fn.type_signature.result
if next_first_param_type != next_result_type:
# This might be multiple output next_fn, check if the first argument might
# be the state. If still not the right type, raise an error.
if isinstance(next_result_type, tff.NamedTupleType):
next_result_type = next_result_type[0]
if next_first_param_type != next_result_type:
raise TypeError('The return type of next_fn should match the '
'first parameter, but found\n'
'next_fn.type_signature.parameter[0]=\n{}\n'
'actual next_result_type=\n{}'.format(
next_first_param_type, next_result_type))
self._initialize_fn = initialize_fn
self._next_fn = next_fn
@property
def initialize(self):
"""A no-arg `tff.Computation` that returns the initial state."""
return self._initialize_fn
@property
def next(self):
"""A `tff.Computation` that produces the next state.
The first argument of should always be the current state (originally
produced by `tff.IterativeProcess.initialize`), and the first (or only)
returned value is the updated state.
Returns:
A `tff.Computation`.
"""
return self._next_fn