Skip to content

Commit

Permalink
Fix checkpointing tests.
Browse files Browse the repository at this point in the history
1. Add calls to `run_restore_ops` to ensure that the restore ops are executed when using checkpoints in graph mode.
2. Ensure that layer orders are the same between the saved model and restored model. Otherwise there will be a race condition when restoring the checkpoint values.

PiperOrigin-RevId: 288719809
Change-Id: I6f87a481e9fe3ea1e8ebc667cfea61dcb0716236
  • Loading branch information
k-w-w authored and tensorflower-gardener committed Jan 8, 2020
1 parent 18845a4 commit 2b32964
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 22 deletions.
32 changes: 11 additions & 21 deletions tensorflow/python/keras/saving/hdf5_format_test.py
Expand Up @@ -1099,7 +1099,7 @@ def test_weight_loading_subclassed_model(self):
self._weight_loading_test_template(SubclassedModel)

def _new_layer_weight_loading_test_template(
self, first_model_fn, second_model_fn, restore_init_fn):
self, first_model_fn, second_model_fn):
with self.cached_session() as session:
model = first_model_fn()
temp_dir = self.get_temp_dir()
Expand All @@ -1122,12 +1122,13 @@ def _new_layer_weight_loading_test_template(
self.addCleanup(shutil.rmtree, temp_dir)

second_model = second_model_fn()
second_model.load_weights(prefix)
status = second_model.load_weights(prefix)
second_model(x)
self.evaluate(restore_init_fn(second_model))
status.run_restore_ops()
second_model.save_weights(prefix)
# Check that the second model's checkpoint loads into the original model
model.load_weights(prefix)
status = model.load_weights(prefix)
status.run_restore_ops(session)
y = self.evaluate(model(x))
self.assertAllClose(ref_y, y)

Expand All @@ -1144,12 +1145,9 @@ def _restore_graph_model():
y = keras.layers.Dense(1, name='second')(x)
b = keras.layers.Dense(3, name='secondjr')(y)
return keras.models.Model(a, b)
def _restore_init_fn(restore_model):
return [v.initializer for v in restore_model.layers[-1].variables]

self._new_layer_weight_loading_test_template(
_save_graph_model, _restore_graph_model,
_restore_init_fn)
_save_graph_model, _restore_graph_model)

@test_util.run_in_graph_and_eager_modes
def test_weight_loading_graph_model_added_no_weight_layer(self):
Expand All @@ -1161,16 +1159,12 @@ def _save_graph_model():
def _restore_graph_model():
a = keras.layers.Input(shape=(2,))
x = keras.layers.Dense(3, name='first')(a)
y = keras.layers.Dropout(rate=0.1)(x)
b = keras.layers.Dense(1, name='second')(y)
return keras.models.Model(a, b)
def _restore_init_fn(restore_model):
del restore_model # unused
return []
b = keras.layers.Dense(1, name='second')(x)
y = keras.layers.Dropout(rate=0.1)(b)
return keras.models.Model(a, y)

self._new_layer_weight_loading_test_template(
_save_graph_model, _restore_graph_model,
_restore_init_fn)
_save_graph_model, _restore_graph_model)

@test_util.run_in_graph_and_eager_modes
def test_weight_loading_subclassed_model_added_layer(self):
Expand All @@ -1186,12 +1180,8 @@ def __init__(self):
def call(self, a):
return self.b_layer(self.y_layer(self.x_layer(a)))

def _restore_init_fn(restore_model):
return [v.initializer for v in restore_model.y_layer.variables]

self._new_layer_weight_loading_test_template(
SubclassedModel, SubclassedModelRestore,
_restore_init_fn)
SubclassedModel, SubclassedModelRestore)

@test_util.run_in_graph_and_eager_modes
def test_incompatible_checkpoint(self):
Expand Down
1 change: 0 additions & 1 deletion tensorflow/python/training/tracking/base.py
Expand Up @@ -352,7 +352,6 @@ def gather_ops_or_named_saveables(self):
if serialized_tensor.checkpoint_key not in saveable.name:
saveable = None
del saveables_cache[self.trackable]
break
if saveable is None:
# If there was no cached SaveableObject, we should check if the Python
# object has the attribute.
Expand Down

0 comments on commit 2b32964

Please sign in to comment.