Skip to content

Commit

Permalink
Breakage in colabs caused by shifting the API surface and thus being …
Browse files Browse the repository at this point in the history
…out of sync between pip package and head. Reverting changes to learning directory and ipynb, keeping them in core.

PiperOrigin-RevId: 237058456
  • Loading branch information
jkr26 authored and tensorflower-gardener committed Mar 6, 2019
1 parent 7d997ae commit 45716b5
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -891,8 +891,8 @@
],
"source": [
"#@test {\"timeout\": 600, \"output\": \"ignore\"}\n",
"state, metrics = iterative_process.next(state, federated_train_data)\n",
"print('round 1, loss={:.4f}'.format(metrics.loss))"
"state, loss = iterative_process.next(state, federated_train_data)\n",
"print('round 1, loss={:.4f}'.format(loss))"
]
},
{
Expand Down Expand Up @@ -952,8 +952,8 @@
"source": [
"#@test {\"skip\": true}\n",
"for round_num in range(2, 11):\n",
" state, metrics = iterative_process.next(state, federated_train_data)\n",
" print('round {:2d}, loss={:.4f}'.format(round_num, metrics.loss))"
" state, loss = iterative_process.next(state, federated_train_data)\n",
" print('round {:2d}, loss={:.4f}'.format(round_num, loss))"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,9 @@ def model_fn():

prev_loss = np.inf
for _ in range(3):
server_state, metrics = iterative_process.next(server_state, federated_ds)
self.assertLess(metrics.loss, prev_loss)
prev_loss = metrics.loss
server_state, loss = iterative_process.next(server_state, federated_ds)
self.assertLess(loss, prev_loss)
prev_loss = loss

def test_execute_empty_data(self):
iterative_process = federated_averaging.build_federated_averaging_process(
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_federated/python/learning/federated_sgd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ def model_fn():
server_state = iterative_process.initialize()
prev_loss = np.inf
for _ in range(3):
server_state, metrics = iterative_process.next(server_state, federated_ds)
self.assertLess(metrics.loss, prev_loss)
prev_loss = metrics.loss
server_state, loss = iterative_process.next(server_state, federated_ds)
self.assertLess(loss, prev_loss)
prev_loss = loss

def test_execute_empty_data(self):
iterative_process = federated_sgd.build_federated_sgd_process(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,13 @@ def _cast_to_float(x):
aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
client_outputs.model_output)

# Promote the FederatedType outside the NamedTupleType
aggregated_outputs = tff.federated_zip(aggregated_outputs)
# Promote the FederatedType outside the NamedTupleType, or return the
# singluar federated value.
num_outputs = len(aggregated_outputs)
if num_outputs == 1:
aggregated_outputs = aggregated_outputs[0]
elif num_outputs >= 2:
aggregated_outputs = tff.federated_zip(aggregated_outputs)

return server_state, aggregated_outputs

Expand Down

0 comments on commit 45716b5

Please sign in to comment.