Skip to content

Commit

Permalink
Add SparseCategoricalAccuracy metric back to example after clearing t…
Browse files Browse the repository at this point in the history
…he cached

session whenever wrapping a Keras model.

This workaround isn't needed when compiling against TF master; but TFF is
pinned the most recent public release (TF 1.13.1).

Additionally:
- Fixed s/metrics/losses/ module path for loss function
- Change test model to output tensors of length 1 rather than scalars (required by Keras)
PiperOrigin-RevId: 236652882
  • Loading branch information
ZacharyGarrett authored and tensorflower-gardener committed Mar 4, 2019
1 parent 117464b commit c4d6cd2
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 14 deletions.
4 changes: 1 addition & 3 deletions tensorflow_federated/python/examples/mnist/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ licenses(["notice"]) # Apache 2.0
py_library(
name = "mnist",
srcs = ["__init__.py"],
deps = [
":models",
],
deps = [":models"],
)

py_library(
Expand Down
6 changes: 2 additions & 4 deletions tensorflow_federated/python/examples/mnist/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,12 @@ def create_simple_keras_model(learning_rate=0.1):

def loss_fn(y_true, y_pred):
return tf.reduce_mean(
tf.keras.metrics.sparse_categorical_crossentropy(y_true, y_pred))
tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred))

model.compile(
loss=loss_fn,
optimizer=gradient_descent.SGD(learning_rate),
# TODO(b/124563513): Readd tf.keras.metrics.SparseCategoricalAccuracy()
# after upgrading to a version of TF that has this fixed.
metrics=[])
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
return model


Expand Down
16 changes: 9 additions & 7 deletions tensorflow_federated/python/examples/mnist/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,16 @@ def keras_evaluate(state):
self.assertLess(np.mean(loss_list[1:]), loss_list[0])

def test_self_contained_example(self):
emnist_batch = collections.OrderedDict([('label', 5),
emnist_batch = collections.OrderedDict([('label', [5]),
('pixels', np.random.rand(28, 28))])

output_types = collections.OrderedDict([('label', tf.int32),
('pixels', tf.float32)])

output_shapes = collections.OrderedDict([('label', tf.TensorShape([])),
('pixels', tf.TensorShape([28,
28]))])
output_shapes = collections.OrderedDict([
('label', tf.TensorShape([1])),
('pixels', tf.TensorShape([28, 28])),
])

def generate_one_emnist_batch():
yield emnist_batch
Expand All @@ -96,7 +97,7 @@ def generate_one_emnist_batch():
output_types, output_shapes)

def client_data():
return models.keras_dataset_from_emnist(dataset).repeat(10).batch(20)
return models.keras_dataset_from_emnist(dataset).repeat(2).batch(2)

train_data = [client_data()]
sample_batch = tf.contrib.framework.nest.map_structure(
Expand All @@ -110,8 +111,9 @@ def model_fn():
state = trainer.initialize()
losses = []
for _ in range(2):
state, loss = trainer.next(state, train_data)
losses.append(loss)
state, outputs = trainer.next(state, train_data)
# Track the loss.
losses.append(outputs.loss)
self.assertLess(losses[1], losses[0])


Expand Down
9 changes: 9 additions & 0 deletions tensorflow_federated/python/learning/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,15 @@ class _KerasModel(model_lib.Model):
"""Internal wrapper class for tf.keras.Model objects."""

def __init__(self, inner_model, dummy_batch, loss_func, metrics):
# TODO(b/124477598): the following set_session() should be removed in the
# future. This is a workaround for Keras' caching sessions in a way that
# isn't compatible with TFF. This is already fixed in TF master, but not as
# of v1.13.1.
#
# We do not use .clear_session() because it blows away the graph stack by
# resetting the default graph.
tf.keras.backend.set_session(None)

if hasattr(dummy_batch, '_asdict'):
dummy_batch = dummy_batch._asdict()
# Convert input to tensors, possibly from nested lists that need to be
Expand Down

0 comments on commit c4d6cd2

Please sign in to comment.