diff --git a/docs/api_docs/python/tff/learning/framework/ModelWeights.md b/docs/api_docs/python/tff/learning/framework/ModelWeights.md
index 4e855f4750..ceadd76d0f 100644
--- a/docs/api_docs/python/tff/learning/framework/ModelWeights.md
+++ b/docs/api_docs/python/tff/learning/framework/ModelWeights.md
@@ -5,6 +5,7 @@
+
@@ -47,8 +48,24 @@ Returns a list of weights in the same order as `tf.keras.Model.weights`.
(Assuming that this ModelWeights object corresponds to the weights of a keras
model).
+IMPORTANT: this is not the same order as `tf.keras.Model.get_weights()`, and
+hence will not work with `tf.keras.Model.set_weights()`. Instead, use
+`tff.learning.ModelWeights.assign_weights_to`.
+
## Methods
+
assign_weights_to
+
+```python
+assign_weights_to(keras_model)
+```
+
+Assign these TFF model weights to the weights of a `tf.keras.Model`.
+
+#### Args:
+
+* `keras_model`: the `tf.keras.Model` object to assign weights to.
+
from_model
```python
diff --git a/tensorflow_federated/python/learning/model_examples.py b/tensorflow_federated/python/learning/model_examples.py
index d451dbd121..8e310d76f8 100644
--- a/tensorflow_federated/python/learning/model_examples.py
+++ b/tensorflow_federated/python/learning/model_examples.py
@@ -171,21 +171,21 @@ def _dense_all_zeros_layer(input_dims=None, output_dim=1):
return build_keras_dense_layer()
-def build_linear_regresion_keras_sequential_model(feature_dims):
+def build_linear_regresion_keras_sequential_model(feature_dims=2):
"""Build a linear regression `tf.keras.Model` using the Sequential API."""
keras_model = tf.keras.models.Sequential()
keras_model.add(_dense_all_zeros_layer(feature_dims))
return keras_model
-def build_linear_regresion_keras_functional_model(feature_dims):
+def build_linear_regresion_keras_functional_model(feature_dims=2):
"""Build a linear regression `tf.keras.Model` using the functional API."""
a = tf.keras.layers.Input(shape=(feature_dims,))
b = _dense_all_zeros_layer()(a)
return tf.keras.Model(inputs=a, outputs=b)
-def build_linear_regresion_keras_subclass_model(feature_dims):
+def build_linear_regresion_keras_subclass_model(feature_dims=2):
"""Build a linear regression model by sub-classing `tf.keras.Model`."""
del feature_dims # unused.
@@ -207,3 +207,46 @@ def build_embedding_keras_model(vocab_size=10):
keras_model.add(tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=5))
keras_model.add(tf.keras.layers.Softmax())
return keras_model
+
+
+def build_conv_batch_norm_keras_model():
+ """Builds a test model with convolution and batch normalization."""
+ # This is an example of a model that has trainable and non-trainable
+ # variables.
+ l = tf.keras.layers
+ data_format = 'channels_last'
+ max_pool = l.MaxPooling2D((2, 2), (2, 2),
+ padding='same',
+ data_format=data_format)
+ keras_model = tf.keras.models.Sequential([
+ l.Reshape(target_shape=[28, 28, 1], input_shape=(28 * 28,)),
+ l.Conv2D(
+ 32,
+ 5,
+ padding='same',
+ data_format=data_format,
+ activation=tf.nn.relu,
+ kernel_initializer='zeros',
+ bias_initializer='zeros'),
+ max_pool,
+ l.BatchNormalization(),
+ l.Conv2D(
+ 64,
+ 5,
+ padding='same',
+ data_format=data_format,
+ activation=tf.nn.relu,
+ kernel_initializer='zeros',
+ bias_initializer='zeros'),
+ max_pool,
+ l.BatchNormalization(),
+ l.Flatten(),
+ l.Dense(
+ 1024,
+ activation=tf.nn.relu,
+ kernel_initializer='zeros',
+ bias_initializer='zeros'),
+ l.Dropout(0.4),
+ l.Dense(10, kernel_initializer='zeros', bias_initializer='zeros'),
+ ])
+ return keras_model
diff --git a/tensorflow_federated/python/learning/model_utils.py b/tensorflow_federated/python/learning/model_utils.py
index 3b1ad38f21..5521ab4130 100644
--- a/tensorflow_federated/python/learning/model_utils.py
+++ b/tensorflow_federated/python/learning/model_utils.py
@@ -96,9 +96,22 @@ def keras_weights(self):
(Assuming that this ModelWeights object corresponds to the weights of
a keras model).
+
+ IMPORTANT: this is not the same order as `tf.keras.Model.get_weights()`, and
+ hence will not work with `tf.keras.Model.set_weights()`. Instead, use
+ `tff.learning.ModelWeights.assign_weights_to`.
"""
return list(self.trainable.values()) + list(self.non_trainable.values())
+ def assign_weights_to(self, keras_model):
+ """Assign these TFF model weights to the weights of a `tf.keras.Model`.
+
+ Args:
+ keras_model: the `tf.keras.Model` object to assign weights to.
+ """
+ for k, w in zip(keras_model.weights, self.keras_weights):
+ k.assign(w)
+
def keras_weights_from_tff_weights(tff_weights):
"""Converts TFF's nested weights structure to flat weights.
@@ -442,8 +455,8 @@ def __init__(self, inner_model, dummy_batch):
inner_model.loss_functions[0], inner_model.metrics)
@property
- def non_trainable_variables(self):
- return (super(_TrainableKerasModel, self).non_trainable_variables +
+ def local_variables(self):
+ return (super(_TrainableKerasModel, self).local_variables +
self._keras_model.optimizer.variables())
@tf.contrib.eager.function(autograph=False)
diff --git a/tensorflow_federated/python/learning/model_utils_test.py b/tensorflow_federated/python/learning/model_utils_test.py
index a68af73adb..c57a2e05d5 100644
--- a/tensorflow_federated/python/learning/model_utils_test.py
+++ b/tensorflow_federated/python/learning/model_utils_test.py
@@ -312,6 +312,58 @@ def loss_fn(y_true, y_pred):
self.assertGreater(m['loss'][0], 0.0)
self.assertEqual(m['loss'][1], input_vocab_size * num_iterations)
+ def test_keras_model_using_batch_norm(self):
+ model = model_examples.build_conv_batch_norm_keras_model()
+
+ def loss_fn(y_true, y_pred):
+ loss_per_example = tf.keras.losses.sparse_categorical_crossentropy(
+ y_true=y_true, y_pred=y_pred)
+ return tf.reduce_mean(loss_per_example)
+
+ model.compile(
+ optimizer=gradient_descent.SGD(learning_rate=0.01),
+ loss=loss_fn,
+ metrics=[NumBatchesCounter(), NumExamplesCounter()])
+
+ dummy_batch = collections.OrderedDict([
+ ('x', np.zeros([1, 28 * 28], dtype=np.float32)),
+ ('y', np.zeros([1, 1], dtype=np.int64)),
+ ])
+ tff_model = model_utils.from_compiled_keras_model(
+ keras_model=model, dummy_batch=dummy_batch)
+
+ batch_size = 2
+ batch = {
+ 'x':
+ np.random.uniform(low=0.0, high=1.0,
+ size=[batch_size, 28 * 28]).astype(np.float32),
+ 'y':
+ np.random.random_integers(low=0, high=9, size=[batch_size,
+ 1]).astype(np.int64),
+ }
+
+ num_iterations = 2
+ for _ in range(num_iterations):
+ self.evaluate(tff_model.train_on_batch(batch))
+
+ m = self.evaluate(tff_model.report_local_outputs())
+ self.assertEqual(m['num_batches'], [num_iterations])
+ self.assertEqual(m['num_examples'], [batch_size * num_iterations])
+ self.assertGreater(m['loss'][0], 0.0)
+ self.assertEqual(m['loss'][1], batch_size * num_iterations)
+
+ # Ensure we can assign the FL trained model weights to a new model.
+ tff_weights = model_utils.ModelWeights.from_model(tff_model)
+ keras_model = model_examples.build_conv_batch_norm_keras_model()
+ tff_weights.assign_weights_to(keras_model)
+
+ for keras_w, tff_w in zip(keras_model.weights, tff_weights.keras_weights):
+ self.assertAllClose(
+ self.evaluate(keras_w),
+ self.evaluate(tff_w),
+ atol=1e-4,
+ msg='Variable [{}]'.format(keras_w.name))
+
def test_wrap_tff_model_in_tf_computation(self):
feature_dims = 3