From 46f9c697cc4352e69cbddb9bb47cdb3b353b3cf2 Mon Sep 17 00:00:00 2001 From: Zachary Garrett Date: Fri, 29 Mar 2019 16:49:15 -0700 Subject: [PATCH] Add a model for assign tff.learning.Model weights back to a tf.keras.Model. This is a workaround for issue #258, which uncovered that tf.keras.Model.weights and tf.keras.Model.get_weights() are not ordered the same. - Add a toy example model that uses batch norm (includes non-trainable variables), which will fail without this change. - Move client optimizer variables to local_variables, this includes variables such as iteration number. PiperOrigin-RevId: 241075678 --- .../tff/learning/framework/ModelWeights.md | 17 ++++++ .../python/learning/model_examples.py | 49 +++++++++++++++-- .../python/learning/model_utils.py | 17 +++++- .../python/learning/model_utils_test.py | 52 +++++++++++++++++++ 4 files changed, 130 insertions(+), 5 deletions(-) diff --git a/docs/api_docs/python/tff/learning/framework/ModelWeights.md b/docs/api_docs/python/tff/learning/framework/ModelWeights.md index 4e855f4750..ceadd76d0f 100644 --- a/docs/api_docs/python/tff/learning/framework/ModelWeights.md +++ b/docs/api_docs/python/tff/learning/framework/ModelWeights.md @@ -5,6 +5,7 @@ + @@ -47,8 +48,24 @@ Returns a list of weights in the same order as `tf.keras.Model.weights`. (Assuming that this ModelWeights object corresponds to the weights of a keras model). +IMPORTANT: this is not the same order as `tf.keras.Model.get_weights()`, and +hence will not work with `tf.keras.Model.set_weights()`. Instead, use +`tff.learning.ModelWeights.assign_weights_to`. + ## Methods +

assign_weights_to

+ +```python +assign_weights_to(keras_model) +``` + +Assign these TFF model weights to the weights of a `tf.keras.Model`. + +#### Args: + +* `keras_model`: the `tf.keras.Model` object to assign weights to. +

from_model

