Skip to content

Commit

Permalink
Implements handling of tuples of length 1 in federated_zip.
Browse files Browse the repository at this point in the history
In particular, federated_zip will now simply promote the federation of a 1-tuple of federated values, returning a federated 1-tuple of values.

WARNING: This CL is backwards-incompatible.
Now an IterativeProcess built through `build_federated_optimizer_process` will always return a *tuple* of metrics as the second return value of its next method, as opposed to returning a tuple if there are 2 or more metrics specified and a scalar if there is only 1.
PiperOrigin-RevId: 236763176
  • Loading branch information
jkr26 authored and tensorflower-gardener committed Mar 5, 2019
1 parent c4d6cd2 commit 7d997ae
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -891,8 +891,8 @@
],
"source": [
"#@test {\"timeout\": 600, \"output\": \"ignore\"}\n",
"state, loss = iterative_process.next(state, federated_train_data)\n",
"print('round 1, loss={:.4f}'.format(loss))"
"state, metrics = iterative_process.next(state, federated_train_data)\n",
"print('round 1, loss={:.4f}'.format(metrics.loss))"
]
},
{
Expand Down Expand Up @@ -952,8 +952,8 @@
"source": [
"#@test {\"skip\": true}\n",
"for round_num in range(2, 11):\n",
" state, loss = iterative_process.next(state, federated_train_data)\n",
" print('round {:2d}, loss={:.4f}'.format(round_num, loss))"
" state, metrics = iterative_process.next(state, federated_train_data)\n",
" print('round {:2d}, loss={:.4f}'.format(round_num, metrics.loss))"
]
},
{
Expand Down
47 changes: 47 additions & 0 deletions tensorflow_federated/python/core/api/intrinsics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,53 @@ def foo(x, y):
str(foo.type_signature),
'(<{int32}@CLIENTS,bool@CLIENTS> -> {<int32,bool>}@CLIENTS)')

def test_federated_zip_with_single_unnamed_int_client(self):

@computations.federated_computation([
computation_types.FederatedType(tf.int32, placements.CLIENTS),
])
def foo(x):
return intrinsics.federated_zip(x)

self.assertEqual(
str(foo.type_signature), '(<{int32}@CLIENTS> -> {<int32>}@CLIENTS)')

def test_federated_zip_with_single_unnamed_int_server(self):

@computations.federated_computation([
computation_types.FederatedType(
tf.int32, placements.SERVER, all_equal=True),
])
def foo(x):
return intrinsics.federated_zip(x)

self.assertEqual(
str(foo.type_signature), '(<int32@SERVER> -> <int32>@SERVER)')

def test_federated_zip_with_single_named_bool_clients(self):

@computations.federated_computation([
('a', computation_types.FederatedType(tf.bool, placements.CLIENTS)),
])
def foo(x):
return intrinsics.federated_zip(x)

self.assertEqual(
str(foo.type_signature), '(<a={bool}@CLIENTS> -> {<a=bool>}@CLIENTS)')

def test_federated_zip_with_single_named_bool_server(self):

@computations.federated_computation([
('a',
computation_types.FederatedType(
tf.bool, placements.SERVER, all_equal=True)),
])
def foo(x):
return intrinsics.federated_zip(x)

self.assertEqual(
str(foo.type_signature), '(<a=bool@SERVER> -> <a=bool>@SERVER)')

def test_federated_zip_with_names_client_non_all_equal_int_and_bool(self):

@computations.federated_computation([
Expand Down
21 changes: 13 additions & 8 deletions tensorflow_federated/python/core/impl/intrinsic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,8 +463,6 @@ def federated_zip(self, value):
"""
# TODO(b/113112108): Extend this to accept *args.

# TODO(b/113112108): Allow for auto-extraction of NamedTuples of length 1.

# TODO(b/113112108): We use the iterate/unwrap approach below because
# our type system is not powerful enough to express the concept of
# "an operation that takes tuples of T of arbitrary length", and therefore
Expand All @@ -476,13 +474,10 @@ def federated_zip(self, value):
py_typecheck.check_type(value, value_base.Value)
py_typecheck.check_type(value.type_signature,
computation_types.NamedTupleType)
num_elements = len(anonymous_tuple.to_elements(value.type_signature))
if num_elements < 2:
raise TypeError(
'The federated zip operator zips tuples of at least two elements, '
'but the tuple given as argument has {} '
'elements.'.format(num_elements))
elements_to_zip = anonymous_tuple.to_elements(value.type_signature)
num_elements = len(elements_to_zip)
py_typecheck.check_type(elements_to_zip[0][1],
computation_types.FederatedType)
output_placement = elements_to_zip[0][1].placement
zip_apply_fn = {
placements.CLIENTS: self.federated_map,
Expand All @@ -492,6 +487,16 @@ def federated_zip(self, value):
raise TypeError(
'federated_zip only supports components with CLIENTS or '
'SERVER placement, [{}] is unsupported'.format(output_placement))
if num_elements == 0:
raise ValueError('federated_zip is only supported on nonempty tuples.')
if num_elements == 1:
input_ref = computation_building_blocks.Reference(
'value_in', elements_to_zip[0][1].member)
output_tuple = computation_building_blocks.Tuple([(elements_to_zip[0][0],
input_ref)])
lam = computation_building_blocks.Lambda(
'value_in', input_ref.type_signature, output_tuple)
return zip_apply_fn[output_placement](lam, value[0])
for _, elem in elements_to_zip:
py_typecheck.check_type(elem, computation_types.FederatedType)
if elem.placement is not output_placement:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,9 @@ def model_fn():

prev_loss = np.inf
for _ in range(3):
server_state, loss = iterative_process.next(server_state, federated_ds)
self.assertLess(loss, prev_loss)
prev_loss = loss
server_state, metrics = iterative_process.next(server_state, federated_ds)
self.assertLess(metrics.loss, prev_loss)
prev_loss = metrics.loss

def test_execute_empty_data(self):
iterative_process = federated_averaging.build_federated_averaging_process(
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_federated/python/learning/federated_sgd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ def model_fn():
server_state = iterative_process.initialize()
prev_loss = np.inf
for _ in range(3):
server_state, loss = iterative_process.next(server_state, federated_ds)
self.assertLess(loss, prev_loss)
prev_loss = loss
server_state, metrics = iterative_process.next(server_state, federated_ds)
self.assertLess(metrics.loss, prev_loss)
prev_loss = metrics.loss

def test_execute_empty_data(self):
iterative_process = federated_sgd.build_federated_sgd_process(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,13 +394,8 @@ def _cast_to_float(x):
aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
client_outputs.model_output)

# Promote the FederatedType outside the NamedTupleType, or return the
# singluar federated value.
num_outputs = len(aggregated_outputs)
if num_outputs == 1:
aggregated_outputs = aggregated_outputs[0]
elif num_outputs >= 2:
aggregated_outputs = tff.federated_zip(aggregated_outputs)
# Promote the FederatedType outside the NamedTupleType
aggregated_outputs = tff.federated_zip(aggregated_outputs)

return server_state, aggregated_outputs

Expand Down

0 comments on commit 7d997ae

Please sign in to comment.