Skip to content

Commit

Permalink
Run checkpoint tests with TPUStrategy.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 247413794
  • Loading branch information
tomhennigan authored and tamaranorman committed May 10, 2019
1 parent 48720f1 commit a0464dc
Showing 1 changed file with 40 additions and 12 deletions.
52 changes: 40 additions & 12 deletions sonnet/golden_checkpoints/goldens_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,12 @@ def mirrored_all_devices():
return tf.distribute.MirroredStrategy(devices=all_visible_devices)


def all_goldens_and_strategies(test_method):
# TODO(tomhennigan) Add TPU and ParameterServer tests.
cases = [(name + "_mirrored", cls(), mirrored_all_devices)
for _, name, cls in goldens.list_goldens()]
return parameterized.named_parameters(cases)(test_method)
def with_soft_placement(f):
"""Wraps `f` such that it runs with soft device placement."""
def wrapper(*a, **k):
with tf.device(None):
return f(*a, **k)
return wrapper


class GoldenCheckpointsTest(test_utils.TestCase, parameterized.TestCase):
Expand Down Expand Up @@ -176,9 +177,29 @@ def test_restore_golden(self, golden):
for variable in variables:
self.assertAllClose(variable.read_value(), goldens.range_like(variable))

@all_goldens_and_strategies
def test_checkpoint_distribution_strategy(self, golden, strategy_fn):
strategy = strategy_fn()
@all_goldens
def test_checkpoint_mirrored_strategy(self, golden):
strategy = mirrored_all_devices()
self.assertCheckpointDistributionStrategy(golden, strategy,
use_function=False)

@all_goldens
def test_checkpoint_mirrored_strategy_function(self, golden):
strategy = mirrored_all_devices()
self.assertCheckpointDistributionStrategy(golden, strategy,
use_function=True)

@all_goldens
def test_checkpoint_tpu_strategy(self, golden):
if self.primary_device != "TPU":
self.skipTest("Test requires a TPU")

strategy = tf.distribute.experimental.TPUStrategy()
self.assertCheckpointDistributionStrategy(golden, strategy,
use_function=True)

def assertCheckpointDistributionStrategy(self, golden, strategy,
use_function=True):
with strategy.scope():
module = golden.create_module()
variables = golden.create_all_variables(module)
Expand All @@ -187,6 +208,12 @@ def forward():
per_replica = strategy.experimental_run_v2(lambda: golden.forward(module))
return tf.stack(strategy.unwrap(per_replica), axis=0)

if use_function:
forward = tf.function(forward)
if self.primary_device == "TPU":
# TODO(b/132329316) Remove when `xla.compile` allows tf.device(TPU).
forward = with_soft_placement(forward)

# Assign sequential values to the weights and compute a forward pass.
for index, variable in enumerate(variables):
variable.assign(goldens.range_like(variable, start=index))
Expand All @@ -196,17 +223,18 @@ def forward():
checkpoint = TestCheckpoint(module=module)
checkpoint.save()

# Assign ones into the weights and do another forward pass. The result
# should be different.
# Assign different values into the weights and do another forward pass. The
# result should be different.
for variable in variables:
variable.assign(tf.ones_like(variable))
variable.assign(-tf.ones_like(variable))

if golden.deterministic:
y = forward()
self.assertNotAllClose(y, before_save_ys)

# Restore from the checkpoint and assert the module is in the same state.
checkpoint.restore_latest()
status = checkpoint.restore_latest()
status.assert_consumed()

for index, variable in enumerate(variables):
# Parameters should be restored to their previous values.
Expand Down

0 comments on commit a0464dc

Please sign in to comment.