```python diff --git a/tensorflow_federated/python/learning/model_examples.py b/tensorflow_federated/python/learning/model_examples.py index d451dbd121..8e310d76f8 100644 --- a/tensorflow_federated/python/learning/model_examples.py +++ b/tensorflow_federated/python/learning/model_examples.py @@ -171,21 +171,21 @@ def _dense_all_zeros_layer(input_dims=None, output_dim=1): return build_keras_dense_layer() -def build_linear_regresion_keras_sequential_model(feature_dims): +def build_linear_regresion_keras_sequential_model(feature_dims=2): """Build a linear regression `tf.keras.Model` using the Sequential API.""" keras_model = tf.keras.models.Sequential() keras_model.add(_dense_all_zeros_layer(feature_dims)) return keras_model -def build_linear_regresion_keras_functional_model(feature_dims): +def build_linear_regresion_keras_functional_model(feature_dims=2): """Build a linear regression `tf.keras.Model` using the functional API.""" a = tf.keras.layers.Input(shape=(feature_dims,)) b = _dense_all_zeros_layer()(a) return tf.keras.Model(inputs=a, outputs=b) -def build_linear_regresion_keras_subclass_model(feature_dims): +def build_linear_regresion_keras_subclass_model(feature_dims=2): """Build a linear regression model by sub-classing `tf.keras.Model`.""" del feature_dims # unused. @@ -207,3 +207,46 @@ def build_embedding_keras_model(vocab_size=10): keras_model.add(tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=5)) keras_model.add(tf.keras.layers.Softmax()) return keras_model + + +def build_conv_batch_norm_keras_model(): + """Builds a test model with convolution and batch normalization.""" + # This is an example of a model that has trainable and non-trainable + # variables. + l = tf.keras.layers + data_format = 'channels_last' + max_pool = l.MaxPooling2D((2, 2), (2, 2), + padding='same', + data_format=data_format) + keras_model = tf.keras.models.Sequential([ + l.Reshape(target_shape=[28, 28, 1], input_shape=(28 * 28,)), + l.Conv2D( + 32, + 5, + padding='same', + data_format=data_format, + activation=tf.nn.relu, + kernel_initializer='zeros', + bias_initializer='zeros'), + max_pool, + l.BatchNormalization(), + l.Conv2D( + 64, + 5, + padding='same', + data_format=data_format, + activation=tf.nn.relu, + kernel_initializer='zeros', + bias_initializer='zeros'), + max_pool, + l.BatchNormalization(), + l.Flatten(), + l.Dense( + 1024, + activation=tf.nn.relu, + kernel_initializer='zeros', + bias_initializer='zeros'), + l.Dropout(0.4), + l.Dense(10, kernel_initializer='zeros', bias_initializer='zeros'), + ]) + return keras_model diff --git a/tensorflow_federated/python/learning/model_utils.py b/tensorflow_federated/python/learning/model_utils.py index 3b1ad38f21..5521ab4130 100644 --- a/tensorflow_federated/python/learning/model_utils.py +++ b/tensorflow_federated/python/learning/model_utils.py @@ -96,9 +96,22 @@ def keras_weights(self): (Assuming that this ModelWeights object corresponds to the weights of a keras model). + + IMPORTANT: this is not the same order as `tf.keras.Model.get_weights()`, and + hence will not work with `tf.keras.Model.set_weights()`. Instead, use + `tff.learning.ModelWeights.assign_weights_to`. """ return list(self.trainable.values()) + list(self.non_trainable.values()) + def assign_weights_to(self, keras_model): + """Assign these TFF model weights to the weights of a `tf.keras.Model`. + + Args: + keras_model: the `tf.keras.Model` object to assign weights to. + """ + for k, w in zip(keras_model.weights, self.keras_weights): + k.assign(w) + def keras_weights_from_tff_weights(tff_weights): """Converts TFF's nested weights structure to flat weights. @@ -442,8 +455,8 @@ def __init__(self, inner_model, dummy_batch): inner_model.loss_functions[0], inner_model.metrics) @property - def non_trainable_variables(self): - return (super(_TrainableKerasModel, self).non_trainable_variables + + def local_variables(self): + return (super(_TrainableKerasModel, self).local_variables + self._keras_model.optimizer.variables()) @tf.contrib.eager.function(autograph=False) diff --git a/tensorflow_federated/python/learning/model_utils_test.py b/tensorflow_federated/python/learning/model_utils_test.py index a68af73adb..c57a2e05d5 100644 --- a/tensorflow_federated/python/learning/model_utils_test.py +++ b/tensorflow_federated/python/learning/model_utils_test.py @@ -312,6 +312,58 @@ def loss_fn(y_true, y_pred): self.assertGreater(m['loss'][0], 0.0) self.assertEqual(m['loss'][1], input_vocab_size * num_iterations) + def test_keras_model_using_batch_norm(self): + model = model_examples.build_conv_batch_norm_keras_model() + + def loss_fn(y_true, y_pred): + loss_per_example = tf.keras.losses.sparse_categorical_crossentropy( + y_true=y_true, y_pred=y_pred) + return tf.reduce_mean(loss_per_example) + + model.compile( + optimizer=gradient_descent.SGD(learning_rate=0.01), + loss=loss_fn, + metrics=[NumBatchesCounter(), NumExamplesCounter()]) + + dummy_batch = collections.OrderedDict([ + ('x', np.zeros([1, 28 * 28], dtype=np.float32)), + ('y', np.zeros([1, 1], dtype=np.int64)), + ]) + tff_model = model_utils.from_compiled_keras_model( + keras_model=model, dummy_batch=dummy_batch) + + batch_size = 2 + batch = { + 'x': + np.random.uniform(low=0.0, high=1.0, + size=[batch_size, 28 * 28]).astype(np.float32), + 'y': + np.random.random_integers(low=0, high=9, size=[batch_size, + 1]).astype(np.int64), + } + + num_iterations = 2 + for _ in range(num_iterations): + self.evaluate(tff_model.train_on_batch(batch)) + + m = self.evaluate(tff_model.report_local_outputs()) + self.assertEqual(m['num_batches'], [num_iterations]) + self.assertEqual(m['num_examples'], [batch_size * num_iterations]) + self.assertGreater(m['loss'][0], 0.0) + self.assertEqual(m['loss'][1], batch_size * num_iterations) + + # Ensure we can assign the FL trained model weights to a new model. + tff_weights = model_utils.ModelWeights.from_model(tff_model) + keras_model = model_examples.build_conv_batch_norm_keras_model() + tff_weights.assign_weights_to(keras_model) + + for keras_w, tff_w in zip(keras_model.weights, tff_weights.keras_weights): + self.assertAllClose( + self.evaluate(keras_w), + self.evaluate(tff_w), + atol=1e-4, + msg='Variable [{}]'.format(keras_w.name)) + def test_wrap_tff_model_in_tf_computation(self): feature_dims = 3