From 7d997aea47395be436ad03cacb4c8072d04086d1 Mon Sep 17 00:00:00 2001 From: krush Date: Mon, 4 Mar 2019 18:06:48 -0800 Subject: [PATCH] Implements handling of tuples of length 1 in federated_zip. In particular, federated_zip will now simply promote the federation of a 1-tuple of federated values, returning a federated 1-tuple of values. WARNING: This CL is backwards-incompatible. Now an IterativeProcess built through `build_federated_optimizer_process` will always return a *tuple* of metrics as the second return value of its next method, as opposed to returning a tuple if there are 2 or more metrics specified and a scalar if there is only 1. PiperOrigin-RevId: 236763176 --- ...ed_learning_for_image_classification.ipynb | 8 ++-- .../python/core/api/intrinsics_test.py | 47 +++++++++++++++++++ .../python/core/impl/intrinsic_factory.py | 21 +++++---- .../learning/federated_averaging_test.py | 6 +-- .../python/learning/federated_sgd_test.py | 6 +-- .../learning/framework/optimizer_utils.py | 9 +--- 6 files changed, 72 insertions(+), 25 deletions(-) diff --git a/docs/tutorials/federated_learning_for_image_classification.ipynb b/docs/tutorials/federated_learning_for_image_classification.ipynb index 4d85cb6f98..b9c240f2b8 100644 --- a/docs/tutorials/federated_learning_for_image_classification.ipynb +++ b/docs/tutorials/federated_learning_for_image_classification.ipynb @@ -891,8 +891,8 @@ ], "source": [ "#@test {\"timeout\": 600, \"output\": \"ignore\"}\n", - "state, loss = iterative_process.next(state, federated_train_data)\n", - "print('round 1, loss={:.4f}'.format(loss))" + "state, metrics = iterative_process.next(state, federated_train_data)\n", + "print('round 1, loss={:.4f}'.format(metrics.loss))" ] }, { @@ -952,8 +952,8 @@ "source": [ "#@test {\"skip\": true}\n", "for round_num in range(2, 11):\n", - " state, loss = iterative_process.next(state, federated_train_data)\n", - " print('round {:2d}, loss={:.4f}'.format(round_num, loss))" + " state, metrics = iterative_process.next(state, federated_train_data)\n", + " print('round {:2d}, loss={:.4f}'.format(round_num, metrics.loss))" ] }, { diff --git a/tensorflow_federated/python/core/api/intrinsics_test.py b/tensorflow_federated/python/core/api/intrinsics_test.py index e3414d7ee6..7c8249c4da 100644 --- a/tensorflow_federated/python/core/api/intrinsics_test.py +++ b/tensorflow_federated/python/core/api/intrinsics_test.py @@ -137,6 +137,53 @@ def foo(x, y): str(foo.type_signature), '(<{int32}@CLIENTS,bool@CLIENTS> -> {}@CLIENTS)') + def test_federated_zip_with_single_unnamed_int_client(self): + + @computations.federated_computation([ + computation_types.FederatedType(tf.int32, placements.CLIENTS), + ]) + def foo(x): + return intrinsics.federated_zip(x) + + self.assertEqual( + str(foo.type_signature), '(<{int32}@CLIENTS> -> {}@CLIENTS)') + + def test_federated_zip_with_single_unnamed_int_server(self): + + @computations.federated_computation([ + computation_types.FederatedType( + tf.int32, placements.SERVER, all_equal=True), + ]) + def foo(x): + return intrinsics.federated_zip(x) + + self.assertEqual( + str(foo.type_signature), '( -> @SERVER)') + + def test_federated_zip_with_single_named_bool_clients(self): + + @computations.federated_computation([ + ('a', computation_types.FederatedType(tf.bool, placements.CLIENTS)), + ]) + def foo(x): + return intrinsics.federated_zip(x) + + self.assertEqual( + str(foo.type_signature), '( -> {}@CLIENTS)') + + def test_federated_zip_with_single_named_bool_server(self): + + @computations.federated_computation([ + ('a', + computation_types.FederatedType( + tf.bool, placements.SERVER, all_equal=True)), + ]) + def foo(x): + return intrinsics.federated_zip(x) + + self.assertEqual( + str(foo.type_signature), '( -> @SERVER)') + def test_federated_zip_with_names_client_non_all_equal_int_and_bool(self): @computations.federated_computation([ diff --git a/tensorflow_federated/python/core/impl/intrinsic_factory.py b/tensorflow_federated/python/core/impl/intrinsic_factory.py index cb43436948..dc7f9fbd05 100644 --- a/tensorflow_federated/python/core/impl/intrinsic_factory.py +++ b/tensorflow_federated/python/core/impl/intrinsic_factory.py @@ -463,8 +463,6 @@ def federated_zip(self, value): """ # TODO(b/113112108): Extend this to accept *args. - # TODO(b/113112108): Allow for auto-extraction of NamedTuples of length 1. - # TODO(b/113112108): We use the iterate/unwrap approach below because # our type system is not powerful enough to express the concept of # "an operation that takes tuples of T of arbitrary length", and therefore @@ -476,13 +474,10 @@ def federated_zip(self, value): py_typecheck.check_type(value, value_base.Value) py_typecheck.check_type(value.type_signature, computation_types.NamedTupleType) - num_elements = len(anonymous_tuple.to_elements(value.type_signature)) - if num_elements < 2: - raise TypeError( - 'The federated zip operator zips tuples of at least two elements, ' - 'but the tuple given as argument has {} ' - 'elements.'.format(num_elements)) elements_to_zip = anonymous_tuple.to_elements(value.type_signature) + num_elements = len(elements_to_zip) + py_typecheck.check_type(elements_to_zip[0][1], + computation_types.FederatedType) output_placement = elements_to_zip[0][1].placement zip_apply_fn = { placements.CLIENTS: self.federated_map, @@ -492,6 +487,16 @@ def federated_zip(self, value): raise TypeError( 'federated_zip only supports components with CLIENTS or ' 'SERVER placement, [{}] is unsupported'.format(output_placement)) + if num_elements == 0: + raise ValueError('federated_zip is only supported on nonempty tuples.') + if num_elements == 1: + input_ref = computation_building_blocks.Reference( + 'value_in', elements_to_zip[0][1].member) + output_tuple = computation_building_blocks.Tuple([(elements_to_zip[0][0], + input_ref)]) + lam = computation_building_blocks.Lambda( + 'value_in', input_ref.type_signature, output_tuple) + return zip_apply_fn[output_placement](lam, value[0]) for _, elem in elements_to_zip: py_typecheck.check_type(elem, computation_types.FederatedType) if elem.placement is not output_placement: diff --git a/tensorflow_federated/python/learning/federated_averaging_test.py b/tensorflow_federated/python/learning/federated_averaging_test.py index b5a10c2921..5f5190a4c0 100644 --- a/tensorflow_federated/python/learning/federated_averaging_test.py +++ b/tensorflow_federated/python/learning/federated_averaging_test.py @@ -219,9 +219,9 @@ def model_fn(): prev_loss = np.inf for _ in range(3): - server_state, loss = iterative_process.next(server_state, federated_ds) - self.assertLess(loss, prev_loss) - prev_loss = loss + server_state, metrics = iterative_process.next(server_state, federated_ds) + self.assertLess(metrics.loss, prev_loss) + prev_loss = metrics.loss def test_execute_empty_data(self): iterative_process = federated_averaging.build_federated_averaging_process( diff --git a/tensorflow_federated/python/learning/federated_sgd_test.py b/tensorflow_federated/python/learning/federated_sgd_test.py index 743bd35194..de91e171a9 100644 --- a/tensorflow_federated/python/learning/federated_sgd_test.py +++ b/tensorflow_federated/python/learning/federated_sgd_test.py @@ -220,9 +220,9 @@ def model_fn(): server_state = iterative_process.initialize() prev_loss = np.inf for _ in range(3): - server_state, loss = iterative_process.next(server_state, federated_ds) - self.assertLess(loss, prev_loss) - prev_loss = loss + server_state, metrics = iterative_process.next(server_state, federated_ds) + self.assertLess(metrics.loss, prev_loss) + prev_loss = metrics.loss def test_execute_empty_data(self): iterative_process = federated_sgd.build_federated_sgd_process( diff --git a/tensorflow_federated/python/learning/framework/optimizer_utils.py b/tensorflow_federated/python/learning/framework/optimizer_utils.py index 38de2efb2c..0b0d6ba242 100644 --- a/tensorflow_federated/python/learning/framework/optimizer_utils.py +++ b/tensorflow_federated/python/learning/framework/optimizer_utils.py @@ -394,13 +394,8 @@ def _cast_to_float(x): aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) - # Promote the FederatedType outside the NamedTupleType, or return the - # singluar federated value. - num_outputs = len(aggregated_outputs) - if num_outputs == 1: - aggregated_outputs = aggregated_outputs[0] - elif num_outputs >= 2: - aggregated_outputs = tff.federated_zip(aggregated_outputs) + # Promote the FederatedType outside the NamedTupleType + aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs